"test/ref_ops_test.cpp" did not exist on "cf984059287280de8c01b4af3f9fd441718d308c"
Commit dd26f1aa authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into rnn_optimization

parents 4e3d06ab 4a3e493c
......@@ -10,6 +10,8 @@ add_library(migraphx
eliminate_allocation.cpp
eliminate_contiguous.cpp
eliminate_concat.cpp
eliminate_identity.cpp
eliminate_pad.cpp
fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp
env.cpp
......
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
......
......@@ -4,6 +4,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -50,7 +51,8 @@ void dead_code_elimination::apply(program& p) const
assert(p.has_instruction(leaf));
if(leaf->outputs().empty())
{
auto args = leaf->inputs();
std::unordered_set<instruction_ref> args(leaf->inputs().begin(),
leaf->inputs().end());
leaf->clear_arguments();
assert(bidistance(p, last, leaf) < 0);
assert(leaf != ins);
......
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
......
......@@ -2,7 +2,8 @@
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
......
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
......
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_identity::apply(program& p) const
{
auto last = std::prev(p.end());
for(auto ins : iterator_for(p))
{
// Skip the first instruction, since we always process the previous
// instruction
if(ins == p.begin())
continue;
const auto i = std::prev(ins);
if(i->name() == "identity")
{
p.replace_instruction(i, i->inputs().front());
p.move_instruction(i, p.end());
}
if(ins == last)
{
if(ins->name() == "identity")
{
const instruction_ref& identity_input = ins->inputs().front();
if(identity_input->outputs().size() == 1)
{
p.move_instruction(identity_input, i);
// since this is the last instruction, removing it only
// requires changing "last" and calling remove below
last = std::prev(last);
}
}
break;
}
}
p.remove_instructions(std::next(last), p.end());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_pad::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
const std::string& op_name = ins->name();
if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling")
continue;
auto input = ins->inputs().front();
if(input->name() != "pad")
continue;
if(op_name == "convolution")
update_op(op::convolution{}, input, ins, p);
else if(op_name == "im2col")
update_op(op::im2col{}, input, ins, p);
else if(op_name == "pooling")
update_op(op::pooling{}, input, ins, p);
}
}
template <class T>
void eliminate_pad::update_op(T,
const instruction_ref& input,
const instruction_ref& ins,
program& p) const
{
auto pad_op = any_cast<op::pad>(input->get_operator());
if(!pad_op.symmetric())
return;
std::vector<int64_t> pads = pad_op.pads;
std::array<size_t, 2> new_pads{static_cast<size_t>(pads[2]), static_cast<size_t>(pads[3])};
T op = any_cast<T>(ins->get_operator());
if(op.padding_mode != op::padding_mode_t::default_)
return;
op.padding = new_pads;
std::vector<instruction_ref> new_inputs{ins->inputs()};
new_inputs.front() = input->inputs().front();
p.replace_instruction(ins, op, new_inputs);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
......@@ -14,32 +17,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
if(ins->name() != "batch_norm_inference")
continue;
if(not std::all_of(ins->inputs().begin() + 1, ins->inputs().end(), [](auto arg) {
return arg->name() == "@literal";
}))
// Get scale, bias, mean, variance from inputs
auto gamma = ins->inputs()[1]->eval();
auto bias = ins->inputs()[2]->eval();
auto mean = ins->inputs()[3]->eval();
auto variance = ins->inputs()[4]->eval();
if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
continue;
auto conv_ins = ins->inputs()[0];
if(conv_ins->name() != "convolution")
continue;
if(conv_ins->inputs()[1]->name() != "@literal")
// Get convolution weights
auto weights = conv_ins->inputs()[1]->eval();
if(weights.empty())
continue;
// Get scale, bias, mean, variance from instruction_ref
const auto& gamma = ins->inputs()[1]->get_literal();
const auto& bias = ins->inputs()[2]->get_literal();
const auto& mean = ins->inputs()[3]->get_literal();
const auto& variance = ins->inputs()[4]->get_literal();
// Get epsilon
auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon;
// Get convolution weights
const auto& weights = conv_ins->inputs()[1]->get_literal();
// Get convolution op
auto conv_op = conv_ins->get_operator();
auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()};
argument new_bias{bias.get_shape()};
argument new_bias{{bias.get_shape().type(), {bias.get_shape().elements()}}};
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
[&](auto weights2,
auto gamma2,
......@@ -51,11 +52,11 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])(
[&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
new_weights2(k, c, h, w) =
gamma2(k) / std::sqrt(variance2(k) + epsilon) * weights2(k, c, h, w);
gamma2[k] / std::sqrt(variance2[k] + epsilon) * weights2(k, c, h, w);
});
dfor(new_bias.get_shape().elements())([&](std::size_t c) {
new_bias2(c) =
bias2(c) - (gamma2(c) * mean2(c) / std::sqrt(variance2(c) + epsilon));
new_bias2[c] =
bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
});
});
// Replace convolution instruction with updated weights
......
......@@ -18,6 +18,11 @@ struct check_shapes
{
}
template <class Op>
check_shapes(const shape* b, const shape* e, const Op& op) : begin(b), end(e), name(op.name())
{
}
check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {}
template <class Op>
......@@ -119,6 +124,13 @@ struct check_shapes
return *this;
}
const check_shapes& elements(std::size_t n) const
{
if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Wrong number of elements");
return *this;
}
template <class F>
bool same(F f) const
{
......
......@@ -9,7 +9,7 @@
#include <utility>
#include <migraphx/operation.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_IDENTITY_HPP
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_IDENTITY_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Remove identity instructions. Currently when used as the last pass, it will
* preserve the semantics of previous program state, therefore dead code elimination
* should not be used afterwards.
*/
struct eliminate_identity
{
std::string name() const { return "eliminate_identity"; }
void apply(program& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_PAD_HPP
#include <string>
#include <vector>
#include <array>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Remove pads if they can be written as an
* attribute to another op (im2col, convolution, pooling)
*/
struct eliminate_pad
{
std::string name() const { return "eliminate_pad"; }
void apply(program& p) const;
template <class T>
void update_op(T, const instruction_ref& input, const instruction_ref& ins, program& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ABNORMAL_OPS_HPP
#define MIGRAPHX_GUARD_OPERATORS_ABNORMAL_OPS_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct not_computable
{
argument compute(const shape&, const std::vector<argument>&) const
{
MIGRAPHX_THROW("not computable");
}
};
struct undefined
{
std::string name() const { return "undefined"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(0);
return {};
}
argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; }
};
struct unknown
{
std::string op;
std::string name() const { return "unknown:" + op; }
shape compute_shape(std::vector<shape> input) const
{
if(input.empty())
return {};
else
return input.front();
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ABS_HPP
#define MIGRAPHX_GUARD_OPERATORS_ABS_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct abs : unary
{
std::string name() const { return "abs"; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ACOS_HPP
#define MIGRAPHX_GUARD_OPERATORS_ACOS_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct acos : unary
{
std::string name() const { return "acos"; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_ADD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct add : binary
{
std::string name() const { return "add"; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_AS_SHAPE_HPP
#define MIGRAPHX_GUARD_OPERATORS_AS_SHAPE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct as_shape
{
shape s;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.s, "shape"));
}
std::string name() const { return "as_shape"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
assert(inputs.front().elements() == s.elements());
return s;
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ASIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ASIN_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct asin : unary
{
std::string name() const { return "asin"; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ATAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ATAN_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct atan : unary
{
std::string name() const { return "atan"; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment