Commit 0045d0b7 authored by Shucai Xiao's avatar Shucai Xiao Committed by mvermeulen
Browse files

Improve concat gather (#402)

* improve gather implementation to handle negative input indices

* clang format

* clang format

* improve concat to support neg axis input

* clang format

* fix cppcheck error

* clang format

* code cleanup

* clang format

* fix review comments

* clang format
parent d10628ee
...@@ -33,7 +33,9 @@ void eliminate_concat::apply(program& p) const ...@@ -33,7 +33,9 @@ void eliminate_concat::apply(program& p) const
// we only need to check the first input // we only need to check the first input
auto lens = ins->inputs().front()->get_shape().lens(); auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator()); auto concat_op = concat_opt.get_concat(ins->get_operator());
if(concat_op.axis == 0 || std::size_t axis_index =
(concat_op.axis < 0) ? (concat_op.axis + lens.size()) : concat_op.axis;
if(axis_index == 0 ||
std::all_of(lens.begin(), lens.begin() + concat_op.axis, [](auto x) { return x == 1; })) std::all_of(lens.begin(), lens.begin() + concat_op.axis, [](auto x) { return x == 1; }))
{ {
// Last input should be an allocation // Last input should be an allocation
......
...@@ -18,7 +18,7 @@ namespace op { ...@@ -18,7 +18,7 @@ namespace op {
struct concat struct concat
{ {
std::size_t axis = 0; int64_t axis = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -30,13 +30,15 @@ struct concat ...@@ -30,13 +30,15 @@ struct concat
std::vector<std::size_t> compute_offsets(const shape& output_shape, std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
auto n_dims = args[0].get_shape().lens().size();
std::size_t axis_index = (axis < 0) ? axis + n_dims : axis;
std::vector<std::size_t> offsets; std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0); std::vector<std::size_t> offset(n_dims, 0);
offset[axis] = 0; offset[axis_index] = 0;
for(const auto& arg : args) for(const auto& arg : args)
{ {
offsets.push_back(output_shape.index(offset)); offsets.push_back(output_shape.index(offset));
offset[axis] += arg.get_shape().lens()[axis]; offset[axis_index] += arg.get_shape().lens()[axis_index];
} }
return offsets; return offsets;
} }
...@@ -44,20 +46,21 @@ struct concat ...@@ -44,20 +46,21 @@ struct concat
{ {
if(inputs.empty()) if(inputs.empty())
{ {
MIGRAPHX_THROW("Number of input tensors should exceed 0"); MIGRAPHX_THROW("CONCAT: Number of input tensors should exceed 0");
} }
const auto& first_shape_lens = inputs.front().lens(); const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type(); const auto& type = inputs.front().type();
std::size_t axis_index = (axis < 0) ? (first_shape_lens.size() + axis) : axis;
for(std::size_t l = 0; l < first_shape_lens.size(); l++) for(std::size_t l = 0; l < first_shape_lens.size(); l++)
{ {
if(l != axis) if(l != axis_index)
{ {
if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) { if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[l] == first_shape_lens[l]; return s.lens()[l] == first_shape_lens[l];
})) }))
{ {
MIGRAPHX_THROW("Non-axis dimensions should match"); MIGRAPHX_THROW("CONCAT: Non-axis dimensions should match");
} }
} }
} }
...@@ -65,11 +68,11 @@ struct concat ...@@ -65,11 +68,11 @@ struct concat
for(const auto& input : inputs) for(const auto& input : inputs)
{ {
const auto& lens = input.lens(); const auto& lens = input.lens();
new_dim_axis += lens[axis]; new_dim_axis += lens[axis_index];
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens)); std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis] = new_dim_axis; new_lens[axis_index] = new_dim_axis;
return {type, new_lens}; return {type, new_lens};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
......
...@@ -62,15 +62,18 @@ struct gather ...@@ -62,15 +62,18 @@ struct gather
{ {
argument result{output_shape}; argument result{output_shape};
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
int axis_index = auto lens = args[0].get_shape().lens();
(axis < 0) ? static_cast<int>(args[0].get_shape().lens().size() + axis) : axis; int axis_index = (axis < 0) ? static_cast<int>(lens.size() + axis) : axis;
std::size_t axis_dim_size = lens[axis_index];
// max dimension in axis // max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) { visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
if(output_shape.scalar()) if(output_shape.scalar())
{ {
output[0] = data[indices.front()]; auto in_index = indices.front();
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
output[0] = data[indices.front()];
} }
else else
{ {
...@@ -79,7 +82,9 @@ struct gather ...@@ -79,7 +82,9 @@ struct gather
migraphx::shape out_comp_shape{data.get_shape().type(), out_lens}; migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
shape_for_each(out_comp_shape, [&](const auto& out_idx) { shape_for_each(out_comp_shape, [&](const auto& out_idx) {
auto data_idx = out_idx; auto data_idx = out_idx;
data_idx[axis_index] = indices[data_idx[axis_index]]; auto in_index = indices[data_idx[axis_index]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
data_idx[axis_index] = in_index;
output[out_comp_shape.index(out_idx.begin(), out_idx.end())] = output[out_comp_shape.index(out_idx.begin(), out_idx.end())] =
data(data_idx.begin(), data_idx.end()); data(data_idx.begin(), data_idx.end());
}); });
......
...@@ -491,7 +491,13 @@ struct onnx_parser ...@@ -491,7 +491,13 @@ struct onnx_parser
instruction_ref instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
std::size_t axis = parse_value(attributes.at("axis")).at<int>(); // change to hande axis to be negative values
if(!contains(attributes, "axis"))
{
MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
}
int axis = parse_value(attributes.at("axis")).at<int>();
op::concat op{axis}; op::concat op{axis};
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
......
...@@ -12,10 +12,11 @@ namespace device { ...@@ -12,10 +12,11 @@ namespace device {
argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis) argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis)
{ {
auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis; auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis;
auto& input_shape = arg1.get_shape(); auto& input_shape = arg1.get_shape();
auto lens = input_shape.lens(); auto lens = input_shape.lens();
lens[axis_index] = arg2.get_shape().elements(); auto axis_dim_size = lens[axis_index];
lens[axis_index] = arg2.get_shape().elements();
shape out_comp_shape{result.get_shape().type(), lens}; shape out_comp_shape{result.get_shape().type(), lens};
std::size_t nelements = result.get_shape().elements(); std::size_t nelements = result.get_shape().elements();
...@@ -26,7 +27,9 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg ...@@ -26,7 +27,9 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
gs_launch(stream, nelements, 256)([=](auto i) { gs_launch(stream, nelements, 256)([=](auto i) {
auto idx = out_comp.multi(i); auto idx = out_comp.multi(i);
idx[axis_index] = indices_ptr[idx[axis_index]]; auto in_index = indices_ptr[idx[axis_index]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
idx[axis_index] = in_index;
output_ptr[i] = input[idx]; output_ptr[i] = input[idx];
}); });
}); });
......
...@@ -357,7 +357,7 @@ struct tf_parser ...@@ -357,7 +357,7 @@ struct tf_parser
{ {
// get index for axis within args // get index for axis within args
size_t axis_idx = attributes.at("N").i(); size_t axis_idx = attributes.at("N").i();
size_t axis = args[axis_idx]->eval().at<int64_t>(); int64_t axis = args[axis_idx]->eval().at<int64_t>();
op::concat op{axis}; op::concat op{axis};
// return only first N arguments (assuming last index is the axis value) // return only first N arguments (assuming last index is the axis value)
return prog.add_instruction( return prog.add_instruction(
...@@ -664,8 +664,7 @@ struct tf_parser ...@@ -664,8 +664,7 @@ struct tf_parser
args.end(), args.end(),
std::back_inserter(unsqueezed_args), std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); }); [&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return to_nhwc( return to_nhwc(prog.add_instruction(op::concat{axis}, unsqueezed_args));
prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args));
} }
instruction_ref instruction_ref
......
...@@ -59,7 +59,7 @@ TEST_CASE(concat_test) ...@@ -59,7 +59,7 @@ TEST_CASE(concat_test)
{ {
{ {
migraphx::program p; migraphx::program p;
std::size_t axis = 1; int axis = 1;
std::vector<int> data0 = {0, 1, 5, 6}; std::vector<int> data0 = {0, 1, 5, 6};
std::vector<int> data1 = {2, 3, 4, 7, 8, 9}; std::vector<int> data1 = {2, 3, 4, 7, 8, 9};
std::vector<int> data2 = {10, 20}; std::vector<int> data2 = {10, 20};
...@@ -80,9 +80,58 @@ TEST_CASE(concat_test) ...@@ -80,9 +80,58 @@ TEST_CASE(concat_test)
EXPECT( EXPECT(
migraphx::verify_range(result.get_shape().strides(), std::vector<std::size_t>({6, 1}))); migraphx::verify_range(result.get_shape().strides(), std::vector<std::size_t>({6, 1})));
} }
{
migraphx::program p;
int axis = -1;
std::vector<int> data0 = {0, 1, 5, 6};
std::vector<int> data1 = {2, 3, 4, 7, 8, 9};
std::vector<int> data2 = {10, 20};
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 1}};
auto l0 = p.add_literal(migraphx::literal{s0, data0});
auto l1 = p.add_literal(migraphx::literal{s1, data1});
auto l2 = p.add_literal(migraphx::literal{s2, data2});
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int> gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20};
std::vector<int> results_vector(2 * 6);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<std::size_t>({2, 6})));
EXPECT(
migraphx::verify_range(result.get_shape().strides(), std::vector<std::size_t>({6, 1})));
}
{
migraphx::program p;
int axis = 0;
std::vector<int> data0 = {0, 1, 2, 3};
std::vector<int> data1 = {4, 5, 6, 7, 8, 9};
std::vector<int> data2 = {10, 11};
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {1, 2}};
auto l0 = p.add_literal(migraphx::literal{s0, data0});
auto l1 = p.add_literal(migraphx::literal{s1, data1});
auto l2 = p.add_literal(migraphx::literal{s2, data2});
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
std::vector<int> results_vector(6 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<std::size_t>({6, 2})));
EXPECT(
migraphx::verify_range(result.get_shape().strides(), std::vector<std::size_t>({2, 1})));
}
{ {
migraphx::program p; migraphx::program p;
std::size_t axis = 0; int axis = -2;
std::vector<int> data0 = {0, 1, 2, 3}; std::vector<int> data0 = {0, 1, 2, 3};
std::vector<int> data1 = {4, 5, 6, 7, 8, 9}; std::vector<int> data1 = {4, 5, 6, 7, 8, 9};
std::vector<int> data2 = {10, 11}; std::vector<int> data2 = {10, 11};
...@@ -127,6 +176,26 @@ TEST_CASE(gather_test) ...@@ -127,6 +176,26 @@ TEST_CASE(gather_test)
EXPECT(migraphx::verify_range(res_data, golden)); EXPECT(migraphx::verify_range(res_data, golden));
} }
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{-3, -1};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{ {
migraphx::program p; migraphx::program p;
...@@ -188,6 +257,27 @@ TEST_CASE(gather_test) ...@@ -188,6 +257,27 @@ TEST_CASE(gather_test)
EXPECT(migraphx::verify_range(res_data, golden)); EXPECT(migraphx::verify_range(res_data, golden));
} }
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{-3};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data{};
std::vector<float> golden = {0.5f, 3.5f, 6.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{ {
migraphx::program p; migraphx::program p;
......
...@@ -1941,12 +1941,12 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride> ...@@ -1941,12 +1941,12 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
} }
}; };
struct test_concat : verify_program<test_concat> struct test_concat_axis_1 : verify_program<test_concat_axis_1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
std::size_t axis = 1; int axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 1}}; migraphx::shape s2{migraphx::shape::int32_type, {2, 1}};
...@@ -1958,12 +1958,29 @@ struct test_concat : verify_program<test_concat> ...@@ -1958,12 +1958,29 @@ struct test_concat : verify_program<test_concat>
} }
}; };
struct test_concat2 : verify_program<test_concat2> struct test_concat_axis_neg_1 : verify_program<test_concat_axis_neg_1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
std::size_t axis = 0; int axis = -1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 1}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_concat_axis_0 : verify_program<test_concat_axis_0>
{
migraphx::program create_program() const
{
migraphx::program p;
int axis = 0;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s2{migraphx::shape::int32_type, {1, 2}};
...@@ -1980,7 +1997,7 @@ struct test_concat_transpose : verify_program<test_concat_transpose> ...@@ -1980,7 +1997,7 @@ struct test_concat_transpose : verify_program<test_concat_transpose>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
std::size_t axis = 1; int axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 4}}; migraphx::shape s2{migraphx::shape::int32_type, {2, 4}};
...@@ -1998,7 +2015,7 @@ struct test_concat_transpose2 : verify_program<test_concat_transpose2> ...@@ -1998,7 +2015,7 @@ struct test_concat_transpose2 : verify_program<test_concat_transpose2>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
std::size_t axis = 1; int axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {5, 2}}; migraphx::shape s2{migraphx::shape::int32_type, {5, 2}};
...@@ -2016,7 +2033,7 @@ struct test_concat_transpose3 : verify_program<test_concat_transpose3> ...@@ -2016,7 +2033,7 @@ struct test_concat_transpose3 : verify_program<test_concat_transpose3>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
std::size_t axis = 1; int axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {5, 2}}; migraphx::shape s2{migraphx::shape::int32_type, {5, 2}};
...@@ -2035,7 +2052,7 @@ struct test_concat_relu : verify_program<test_concat_relu> ...@@ -2035,7 +2052,7 @@ struct test_concat_relu : verify_program<test_concat_relu>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
std::size_t axis = 0; int axis = 0;
migraphx::shape s0{migraphx::shape::float_type, {2, 2}}; migraphx::shape s0{migraphx::shape::float_type, {2, 2}};
migraphx::shape s1{migraphx::shape::float_type, {3, 2}}; migraphx::shape s1{migraphx::shape::float_type, {3, 2}};
migraphx::shape s2{migraphx::shape::float_type, {1, 2}}; migraphx::shape s2{migraphx::shape::float_type, {1, 2}};
...@@ -2134,6 +2151,22 @@ struct test_gather_neg_axis : verify_program<test_gather_neg_axis> ...@@ -2134,6 +2151,22 @@ struct test_gather_neg_axis : verify_program<test_gather_neg_axis>
} }
}; };
struct test_gather_neg_indices : verify_program<test_gather_neg_indices>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{-2, -1, -1, -2};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
struct test_gather_scalar_output : verify_program<test_gather_scalar_output> struct test_gather_scalar_output : verify_program<test_gather_scalar_output>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -2202,7 +2235,7 @@ void manual_identity() ...@@ -2202,7 +2235,7 @@ void manual_identity()
void manual_test_concat_relu() void manual_test_concat_relu()
{ {
migraphx::program p; migraphx::program p;
std::size_t axis = 0; int axis = 0;
std::vector<float> data0 = {0, 1, 2, 3}; std::vector<float> data0 = {0, 1, 2, 3};
std::vector<float> data1 = {4, 5, 6, 7, 8, 9}; std::vector<float> data1 = {4, 5, 6, 7, 8, 9};
std::vector<float> data2 = {10, 11}; std::vector<float> data2 = {10, 11};
......
...@@ -138,7 +138,7 @@ TEST_CASE(concat_test) ...@@ -138,7 +138,7 @@ TEST_CASE(concat_test)
// add the literal using a vector in order to set stride to 1 (like in tf parser) // add the literal using a vector in order to set stride to 1 (like in tf parser)
p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis}); p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1); p.add_instruction(migraphx::op::concat{axis}, l0, l1);
auto prog = optimize_tf("concat_test.pb", false); auto prog = optimize_tf("concat_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -341,7 +341,7 @@ TEST_CASE(pack_test) ...@@ -341,7 +341,7 @@ TEST_CASE(pack_test)
[&](migraphx::instruction_ref arg) { [&](migraphx::instruction_ref arg) {
return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg); return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg);
}); });
p.add_instruction(migraphx::op::concat{static_cast<size_t>(axis)}, unsqueezed_args); p.add_instruction(migraphx::op::concat{static_cast<int>(axis)}, unsqueezed_args);
auto prog = optimize_tf("pack_test.pb", false); auto prog = optimize_tf("pack_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -366,7 +366,7 @@ TEST_CASE(pack_test_nhwc) ...@@ -366,7 +366,7 @@ TEST_CASE(pack_test_nhwc)
[&](migraphx::instruction_ref arg) { [&](migraphx::instruction_ref arg) {
return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg); return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
}); });
p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args); p.add_instruction(migraphx::op::concat{static_cast<int>(nchw_axis)}, unsqueezed_args);
auto prog = optimize_tf("pack_test_nhwc.pb", true); auto prog = optimize_tf("pack_test_nhwc.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment