"...git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "f52160107bafe7f6248e71fa6876576cc3425282"
Commit c2aa86c1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format.

parent 2899f9a5
...@@ -757,16 +757,16 @@ struct gather ...@@ -757,16 +757,16 @@ struct gather
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (n_dim + axis) : axis; int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type(); auto type = inputs[0].type();
lens.erase(lens.begin() + axis_index); lens.erase(lens.begin() + axis_index);
if (!inputs[1].scalar()) if(!inputs[1].scalar())
{ {
auto ind_lens = inputs[1].lens(); auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end()); lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end());
} }
// for scalar output // for scalar output
if (lens.size() == 0) if(lens.size() == 0)
{ {
return {type, {1}, {0}}; return {type, {1}, {0}};
} }
...@@ -797,22 +797,24 @@ struct gather ...@@ -797,22 +797,24 @@ struct gather
int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis; int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis;
// max dimension in axis // max dimension in axis
//std::size_t max_dim = args[0].get_shape().lens()[axis_index]; // std::size_t max_dim = args[0].get_shape().lens()[axis_index];
//std::vector<std::size_t> vec_indices; // std::vector<std::size_t> vec_indices;
//args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); }); // args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto data) { visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&] (auto indices) { args[1].visit([&](auto indices) {
if (indices.get_shape().scalar()) if(indices.get_shape().scalar())
{ {
if (output_shape.scalar()) if(output_shape.scalar())
{ {
output[0] = data[indices.front()]; output[0] = data[indices.front()];
} }
else { else
{
shape_for_each(output.get_shape(), [&](const auto& out_idx) { shape_for_each(output.get_shape(), [&](const auto& out_idx) {
auto data_idx = out_idx; auto data_idx = out_idx;
data_idx.insert(data_idx.begin() + axis_index, indices.front()); data_idx.insert(data_idx.begin() + axis_index, indices.front());
output(out_idx.begin(), out_idx.end()) = data(data_idx.begin(), data_idx.end()); output(out_idx.begin(), out_idx.end()) =
data(data_idx.begin(), data_idx.end());
}); });
} }
} }
...@@ -822,12 +824,13 @@ struct gather ...@@ -822,12 +824,13 @@ struct gather
shape_for_each(output.get_shape(), [&](const auto& out_idx) { shape_for_each(output.get_shape(), [&](const auto& out_idx) {
auto data_idx = out_idx; auto data_idx = out_idx;
auto start_it = data_idx.begin() + axis_index; auto start_it = data_idx.begin() + axis_index;
auto end_it = data_idx.begin() + axis_index + ind_lens.size(); auto end_it = data_idx.begin() + axis_index + ind_lens.size();
std::vector<std::size_t> ind_idx(start_it, end_it); std::vector<std::size_t> ind_idx(start_it, end_it);
data_idx.erase(start_it, end_it); data_idx.erase(start_it, end_it);
data_idx.insert(start_it, indices(ind_idx.begin(), ind_idx.end())); data_idx.insert(start_it, indices(ind_idx.begin(), ind_idx.end()));
output(out_idx.begin(), out_idx.end()) = data(data_idx.begin(), data_idx.end()); output(out_idx.begin(), out_idx.end()) =
}); data(data_idx.begin(), data_idx.end());
});
} }
}); });
}); });
......
...@@ -433,10 +433,10 @@ struct onnx_parser ...@@ -433,10 +433,10 @@ struct onnx_parser
attribute_map attributes, attribute_map attributes,
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
{ {
literal v = parse_value(attributes.at("value")); literal v = parse_value(attributes.at("value"));
migraphx::shape v_shape = v.get_shape(); migraphx::shape v_shape = v.get_shape();
// for constant containing 1 element, consider it as a scalar // for constant containing 1 element, consider it as a scalar
if (v_shape.elements() == 1) if(v_shape.elements() == 1)
{ {
migraphx::shape scalar_shape{v_shape.type(), {1}, {0}}; migraphx::shape scalar_shape{v_shape.type(), {1}, {0}};
return prog.add_literal(migraphx::literal{scalar_shape, v.data()}); return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
...@@ -472,7 +472,7 @@ struct onnx_parser ...@@ -472,7 +472,7 @@ struct onnx_parser
// beginning or end of both args have dimension 1, need to squeeze // beginning or end of both args have dimension 1, need to squeeze
// before calling gemm, then doing unsqueeze after getting results // before calling gemm, then doing unsqueeze after getting results
std::size_t num_squeeze = args[0]->get_shape().lens().size(); std::size_t num_squeeze = args[0]->get_shape().lens().size();
if (num_squeeze > 2) if(num_squeeze > 2)
{ {
std::vector<int64_t> vec_axises(num_squeeze - 2); std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0); std::iota(vec_axises.begin(), vec_axises.end(), 0);
...@@ -488,7 +488,7 @@ struct onnx_parser ...@@ -488,7 +488,7 @@ struct onnx_parser
if(beta != 0.f) if(beta != 0.f)
{ {
auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2); auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
if (num_squeeze > 2) if(num_squeeze > 2)
{ {
std::vector<int64_t> vec_axises(num_squeeze - 2); std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0); std::iota(vec_axises.begin(), vec_axises.end(), 0);
...@@ -509,7 +509,7 @@ struct onnx_parser ...@@ -509,7 +509,7 @@ struct onnx_parser
} }
auto dot_res = prog.add_instruction(op::dot{alpha, beta}, l1, l2); auto dot_res = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
if (num_squeeze > 2) if(num_squeeze > 2)
{ {
std::vector<int64_t> vec_axises(num_squeeze - 2); std::vector<int64_t> vec_axises(num_squeeze - 2);
std::iota(vec_axises.begin(), vec_axises.end(), 0); std::iota(vec_axises.begin(), vec_axises.end(), 0);
......
...@@ -23,58 +23,58 @@ argument gather(hipStream_t stream, ...@@ -23,58 +23,58 @@ argument gather(hipStream_t stream,
const auto* indices_ptr = device_cast(indices.data()); const auto* indices_ptr = device_cast(indices.data());
auto* outptr = device_cast(output.data()); auto* outptr = device_cast(output.data());
const auto* inptr = device_cast(input.data()); const auto* inptr = device_cast(input.data());
if (output_shape.scalar()) if(output_shape.scalar())
{ {
gs_launch(stream, 1)([=](auto i) { gs_launch(stream, 1)([=](auto i) { outptr[i] = inptr[indices_ptr[0]]; });
outptr[i] = inptr[indices_ptr[0]];
});
} }
else { else
{
visit_tensor_size(output_shape.lens().size(), [&](auto n_out_dim) { visit_tensor_size(output_shape.lens().size(), [&](auto n_out_dim) {
visit_tensor_size(args[0].get_shape().lens().size(), [&](auto n_in_dim) { visit_tensor_size(args[0].get_shape().lens().size(), [&](auto n_in_dim) {
hip_tensor_descriptor<n_in_dim> desc_input(input.get_shape()); hip_tensor_descriptor<n_in_dim> desc_input(input.get_shape());
hip_tensor_descriptor<n_out_dim> desc_output(output.get_shape()); hip_tensor_descriptor<n_out_dim> desc_output(output.get_shape());
if (args[1].get_shape().scalar()) if(args[1].get_shape().scalar())
{ {
gs_launch(stream, nelements)([=](auto ii) { gs_launch(stream, nelements)([=](auto ii) {
auto out_idx = desc_output.multi(ii); auto out_idx = desc_output.multi(ii);
auto in_idx = desc_input.multi(0); auto in_idx = desc_input.multi(0);
for (int i = 0; i < axis_index; ++i) for(int i = 0; i < axis_index; ++i)
{ {
in_idx[i] = out_idx[i]; in_idx[i] = out_idx[i];
} }
in_idx[axis_index] = indices_ptr[0]; in_idx[axis_index] = indices_ptr[0];
for (int i = axis_index + 1; i < n_in_dim; ++i) for(int i = axis_index + 1; i < n_in_dim; ++i)
{ {
in_idx[i] = out_idx[i - 1]; in_idx[i] = out_idx[i - 1];
} }
outptr[ii] = inptr[desc_input.linear(in_idx)]; outptr[ii] = inptr[desc_input.linear(in_idx)];
}); });
} }
else else
{ {
visit_tensor_size(args[1].get_shape().lens().size(), [&](auto n_ind_dim) { visit_tensor_size(
hip_tensor_descriptor<n_ind_dim> desc_ind(args[1].get_shape()); args[1].get_shape().lens().size(), [&](auto n_ind_dim) {
gs_launch(stream, nelements)([=](auto ii) { hip_tensor_descriptor<n_ind_dim> desc_ind(args[1].get_shape());
auto out_idx = desc_output.multi(ii); gs_launch(stream, nelements)([=](auto ii) {
auto in_idx = desc_input.multi(0); auto out_idx = desc_output.multi(ii);
for (int i = 0; i < axis_index; ++i) auto in_idx = desc_input.multi(0);
{ for(int i = 0; i < axis_index; ++i)
in_idx[i] = out_idx[i]; {
} in_idx[i] = out_idx[i];
auto ind_idx = desc_ind.multi(0); }
for (int i = 0; i < n_ind_dim; ++i) auto ind_idx = desc_ind.multi(0);
{ for(int i = 0; i < n_ind_dim; ++i)
ind_idx[i] = out_idx[i + axis_index]; {
} ind_idx[i] = out_idx[i + axis_index];
in_idx[axis_index] = indices_ptr[desc_ind.linear(ind_idx)]; }
for (int i = axis_index + 1; i < n_in_dim; ++i) in_idx[axis_index] = indices_ptr[desc_ind.linear(ind_idx)];
{ for(int i = axis_index + 1; i < n_in_dim; ++i)
in_idx[i] = out_idx[i + n_ind_dim - 1]; {
} in_idx[i] = out_idx[i + n_ind_dim - 1];
outptr[ii] = inptr[desc_input.linear(in_idx)]; }
outptr[ii] = inptr[desc_input.linear(in_idx)];
});
}); });
});
} }
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment