Commit c21a1b28 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 892d1100
......@@ -433,7 +433,7 @@ struct onnx_parser
attribute_map attributes,
const std::vector<instruction_ref>&)
{
literal v = parse_value(attributes.at("value"));
literal v = parse_value(attributes.at("value"));
auto dim_size = attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar
if(dim_size == 0)
......
......@@ -21,27 +21,27 @@ argument gather(hipStream_t stream,
std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data());
auto* out_ptr = device_cast(output.data());
const auto* in_ptr = device_cast(input.data());
auto* out_ptr = device_cast(output.data());
const auto* in_ptr = device_cast(input.data());
if(output_shape.scalar())
{
gs_launch(stream,
1)([=](auto i) { out_ptr[i] = in_ptr[static_cast<int>(indices_ptr[0])]; });
gs_launch(stream, 1)(
[=](auto i) { out_ptr[i] = in_ptr[static_cast<int>(indices_ptr[0])]; });
}
else
{
// if indices are a scalar, output has one dim smaller than input
auto& input_shape = args[0].get_shape();
auto lens = input_shape.lens();
lens[axis_index] = args[1].get_shape().elements();
auto lens = input_shape.lens();
lens[axis_index] = args[1].get_shape().elements();
migraphx::shape out_comp_shape{output_shape.type(), lens};
visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) {
hip_tensor_descriptor<n_out_dim> desc_input(input_shape);
hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape);
gs_launch(stream, nelements)([=](auto ii) {
auto in_idx = desc_output.multi(ii);
auto in_idx = desc_output.multi(ii);
in_idx[axis_index] = indices_ptr[in_idx[axis_index]];
out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
});
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment