Commit 4d2516e1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into capture_op_more_changes
parents 8efe1eeb 0628e570
...@@ -14,6 +14,7 @@ add_library(migraphx ...@@ -14,6 +14,7 @@ add_library(migraphx
eliminate_pad.cpp eliminate_pad.cpp
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
rewrite_pooling.cpp
env.cpp env.cpp
generate.cpp generate.cpp
instruction.cpp instruction.cpp
......
...@@ -59,7 +59,9 @@ struct reshape ...@@ -59,7 +59,9 @@ struct reshape
shape s{inputs.front().type(), rdims}; shape s{inputs.front().type(), rdims};
if(s.elements() != inputs.front().elements()) if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("Wrong number of elements for reshape"); MIGRAPHX_THROW("Wrong number of elements for reshape: reshape has " +
std::to_string(s.elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));
return s; return s;
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_POOLING_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_POOLING_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Rewrite pooling to reduce_mean
*/
struct rewrite_pooling
{
std::string name() const { return "rewrite_pooling"; }
void apply(program& prog) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -132,7 +132,11 @@ struct tensor_view ...@@ -132,7 +132,11 @@ struct tensor_view
return m_data + this->size(); return m_data + this->size();
} }
std::vector<T> to_vector() const { return std::vector<T>(this->begin(), this->end()); } template <class U = T>
std::vector<U> to_vector() const
{
return std::vector<U>(this->begin(), this->end());
}
friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x) friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
{ {
......
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(program& prog) const
{
for(auto ins : iterator_for(prog))
{
if(ins->name() != "pooling")
continue;
if(ins->get_shape().lens().size() != 4)
continue;
if(ins->inputs().empty())
continue;
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
if(op.mode != "average")
continue;
if(op.padding[0] != 0 and op.padding[1] != 0)
continue;
if(op.stride[0] != 1 and op.stride[1] != 1)
continue;
if(s.lens()[2] != op.lengths[0] and s.lens()[3] != op.lengths[1])
continue;
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
auto reshape =
prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front());
auto pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape);
prog.replace_instruction(ins, op::reshape{{n, c, 1, 1}}, pooling);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -16,6 +16,12 @@ struct hip_array ...@@ -16,6 +16,12 @@ struct hip_array
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) { return d[i]; } MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](std::size_t i) const { return d[i]; } MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](std::size_t i) const { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR T& front() { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& front() const { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR T& back() { return d[N - 1]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& back() const { return d[N - 1]; }
MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; } MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; }
......
...@@ -209,28 +209,15 @@ constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_si ...@@ -209,28 +209,15 @@ constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_si
} }
template <class Op, class T, class Input, class Output> template <class Op, class T, class Input, class Output>
void reduce(hipStream_t stream, void reduce_multi_impl(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg, const argument& arg,
Op op, Op op,
T init, T init,
Input read_input, Input read_input,
Output read_output) Output read_output,
const shape& reduce_slice)
{ {
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
std::vector<std::size_t> reduce_lens;
std::transform(output_shape.lens().begin(),
output_shape.lens().end(),
input_shape.lens().begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
shape reduce_slice{output_shape.type(), reduce_lens};
hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) { hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) {
auto nelements = result.get_shape().elements(); auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements(); auto relements = reduce_slice.elements();
...@@ -250,6 +237,83 @@ void reduce(hipStream_t stream, ...@@ -250,6 +237,83 @@ void reduce(hipStream_t stream,
}); });
} }
template <class Op, class T, class Input, class Output>
void reduce_standard_impl(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output,
std::size_t relements,
std::size_t stride)
{
hip_visit_all(result, arg)([&](auto output, auto input) {
auto nelements = result.get_shape().elements();
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size;
const auto base_idx = out_idx * stride;
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
return read_input(input.data()[base_idx + j]);
});
if(idx.local == 0)
output.data()[out_idx] = read_output(r);
});
});
}
template <class Op, class T, class Input, class Output>
void reduce(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output)
{
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
if(input_shape.standard() and output_shape.standard() and
output_shape.lens().back() != input_shape.lens().back() and
std::equal(output_shape.lens().begin(),
std::prev(output_shape.lens().end()),
input_shape.lens().begin()))
{
std::size_t stride = std::accumulate(input_shape.strides().begin(),
input_shape.strides().end(),
1,
std::multiplies<size_t>());
reduce_standard_impl(stream,
result,
arg,
op,
init,
read_input,
read_output,
input_shape.lens().back(),
stride);
}
else
{
std::vector<std::size_t> reduce_lens;
std::transform(output_shape.lens().begin(),
output_shape.lens().end(),
input_shape.lens().begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
shape reduce_slice{output_shape.type(), reduce_lens};
reduce_multi_impl(stream, result, arg, op, init, read_input, read_output, reduce_slice);
}
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -134,8 +134,6 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -134,8 +134,6 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto conv = any_cast<miopen_convolution>(ins->get_operator()); auto conv = any_cast<miopen_convolution>(ins->get_operator());
if(conv.op.group > 1) if(conv.op.group > 1)
return false; return false;
if(conv.op.padding_mode != op::padding_mode_t::default_)
return false;
if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd) if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
return false; return false;
auto op = conv.op; auto op = conv.op;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
...@@ -46,6 +47,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -46,6 +47,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
fwd_conv_batchnorm_rewrite{}, fwd_conv_batchnorm_rewrite{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
//common_subexpression_elimination{}, //common_subexpression_elimination{},
//dead_code_elimination{}, //dead_code_elimination{},
......
...@@ -574,23 +574,18 @@ struct tf_parser ...@@ -574,23 +574,18 @@ struct tf_parser
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3}; auto lens = args[0]->get_shape().lens();
// check if conditions for GlobalAvgPool are met auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector<int64_t>(), lens.size());
auto lens = args[0]->get_shape().lens();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector(), lens.size());
if(axes == hw_axes and lens.size() == 4) if(keep_dims)
{ {
op::pooling op{"average"}; return prog.add_instruction(op::reduce_mean{axes}, args[0]);
op.lengths[0] = lens[2]; }
op.lengths[1] = lens[3]; else
auto l0 = prog.add_instruction(op, args.front()); {
if(keep_dims) auto ins = prog.add_instruction(op::reduce_mean{axes}, args[0]);
return l0; return prog.add_instruction(op::squeeze{axes}, ins);
return prog.add_instruction(
op::squeeze{std::vector<int64_t>(hw_axes.begin(), hw_axes.end())}, l0);
} }
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
} }
instruction_ref parse_pack(const std::string&, instruction_ref parse_pack(const std::string&,
......
...@@ -257,8 +257,7 @@ TEST_CASE(mean_test) ...@@ -257,8 +257,7 @@ TEST_CASE(mean_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l); p.add_literal(l);
p.add_literal(l); p.add_literal(l);
migraphx::op::pooling op; migraphx::op::reduce_mean op{{2, 3}};
op.lengths = {16, 16};
p.add_instruction(op, l0); p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0); auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
...@@ -272,9 +271,8 @@ TEST_CASE(mean_test_nhwc) ...@@ -272,9 +271,8 @@ TEST_CASE(mean_test_nhwc)
migraphx::program p; migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::op::pooling op; migraphx::op::reduce_mean op{{2, 3}};
op.lengths = {16, 16}; auto l3 = p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = optimize_tf("mean_test_nhwc.pb", true); auto prog = optimize_tf("mean_test_nhwc.pb", true);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment