"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "70837f1aafcd24d6ed4bb328eca87d74a7cc8604"
Commit cda6f573 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'gather_operator' into extend_gemm_op

parents c69b2d6b 4c1a1d63
...@@ -774,35 +774,12 @@ struct gather ...@@ -774,35 +774,12 @@ struct gather
return {type, lens}; return {type, lens};
} }
template <typename V>
std::vector<std::size_t> compute_data_index(const V& indices,
const int axis_index,
const std::vector<std::size_t>& out_idx) const
{
auto data_idx = out_idx;
std::size_t index{};
if(!indices.get_shape().scalar())
{
auto start_it = data_idx.begin() + axis_index;
auto end_it = data_idx.begin() + axis_index + indices.get_shape().lens().size();
std::vector<std::size_t> ind_idx(start_it, end_it);
data_idx.erase(start_it, end_it);
index = indices(ind_idx.begin(), ind_idx.end());
}
else
{
index = indices.front();
}
data_idx.insert(data_idx.begin() + axis_index, index);
return data_idx;
}
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (args[0].get_shape().lens().size() + axis) : axis; int axis_index =
(axis < 0) ? static_cast<int>(args[0].get_shape().lens().size() + axis) : axis;
// max dimension in axis // max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) { visit_all(result, args[0])([&](auto output, auto data) {
...@@ -813,9 +790,13 @@ struct gather ...@@ -813,9 +790,13 @@ struct gather
} }
else else
{ {
shape_for_each(output.get_shape(), [&](const auto& out_idx) { auto out_lens = data.get_shape().lens();
auto data_idx = compute_data_index(indices, axis_index, out_idx); out_lens[axis_index] = indices.get_shape().elements();
output(out_idx.begin(), out_idx.end()) = migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
shape_for_each(out_comp_shape, [&](const auto& out_idx) {
auto data_idx = out_idx;
data_idx[axis_index] = indices[data_idx[axis_index]];
output[out_comp_shape.index(out_idx.begin(), out_idx.end())] =
data(data_idx.begin(), data_idx.end()); data(data_idx.begin(), data_idx.end());
}); });
} }
......
...@@ -23,28 +23,19 @@ argument gather(hipStream_t stream, ...@@ -23,28 +23,19 @@ argument gather(hipStream_t stream,
const auto* indices_ptr = device_cast(indices.data()); const auto* indices_ptr = device_cast(indices.data());
auto* out_ptr = device_cast(output.data()); auto* out_ptr = device_cast(output.data());
const auto* in_ptr = device_cast(input.data()); const auto* in_ptr = device_cast(input.data());
if(output_shape.scalar()) auto& input_shape = args[0].get_shape();
{ auto lens = input_shape.lens();
gs_launch(stream, 1)( lens[axis_index] = args[1].get_shape().elements();
[=](auto i) { out_ptr[i] = in_ptr[static_cast<int>(indices_ptr[0])]; }); migraphx::shape out_comp_shape{output_shape.type(), lens};
} visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) {
else hip_tensor_descriptor<n_out_dim> desc_input(input_shape);
{ hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape);
// if indices are a scalar, output has one dim smaller than input gs_launch(stream, nelements)([=](auto ii) {
auto& input_shape = args[0].get_shape(); auto in_idx = desc_output.multi(ii);
auto lens = input_shape.lens(); in_idx[axis_index] = indices_ptr[in_idx[axis_index]];
lens[axis_index] = args[1].get_shape().elements(); out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
migraphx::shape out_comp_shape{output_shape.type(), lens};
visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) {
hip_tensor_descriptor<n_out_dim> desc_input(input_shape);
hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape);
gs_launch(stream, nelements)([=](auto ii) {
auto in_idx = desc_output.multi(ii);
in_idx[axis_index] = indices_ptr[in_idx[axis_index]];
out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
});
}); });
} });
}); });
}); });
......
shape-gather-example:O
2value"Constant*
value**B const_tensor constantb
z

B
\ No newline at end of file
...@@ -521,6 +521,15 @@ TEST_CASE(constant_test) ...@@ -521,6 +521,15 @@ TEST_CASE(constant_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(constant_test_scalar)
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}, {0}}, {1}});
auto prog = migraphx::parse_onnx("constant_scalar.onnx");
EXPECT(p == prog);
}
TEST_CASE(constant_fill_test) TEST_CASE(constant_fill_test)
{ {
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment