Commit c21a1b28 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 892d1100
...@@ -433,7 +433,7 @@ struct onnx_parser ...@@ -433,7 +433,7 @@ struct onnx_parser
attribute_map attributes, attribute_map attributes,
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
{ {
literal v = parse_value(attributes.at("value")); literal v = parse_value(attributes.at("value"));
auto dim_size = attributes.at("value").t().dims_size(); auto dim_size = attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar // if dim_size is 0, it is a scalar
if(dim_size == 0) if(dim_size == 0)
......
...@@ -21,27 +21,27 @@ argument gather(hipStream_t stream, ...@@ -21,27 +21,27 @@ argument gather(hipStream_t stream,
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data()); const auto* indices_ptr = device_cast(indices.data());
auto* out_ptr = device_cast(output.data()); auto* out_ptr = device_cast(output.data());
const auto* in_ptr = device_cast(input.data()); const auto* in_ptr = device_cast(input.data());
if(output_shape.scalar()) if(output_shape.scalar())
{ {
gs_launch(stream, gs_launch(stream, 1)(
1)([=](auto i) { out_ptr[i] = in_ptr[static_cast<int>(indices_ptr[0])]; }); [=](auto i) { out_ptr[i] = in_ptr[static_cast<int>(indices_ptr[0])]; });
} }
else else
{ {
// if indices are a scalar, output has one dim smaller than input // if indices are a scalar, output has one dim smaller than input
auto& input_shape = args[0].get_shape(); auto& input_shape = args[0].get_shape();
auto lens = input_shape.lens(); auto lens = input_shape.lens();
lens[axis_index] = args[1].get_shape().elements(); lens[axis_index] = args[1].get_shape().elements();
migraphx::shape out_comp_shape{output_shape.type(), lens}; migraphx::shape out_comp_shape{output_shape.type(), lens};
visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) { visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) {
hip_tensor_descriptor<n_out_dim> desc_input(input_shape); hip_tensor_descriptor<n_out_dim> desc_input(input_shape);
hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape); hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape);
gs_launch(stream, nelements)([=](auto ii) { gs_launch(stream, nelements)([=](auto ii) {
auto in_idx = desc_output.multi(ii); auto in_idx = desc_output.multi(ii);
in_idx[axis_index] = indices_ptr[in_idx[axis_index]]; in_idx[axis_index] = indices_ptr[in_idx[axis_index]];
out_ptr[ii] = in_ptr[desc_input.linear(in_idx)]; out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
}); });
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment