"vscode:/vscode.git/clone" did not exist on "09740556c529d65bbee562c1d736ee3060acd6aa"
Commit b0586006 authored by Paul's avatar Paul
Browse files

Formatting

parent 9b32ae0c
......@@ -59,7 +59,9 @@ struct reshape
shape s{inputs.front().type(), rdims};
if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("Wrong number of elements for reshape: reshape has " + std::to_string(s.elements()) + " elements whereas the input has " + std::to_string(inputs.front().elements()));
MIGRAPHX_THROW("Wrong number of elements for reshape: reshape has " +
std::to_string(s.elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));
return s;
}
argument compute(shape output_shape, std::vector<argument> args) const
......
......@@ -31,7 +31,8 @@ void rewrite_pooling::apply(program& prog) const
continue;
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
auto reshape = prog.insert_instruction(ins, op::reshape{{n*c, -1}}, ins->inputs().front());
auto reshape =
prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front());
auto pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape);
prog.replace_instruction(ins, op::reshape{{n, c, 1, 1}}, pooling);
}
......
......@@ -276,10 +276,25 @@ void reduce(hipStream_t stream,
{
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
if (input_shape.standard() and output_shape.standard() and output_shape.lens().back() != input_shape.lens().back() and std::equal(output_shape.lens().begin(), std::prev(output_shape.lens().end()), input_shape.lens().begin()))
if(input_shape.standard() and output_shape.standard() and
output_shape.lens().back() != input_shape.lens().back() and
std::equal(output_shape.lens().begin(),
std::prev(output_shape.lens().end()),
input_shape.lens().begin()))
{
std::size_t stride = std::accumulate(input_shape.strides().begin(), input_shape.strides().end(), 1, std::multiplies<size_t>());
reduce_standard_impl(stream, result, arg, op, init, read_input, read_output, input_shape.lens().back(), stride);
std::size_t stride = std::accumulate(input_shape.strides().begin(),
input_shape.strides().end(),
1,
std::multiplies<size_t>());
reduce_standard_impl(stream,
result,
arg,
op,
init,
read_input,
read_output,
input_shape.lens().back(),
stride);
}
else
{
......@@ -296,7 +311,6 @@ void reduce(hipStream_t stream,
});
shape reduce_slice{output_shape.type(), reduce_lens};
reduce_multi_impl(stream, result, arg, op, init, read_input, read_output, reduce_slice);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment