"...lm-evaluation-harness.git" did not exist on "b29ef52157d2d6a92fb69b86dfd02be3cf895a65"
Commit 2899f9a5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

enhance gather to support scalar as input indices.

parent c50e8004
......@@ -758,27 +758,38 @@ struct gather
int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type();
lens[axis_index] = inputs[1].elements();
return {type, lens};
}
lens.erase(lens.begin() + axis_index);
if (!inputs[1].scalar())
{
auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end());
}
template <class T>
void compute_index(const T& out_idx,
const int axis_index,
const std::vector<std::size_t>& vec_indices,
const std::size_t max_dim,
T& in_idx) const
{
in_idx = out_idx;
std::size_t idx = vec_indices.at(out_idx[axis_index]);
if(idx >= max_dim)
// for scalar output
if (lens.size() == 0)
{
MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
return {type, {1}, {0}};
}
in_idx[axis_index] = idx;
return {type, lens};
}
// template <class T>
// void compute_index(const T& out_idx,
// const int axis_index,
// const std::vector<std::size_t>& vec_indices,
// const std::size_t max_dim,
// T& in_idx) const
// {
// in_idx = out_idx;
// std::size_t idx = vec_indices.at(out_idx[axis_index]);
// if(idx >= max_dim)
// {
// MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
// }
// in_idx[axis_index] = idx;
// }
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
......@@ -786,14 +797,38 @@ struct gather
int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis;
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis_index];
std::vector<std::size_t> vec_indices;
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto input) {
std::vector<std::size_t> in_idx;
shape_for_each(output.get_shape(), [&](const auto& idx) {
this->compute_index(idx, axis_index, vec_indices, max_dim, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
//std::size_t max_dim = args[0].get_shape().lens()[axis_index];
//std::vector<std::size_t> vec_indices;
//args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&] (auto indices) {
if (indices.get_shape().scalar())
{
if (output_shape.scalar())
{
output[0] = data[indices.front()];
}
else {
shape_for_each(output.get_shape(), [&](const auto& out_idx) {
auto data_idx = out_idx;
data_idx.insert(data_idx.begin() + axis_index, indices.front());
output(out_idx.begin(), out_idx.end()) = data(data_idx.begin(), data_idx.end());
});
}
}
else
{
auto ind_lens = indices.get_shape().lens();
shape_for_each(output.get_shape(), [&](const auto& out_idx) {
auto data_idx = out_idx;
auto start_it = data_idx.begin() + axis_index;
auto end_it = data_idx.begin() + axis_index + ind_lens.size();
std::vector<std::size_t> ind_idx(start_it, end_it);
data_idx.erase(start_it, end_it);
data_idx.insert(start_it, indices(ind_idx.begin(), ind_idx.end()));
output(out_idx.begin(), out_idx.end()) = data(data_idx.begin(), data_idx.end());
});
}
});
});
......
......@@ -434,6 +434,14 @@ struct onnx_parser
const std::vector<instruction_ref>&)
{
literal v = parse_value(attributes.at("value"));
migraphx::shape v_shape = v.get_shape();
// for constant containing 1 element, consider it as a scalar
if (v_shape.elements() == 1)
{
migraphx::shape scalar_shape{v_shape.type(), {1}, {0}};
return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
}
return prog.add_literal(v);
}
......@@ -460,6 +468,18 @@ struct onnx_parser
{
transb = parse_value(attributes.at("transB")).at<bool>();
}
// beginning or end of both args have dimension 1, need to squeeze
// before calling gemm, then doing unsqueeze after getting results
std::size_t num_squeeze = args[0]->get_shape().lens().size();
if (num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
args[0] = prog.add_instruction(op::squeeze{vec_axises}, args[0]);
args[1] = prog.add_instruction(op::squeeze{vec_axises}, args[1]);
}
std::vector<int64_t> perm = {1, 0};
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
......@@ -468,6 +488,13 @@ struct onnx_parser
if(beta != 0.f)
{
auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
if (num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
l3 = prog.add_instruction(op::unsqueeze{vec_axises}, l3);
}
auto l4 = args[2];
if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B)
return l3;
......@@ -480,7 +507,16 @@ struct onnx_parser
return add_broadcastable_binary_op(l3, l4, op::add{});
}
}
return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
auto dot_res = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
if (num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
dot_res = prog.add_instruction(op::unsqueeze{vec_axises}, dot_res);
}
return dot_res;
}
instruction_ref
......
......@@ -20,18 +20,65 @@ argument gather(hipStream_t stream,
visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
const auto* indices_ptr = device_cast(indices.data());
auto* outptr = device_cast(output.data());
const auto* inptr = device_cast(input.data());
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i);
lens[axis_index] = indices_ptr[lens[axis_index]];
outptr[i] = inptr[desc_input.linear(lens)];
const auto* indices_ptr = device_cast(indices.data());
auto* outptr = device_cast(output.data());
const auto* inptr = device_cast(input.data());
if (output_shape.scalar())
{
gs_launch(stream, 1)([=](auto i) {
outptr[i] = inptr[indices_ptr[0]];
});
});
}
else {
visit_tensor_size(output_shape.lens().size(), [&](auto n_out_dim) {
visit_tensor_size(args[0].get_shape().lens().size(), [&](auto n_in_dim) {
hip_tensor_descriptor<n_in_dim> desc_input(input.get_shape());
hip_tensor_descriptor<n_out_dim> desc_output(output.get_shape());
if (args[1].get_shape().scalar())
{
gs_launch(stream, nelements)([=](auto ii) {
auto out_idx = desc_output.multi(ii);
auto in_idx = desc_input.multi(0);
for (int i = 0; i < axis_index; ++i)
{
in_idx[i] = out_idx[i];
}
in_idx[axis_index] = indices_ptr[0];
for (int i = axis_index + 1; i < n_in_dim; ++i)
{
in_idx[i] = out_idx[i - 1];
}
outptr[ii] = inptr[desc_input.linear(in_idx)];
});
}
else
{
visit_tensor_size(args[1].get_shape().lens().size(), [&](auto n_ind_dim) {
hip_tensor_descriptor<n_ind_dim> desc_ind(args[1].get_shape());
gs_launch(stream, nelements)([=](auto ii) {
auto out_idx = desc_output.multi(ii);
auto in_idx = desc_input.multi(0);
for (int i = 0; i < axis_index; ++i)
{
in_idx[i] = out_idx[i];
}
auto ind_idx = desc_ind.multi(0);
for (int i = 0; i < n_ind_dim; ++i)
{
ind_idx[i] = out_idx[i + axis_index];
}
in_idx[axis_index] = indices_ptr[desc_ind.linear(ind_idx)];
for (int i = axis_index + 1; i < n_in_dim; ++i)
{
in_idx[i] = out_idx[i + n_ind_dim - 1];
}
outptr[ii] = inptr[desc_input.linear(in_idx)];
});
});
}
});
});
}
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment