Commit 0045d0b7 authored by Shucai Xiao's avatar Shucai Xiao Committed by mvermeulen
Browse files

Improve concat gather (#402)

* improve gather implementation to handle negative input indices

* clang format

* clang format

* improve concat to support neg axis input

* clang format

* fix cppcheck error

* clang format

* code cleanup

* clang format

* fix review comments

* clang format
parent d10628ee
......@@ -33,7 +33,9 @@ void eliminate_concat::apply(program& p) const
// we only need to check the first input
auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator());
if(concat_op.axis == 0 ||
std::size_t axis_index =
(concat_op.axis < 0) ? (concat_op.axis + lens.size()) : concat_op.axis;
if(axis_index == 0 ||
std::all_of(lens.begin(), lens.begin() + concat_op.axis, [](auto x) { return x == 1; }))
{
// Last input should be an allocation
......
......@@ -18,7 +18,7 @@ namespace op {
struct concat
{
std::size_t axis = 0;
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -30,13 +30,15 @@ struct concat
std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument>& args) const
{
auto n_dims = args[0].get_shape().lens().size();
std::size_t axis_index = (axis < 0) ? axis + n_dims : axis;
std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
offset[axis] = 0;
std::vector<std::size_t> offset(n_dims, 0);
offset[axis_index] = 0;
for(const auto& arg : args)
{
offsets.push_back(output_shape.index(offset));
offset[axis] += arg.get_shape().lens()[axis];
offset[axis_index] += arg.get_shape().lens()[axis_index];
}
return offsets;
}
......@@ -44,20 +46,21 @@ struct concat
{
if(inputs.empty())
{
MIGRAPHX_THROW("Number of input tensors should exceed 0");
MIGRAPHX_THROW("CONCAT: Number of input tensors should exceed 0");
}
const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type();
std::size_t axis_index = (axis < 0) ? (first_shape_lens.size() + axis) : axis;
for(std::size_t l = 0; l < first_shape_lens.size(); l++)
{
if(l != axis)
if(l != axis_index)
{
if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[l] == first_shape_lens[l];
}))
{
MIGRAPHX_THROW("Non-axis dimensions should match");
MIGRAPHX_THROW("CONCAT: Non-axis dimensions should match");
}
}
}
......@@ -65,11 +68,11 @@ struct concat
for(const auto& input : inputs)
{
const auto& lens = input.lens();
new_dim_axis += lens[axis];
new_dim_axis += lens[axis_index];
}
std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis] = new_dim_axis;
new_lens[axis_index] = new_dim_axis;
return {type, new_lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
......
......@@ -62,15 +62,18 @@ struct gather
{
argument result{output_shape};
// negative axis means counting dimensions from back
int axis_index =
(axis < 0) ? static_cast<int>(args[0].get_shape().lens().size() + axis) : axis;
auto lens = args[0].get_shape().lens();
int axis_index = (axis < 0) ? static_cast<int>(lens.size() + axis) : axis;
std::size_t axis_dim_size = lens[axis_index];
// max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
if(output_shape.scalar())
{
output[0] = data[indices.front()];
auto in_index = indices.front();
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
output[0] = data[indices.front()];
}
else
{
......@@ -79,7 +82,9 @@ struct gather
migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
shape_for_each(out_comp_shape, [&](const auto& out_idx) {
auto data_idx = out_idx;
data_idx[axis_index] = indices[data_idx[axis_index]];
auto in_index = indices[data_idx[axis_index]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
data_idx[axis_index] = in_index;
output[out_comp_shape.index(out_idx.begin(), out_idx.end())] =
data(data_idx.begin(), data_idx.end());
});
......
......@@ -491,7 +491,13 @@ struct onnx_parser
instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::size_t axis = parse_value(attributes.at("axis")).at<int>();
// change to hande axis to be negative values
if(!contains(attributes, "axis"))
{
MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
}
int axis = parse_value(attributes.at("axis")).at<int>();
op::concat op{axis};
return prog.add_instruction(op, std::move(args));
}
......
......@@ -12,10 +12,11 @@ namespace device {
argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis)
{
auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis;
auto& input_shape = arg1.get_shape();
auto lens = input_shape.lens();
lens[axis_index] = arg2.get_shape().elements();
auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis;
auto& input_shape = arg1.get_shape();
auto lens = input_shape.lens();
auto axis_dim_size = lens[axis_index];
lens[axis_index] = arg2.get_shape().elements();
shape out_comp_shape{result.get_shape().type(), lens};
std::size_t nelements = result.get_shape().elements();
......@@ -26,7 +27,9 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg
auto* output_ptr = device_cast(output.data());
gs_launch(stream, nelements, 256)([=](auto i) {
auto idx = out_comp.multi(i);
idx[axis_index] = indices_ptr[idx[axis_index]];
auto in_index = indices_ptr[idx[axis_index]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
idx[axis_index] = in_index;
output_ptr[i] = input[idx];
});
});
......
......@@ -357,7 +357,7 @@ struct tf_parser
{
// get index for axis within args
size_t axis_idx = attributes.at("N").i();
size_t axis = args[axis_idx]->eval().at<int64_t>();
int64_t axis = args[axis_idx]->eval().at<int64_t>();
op::concat op{axis};
// return only first N arguments (assuming last index is the axis value)
return prog.add_instruction(
......@@ -664,8 +664,7 @@ struct tf_parser
args.end(),
std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return to_nhwc(
prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args));
return to_nhwc(prog.add_instruction(op::concat{axis}, unsqueezed_args));
}
instruction_ref
......
......@@ -59,7 +59,7 @@ TEST_CASE(concat_test)
{
{
migraphx::program p;
std::size_t axis = 1;
int axis = 1;
std::vector<int> data0 = {0, 1, 5, 6};
std::vector<int> data1 = {2, 3, 4, 7, 8, 9};
std::vector<int> data2 = {10, 20};
......@@ -80,9 +80,58 @@ TEST_CASE(concat_test)
EXPECT(
migraphx::verify_range(result.get_shape().strides(), std::vector<std::size_t>({6, 1})));
}
{
migraphx::program p;
int axis = -1;
std::vector<int> data0 = {0, 1, 5, 6};
std::vector<int> data1 = {2, 3, 4, 7, 8, 9};
std::vector<int> data2 = {10, 20};
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 1}};
auto l0 = p.add_literal(migraphx::literal{s0, data0});
auto l1 = p.add_literal(migraphx::literal{s1, data1});
auto l2 = p.add_literal(migraphx::literal{s2, data2});
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int> gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20};
std::vector<int> results_vector(2 * 6);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<std::size_t>({2, 6})));
EXPECT(
migraphx::verify_range(result.get_shape().strides(), std::vector<std::size_t>({6, 1})));
}
{
migraphx::program p;
int axis = 0;
std::vector<int> data0 = {0, 1, 2, 3};
std::vector<int> data1 = {4, 5, 6, 7, 8, 9};
std::vector<int> data2 = {10, 11};
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {1, 2}};
auto l0 = p.add_literal(migraphx::literal{s0, data0});
auto l1 = p.add_literal(migraphx::literal{s1, data1});
auto l2 = p.add_literal(migraphx::literal{s2, data2});
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
std::vector<int> results_vector(6 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<std::size_t>({6, 2})));
EXPECT(
migraphx::verify_range(result.get_shape().strides(), std::vector<std::size_t>({2, 1})));
}
{
migraphx::program p;
std::size_t axis = 0;
int axis = -2;
std::vector<int> data0 = {0, 1, 2, 3};
std::vector<int> data1 = {4, 5, 6, 7, 8, 9};
std::vector<int> data2 = {10, 11};
......@@ -127,6 +176,26 @@ TEST_CASE(gather_test)
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{-3, -1};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
......@@ -188,6 +257,27 @@ TEST_CASE(gather_test)
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{-3};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data{};
std::vector<float> golden = {0.5f, 3.5f, 6.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
......
......@@ -1941,12 +1941,12 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
}
};
struct test_concat : verify_program<test_concat>
struct test_concat_axis_1 : verify_program<test_concat_axis_1>
{
migraphx::program create_program() const
{
migraphx::program p;
std::size_t axis = 1;
int axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 1}};
......@@ -1958,12 +1958,29 @@ struct test_concat : verify_program<test_concat>
}
};
struct test_concat2 : verify_program<test_concat2>
struct test_concat_axis_neg_1 : verify_program<test_concat_axis_neg_1>
{
migraphx::program create_program() const
{
migraphx::program p;
std::size_t axis = 0;
int axis = -1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 1}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_concat_axis_0 : verify_program<test_concat_axis_0>
{
migraphx::program create_program() const
{
migraphx::program p;
int axis = 0;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {1, 2}};
......@@ -1980,7 +1997,7 @@ struct test_concat_transpose : verify_program<test_concat_transpose>
migraphx::program create_program() const
{
migraphx::program p;
std::size_t axis = 1;
int axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 4}};
......@@ -1998,7 +2015,7 @@ struct test_concat_transpose2 : verify_program<test_concat_transpose2>
migraphx::program create_program() const
{
migraphx::program p;
std::size_t axis = 1;
int axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {5, 2}};
......@@ -2016,7 +2033,7 @@ struct test_concat_transpose3 : verify_program<test_concat_transpose3>
migraphx::program create_program() const
{
migraphx::program p;
std::size_t axis = 1;
int axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {5, 2}};
......@@ -2035,7 +2052,7 @@ struct test_concat_relu : verify_program<test_concat_relu>
migraphx::program create_program() const
{
migraphx::program p;
std::size_t axis = 0;
int axis = 0;
migraphx::shape s0{migraphx::shape::float_type, {2, 2}};
migraphx::shape s1{migraphx::shape::float_type, {3, 2}};
migraphx::shape s2{migraphx::shape::float_type, {1, 2}};
......@@ -2134,6 +2151,22 @@ struct test_gather_neg_axis : verify_program<test_gather_neg_axis>
}
};
struct test_gather_neg_indices : verify_program<test_gather_neg_indices>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{-2, -1, -1, -2};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
struct test_gather_scalar_output : verify_program<test_gather_scalar_output>
{
migraphx::program create_program() const
......@@ -2202,7 +2235,7 @@ void manual_identity()
void manual_test_concat_relu()
{
migraphx::program p;
std::size_t axis = 0;
int axis = 0;
std::vector<float> data0 = {0, 1, 2, 3};
std::vector<float> data1 = {4, 5, 6, 7, 8, 9};
std::vector<float> data2 = {10, 11};
......
......@@ -138,7 +138,7 @@ TEST_CASE(concat_test)
// add the literal using a vector in order to set stride to 1 (like in tf parser)
p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1);
p.add_instruction(migraphx::op::concat{axis}, l0, l1);
auto prog = optimize_tf("concat_test.pb", false);
EXPECT(p == prog);
......@@ -341,7 +341,7 @@ TEST_CASE(pack_test)
[&](migraphx::instruction_ref arg) {
return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
p.add_instruction(migraphx::op::concat{static_cast<int>(axis)}, unsqueezed_args);
auto prog = optimize_tf("pack_test.pb", false);
EXPECT(p == prog);
......@@ -366,7 +366,7 @@ TEST_CASE(pack_test_nhwc)
[&](migraphx::instruction_ref arg) {
return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args);
p.add_instruction(migraphx::op::concat{static_cast<int>(nchw_axis)}, unsqueezed_args);
auto prog = optimize_tf("pack_test_nhwc.pb", true);
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment