"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "f33f2298f72a97bb495c6cd60446ec92889b7333"
Unverified Commit e15b8333 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #193 from ROCmSoftwarePlatform/gather_operator

Gather operator
parents 66e6f824 b48c3ee1
...@@ -757,43 +757,49 @@ struct gather ...@@ -757,43 +757,49 @@ struct gather
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (n_dim + axis) : axis; int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type(); auto type = inputs[0].type();
lens[axis_index] = inputs[1].elements(); lens.erase(lens.begin() + axis_index);
if(!inputs[1].scalar())
return {type, lens}; {
} auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end());
}
template <class T> // for scalar output
void compute_index(const T& out_idx, if(lens.empty())
const int axis_index,
const std::vector<std::size_t>& vec_indices,
const std::size_t max_dim,
T& in_idx) const
{
in_idx = out_idx;
std::size_t idx = vec_indices.at(out_idx[axis_index]);
if(idx >= max_dim)
{ {
MIGRAPHX_THROW("Gather: indices are out of range in input tensor"); return {type};
} }
in_idx[axis_index] = idx;
return {type, lens};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis; int axis_index =
(axis < 0) ? static_cast<int>(args[0].get_shape().lens().size() + axis) : axis;
// max dimension in axis // max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis_index]; visit_all(result, args[0])([&](auto output, auto data) {
std::vector<std::size_t> vec_indices; args[1].visit([&](auto indices) {
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); }); if(output_shape.scalar())
visit_all(result, args[0])([&](auto output, auto input) { {
std::vector<std::size_t> in_idx; output[0] = data[indices.front()];
shape_for_each(output.get_shape(), [&](const auto& idx) { }
this->compute_index(idx, axis_index, vec_indices, max_dim, in_idx); else
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end()); {
auto out_lens = data.get_shape().lens();
out_lens[axis_index] = indices.get_shape().elements();
migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
shape_for_each(out_comp_shape, [&](const auto& out_idx) {
auto data_idx = out_idx;
data_idx[axis_index] = indices[data_idx[axis_index]];
output[out_comp_shape.index(out_idx.begin(), out_idx.end())] =
data(data_idx.begin(), data_idx.end());
});
}
}); });
}); });
......
...@@ -436,7 +436,15 @@ struct onnx_parser ...@@ -436,7 +436,15 @@ struct onnx_parser
attribute_map attributes, attribute_map attributes,
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
{ {
literal v = parse_value(attributes.at("value")); literal v = parse_value(attributes.at("value"));
auto dim_size = attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar
if(dim_size == 0)
{
migraphx::shape scalar_shape{v.get_shape().type()};
return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
}
return prog.add_literal(v); return prog.add_literal(v);
} }
...@@ -463,6 +471,7 @@ struct onnx_parser ...@@ -463,6 +471,7 @@ struct onnx_parser
{ {
transb = parse_value(attributes.at("transB")).at<bool>(); transb = parse_value(attributes.at("transB")).at<bool>();
} }
std::vector<int64_t> perm = {1, 0}; std::vector<int64_t> perm = {1, 0};
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0]; auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
...@@ -483,7 +492,10 @@ struct onnx_parser ...@@ -483,7 +492,10 @@ struct onnx_parser
return add_broadcastable_binary_op(l3, l4, op::add{}); return add_broadcastable_binary_op(l3, l4, op::add{});
} }
} }
return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
auto dot_res = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
return dot_res;
} }
instruction_ref instruction_ref
......
...@@ -19,7 +19,7 @@ struct shape_impl ...@@ -19,7 +19,7 @@ struct shape_impl
shape_impl() : m_type(shape::float_type), m_standard(false) {} shape_impl() : m_type(shape::float_type), m_standard(false) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {} shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true) {}
shape_impl(shape::type_t t, std::vector<std::size_t> l) shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true) : m_type(t), m_lens(std::move(l)), m_standard(true)
{ {
......
...@@ -16,20 +16,24 @@ argument gather(hipStream_t stream, ...@@ -16,20 +16,24 @@ argument gather(hipStream_t stream,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
int axis) int axis)
{ {
int axis_index = (axis < 0) ? (axis + output_shape.lens().size()) : axis; int axis_index = (axis < 0) ? (axis + args[0].get_shape().lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) { visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { const auto* indices_ptr = device_cast(indices.data());
const auto* indices_ptr = device_cast(indices.data()); auto* out_ptr = device_cast(output.data());
auto* outptr = device_cast(output.data()); const auto* in_ptr = device_cast(input.data());
const auto* inptr = device_cast(input.data()); auto& input_shape = args[0].get_shape();
hip_tensor_descriptor<ndim> desc_input(input.get_shape()); auto lens = input_shape.lens();
hip_tensor_descriptor<ndim> desc_output(output.get_shape()); lens[axis_index] = args[1].get_shape().elements();
gs_launch(stream, nelements)([=](auto i) { migraphx::shape out_comp_shape{output_shape.type(), lens};
auto lens = desc_output.multi(i); visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) {
lens[axis_index] = indices_ptr[lens[axis_index]]; hip_tensor_descriptor<n_out_dim> desc_input(input_shape);
outptr[i] = inptr[desc_input.linear(lens)]; hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape);
gs_launch(stream, nelements)([=](auto ii) {
auto in_idx = desc_output.multi(ii);
in_idx[axis_index] = indices_ptr[in_idx[axis_index]];
out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
}); });
}); });
}); });
......
...@@ -164,6 +164,48 @@ TEST_CASE(gather_test) ...@@ -164,6 +164,48 @@ TEST_CASE(gather_test)
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden)); EXPECT(migraphx::verify_range(res_data, golden));
} }
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{0};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data{};
std::vector<float> golden = {0.5f, 3.5f, 6.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
std::vector<float> data(3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{0};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data{};
std::vector<float> golden = {0.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
} }
TEST_CASE(squeeze_test) TEST_CASE(squeeze_test)
......
...@@ -1068,6 +1068,54 @@ struct test_gather_neg_axis ...@@ -1068,6 +1068,54 @@ struct test_gather_neg_axis
} }
}; };
struct test_gather_scalar_output
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
struct test_gather_scalar_index
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
struct test_gather_1d_index
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {1}};
std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
void manual_identity() void manual_identity()
{ {
migraphx::program p; migraphx::program p;
...@@ -2904,6 +2952,9 @@ int main() ...@@ -2904,6 +2952,9 @@ int main()
verify_program<test_slice>(); verify_program<test_slice>();
verify_program<test_gather>(); verify_program<test_gather>();
verify_program<test_gather_neg_axis>(); verify_program<test_gather_neg_axis>();
verify_program<test_gather_scalar_output>();
verify_program<test_gather_scalar_index>();
verify_program<test_gather_1d_index>();
verify_program<test_rnn_forward>(); verify_program<test_rnn_forward>();
verify_program<test_rnn_forward10>(); verify_program<test_rnn_forward10>();
verify_program<test_rnn_reverse>(); verify_program<test_rnn_reverse>();
......
shape-gather-example:O
2value"Constant*
value**B const_tensor constantb
z

B
\ No newline at end of file
...@@ -521,6 +521,15 @@ TEST_CASE(constant_test) ...@@ -521,6 +521,15 @@ TEST_CASE(constant_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(constant_test_scalar)
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {1}});
auto prog = migraphx::parse_onnx("constant_scalar.onnx");
EXPECT(p == prog);
}
TEST_CASE(constant_fill_test) TEST_CASE(constant_fill_test)
{ {
{ {
......
...@@ -235,7 +235,7 @@ TEST_CASE(gather) ...@@ -235,7 +235,7 @@ TEST_CASE(gather)
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 1; int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 4, 5}},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
input, input,
indices); indices);
...@@ -245,7 +245,57 @@ TEST_CASE(gather) ...@@ -245,7 +245,57 @@ TEST_CASE(gather)
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -4; int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 3, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type};
int axis = 3;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {3}};
migraphx::shape indices{migraphx::shape::int32_type};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {3}};
migraphx::shape indices{migraphx::shape::int32_type, {1}};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1}},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
input, input,
indices); indices);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment