Commit c2aa86c1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format.

parent 2899f9a5
......@@ -757,16 +757,16 @@ struct gather
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type();
auto type = inputs[0].type();
lens.erase(lens.begin() + axis_index);
if (!inputs[1].scalar())
if(!inputs[1].scalar())
{
auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end());
}
// for scalar output
if (lens.size() == 0)
if(lens.size() == 0)
{
return {type, {1}, {0}};
}
......@@ -797,22 +797,24 @@ struct gather
int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis;
// max dimension in axis
//std::size_t max_dim = args[0].get_shape().lens()[axis_index];
//std::vector<std::size_t> vec_indices;
//args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
// std::size_t max_dim = args[0].get_shape().lens()[axis_index];
// std::vector<std::size_t> vec_indices;
// args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&] (auto indices) {
if (indices.get_shape().scalar())
args[1].visit([&](auto indices) {
if(indices.get_shape().scalar())
{
if (output_shape.scalar())
if(output_shape.scalar())
{
output[0] = data[indices.front()];
}
else {
else
{
shape_for_each(output.get_shape(), [&](const auto& out_idx) {
auto data_idx = out_idx;
data_idx.insert(data_idx.begin() + axis_index, indices.front());
output(out_idx.begin(), out_idx.end()) = data(data_idx.begin(), data_idx.end());
output(out_idx.begin(), out_idx.end()) =
data(data_idx.begin(), data_idx.end());
});
}
}
......@@ -822,12 +824,13 @@ struct gather
shape_for_each(output.get_shape(), [&](const auto& out_idx) {
auto data_idx = out_idx;
auto start_it = data_idx.begin() + axis_index;
auto end_it = data_idx.begin() + axis_index + ind_lens.size();
auto end_it = data_idx.begin() + axis_index + ind_lens.size();
std::vector<std::size_t> ind_idx(start_it, end_it);
data_idx.erase(start_it, end_it);
data_idx.insert(start_it, indices(ind_idx.begin(), ind_idx.end()));
output(out_idx.begin(), out_idx.end()) = data(data_idx.begin(), data_idx.end());
});
output(out_idx.begin(), out_idx.end()) =
data(data_idx.begin(), data_idx.end());
});
}
});
});
......
......@@ -433,10 +433,10 @@ struct onnx_parser
attribute_map attributes,
const std::vector<instruction_ref>&)
{
literal v = parse_value(attributes.at("value"));
literal v = parse_value(attributes.at("value"));
migraphx::shape v_shape = v.get_shape();
// for constant containing 1 element, consider it as a scalar
if (v_shape.elements() == 1)
if(v_shape.elements() == 1)
{
migraphx::shape scalar_shape{v_shape.type(), {1}, {0}};
return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
......@@ -472,7 +472,7 @@ struct onnx_parser
// beginning or end of both args have dimension 1, need to squeeze
// before calling gemm, then doing unsqueeze after getting results
std::size_t num_squeeze = args[0]->get_shape().lens().size();
if (num_squeeze > 2)
if(num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
......@@ -488,7 +488,7 @@ struct onnx_parser
if(beta != 0.f)
{
auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
if (num_squeeze > 2)
if(num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
......@@ -509,7 +509,7 @@ struct onnx_parser
}
auto dot_res = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
if (num_squeeze > 2)
if(num_squeeze > 2)
{
std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0);
......
......@@ -23,58 +23,58 @@ argument gather(hipStream_t stream,
const auto* indices_ptr = device_cast(indices.data());
auto* outptr = device_cast(output.data());
const auto* inptr = device_cast(input.data());
if (output_shape.scalar())
if(output_shape.scalar())
{
gs_launch(stream, 1)([=](auto i) {
outptr[i] = inptr[indices_ptr[0]];
});
gs_launch(stream, 1)([=](auto i) { outptr[i] = inptr[indices_ptr[0]]; });
}
else {
else
{
visit_tensor_size(output_shape.lens().size(), [&](auto n_out_dim) {
visit_tensor_size(args[0].get_shape().lens().size(), [&](auto n_in_dim) {
hip_tensor_descriptor<n_in_dim> desc_input(input.get_shape());
hip_tensor_descriptor<n_out_dim> desc_output(output.get_shape());
if (args[1].get_shape().scalar())
if(args[1].get_shape().scalar())
{
gs_launch(stream, nelements)([=](auto ii) {
auto out_idx = desc_output.multi(ii);
auto in_idx = desc_input.multi(0);
for (int i = 0; i < axis_index; ++i)
auto out_idx = desc_output.multi(ii);
auto in_idx = desc_input.multi(0);
for(int i = 0; i < axis_index; ++i)
{
in_idx[i] = out_idx[i];
}
in_idx[axis_index] = indices_ptr[0];
for (int i = axis_index + 1; i < n_in_dim; ++i)
for(int i = axis_index + 1; i < n_in_dim; ++i)
{
in_idx[i] = out_idx[i - 1];
}
outptr[ii] = inptr[desc_input.linear(in_idx)];
outptr[ii] = inptr[desc_input.linear(in_idx)];
});
}
else
else
{
visit_tensor_size(args[1].get_shape().lens().size(), [&](auto n_ind_dim) {
hip_tensor_descriptor<n_ind_dim> desc_ind(args[1].get_shape());
gs_launch(stream, nelements)([=](auto ii) {
auto out_idx = desc_output.multi(ii);
auto in_idx = desc_input.multi(0);
for (int i = 0; i < axis_index; ++i)
{
in_idx[i] = out_idx[i];
}
auto ind_idx = desc_ind.multi(0);
for (int i = 0; i < n_ind_dim; ++i)
{
ind_idx[i] = out_idx[i + axis_index];
}
in_idx[axis_index] = indices_ptr[desc_ind.linear(ind_idx)];
for (int i = axis_index + 1; i < n_in_dim; ++i)
{
in_idx[i] = out_idx[i + n_ind_dim - 1];
}
outptr[ii] = inptr[desc_input.linear(in_idx)];
visit_tensor_size(
args[1].get_shape().lens().size(), [&](auto n_ind_dim) {
hip_tensor_descriptor<n_ind_dim> desc_ind(args[1].get_shape());
gs_launch(stream, nelements)([=](auto ii) {
auto out_idx = desc_output.multi(ii);
auto in_idx = desc_input.multi(0);
for(int i = 0; i < axis_index; ++i)
{
in_idx[i] = out_idx[i];
}
auto ind_idx = desc_ind.multi(0);
for(int i = 0; i < n_ind_dim; ++i)
{
ind_idx[i] = out_idx[i + axis_index];
}
in_idx[axis_index] = indices_ptr[desc_ind.linear(ind_idx)];
for(int i = axis_index + 1; i < n_in_dim; ++i)
{
in_idx[i] = out_idx[i + n_ind_dim - 1];
}
outptr[ii] = inptr[desc_input.linear(in_idx)];
});
});
});
}
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment