"src/layout/gemm_layouts.cc" did not exist on "64f17c2f369e612cc297d358f607307a615bbb59"
Commit b0586006 authored by Paul's avatar Paul
Browse files

Formatting

parent 9b32ae0c
......@@ -59,7 +59,9 @@ struct reshape
shape s{inputs.front().type(), rdims};
if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("Wrong number of elements for reshape: reshape has " + std::to_string(s.elements()) + " elements whereas the input has " + std::to_string(inputs.front().elements()));
MIGRAPHX_THROW("Wrong number of elements for reshape: reshape has " +
std::to_string(s.elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));
return s;
}
argument compute(shape output_shape, std::vector<argument> args) const
......
......@@ -31,7 +31,8 @@ void rewrite_pooling::apply(program& prog) const
continue;
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
auto reshape = prog.insert_instruction(ins, op::reshape{{n*c, -1}}, ins->inputs().front());
auto reshape =
prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front());
auto pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape);
prog.replace_instruction(ins, op::reshape{{n, c, 1, 1}}, pooling);
}
......
......@@ -210,13 +210,13 @@ constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_si
template <class Op, class T, class Input, class Output>
void reduce_multi_impl(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output,
const shape& reduce_slice)
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output,
const shape& reduce_slice)
{
hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) {
auto nelements = result.get_shape().elements();
......@@ -239,14 +239,14 @@ void reduce_multi_impl(hipStream_t stream,
template <class Op, class T, class Input, class Output>
void reduce_standard_impl(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output,
std::size_t relements,
std::size_t stride)
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output,
std::size_t relements,
std::size_t stride)
{
hip_visit_all(result, arg)([&](auto output, auto input) {
auto nelements = result.get_shape().elements();
......@@ -254,8 +254,8 @@ void reduce_standard_impl(hipStream_t stream,
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size;
const auto base_idx = out_idx * stride;
const auto out_idx = i / block_size;
const auto base_idx = out_idx * stride;
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
return read_input(input.data()[base_idx + j]);
});
......@@ -276,10 +276,25 @@ void reduce(hipStream_t stream,
{
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
if (input_shape.standard() and output_shape.standard() and output_shape.lens().back() != input_shape.lens().back() and std::equal(output_shape.lens().begin(), std::prev(output_shape.lens().end()), input_shape.lens().begin()))
if(input_shape.standard() and output_shape.standard() and
output_shape.lens().back() != input_shape.lens().back() and
std::equal(output_shape.lens().begin(),
std::prev(output_shape.lens().end()),
input_shape.lens().begin()))
{
std::size_t stride = std::accumulate(input_shape.strides().begin(), input_shape.strides().end(), 1, std::multiplies<size_t>());
reduce_standard_impl(stream, result, arg, op, init, read_input, read_output, input_shape.lens().back(), stride);
std::size_t stride = std::accumulate(input_shape.strides().begin(),
input_shape.strides().end(),
1,
std::multiplies<size_t>());
reduce_standard_impl(stream,
result,
arg,
op,
init,
read_input,
read_output,
input_shape.lens().back(),
stride);
}
else
{
......@@ -296,7 +311,6 @@ void reduce(hipStream_t stream,
});
shape reduce_slice{output_shape.type(), reduce_lens};
reduce_multi_impl(stream, result, arg, op, init, read_input, read_output, reduce_slice);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment