Unverified Commit 8f08607e authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge branch 'develop' into rm_identity

parents e2cf822d 15ba8a36
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
namespace migraphx { namespace migraphx {
...@@ -14,32 +15,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -14,32 +15,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{ {
if(ins->name() != "batch_norm_inference") if(ins->name() != "batch_norm_inference")
continue; continue;
if(not std::all_of(ins->inputs().begin() + 1, ins->inputs().end(), [](auto arg) { // Get scale, bias, mean, variance from inputs
return arg->name() == "@literal"; auto gamma = ins->inputs()[1]->eval();
})) auto bias = ins->inputs()[2]->eval();
auto mean = ins->inputs()[3]->eval();
auto variance = ins->inputs()[4]->eval();
if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
continue; continue;
auto conv_ins = ins->inputs()[0]; auto conv_ins = ins->inputs()[0];
if(conv_ins->name() != "convolution") if(conv_ins->name() != "convolution")
continue; continue;
if(conv_ins->inputs()[1]->name() != "@literal") // Get convolution weights
auto weights = conv_ins->inputs()[1]->eval();
if(weights.empty())
continue; continue;
// Get scale, bias, mean, variance from instruction_ref
const auto& gamma = ins->inputs()[1]->get_literal();
const auto& bias = ins->inputs()[2]->get_literal();
const auto& mean = ins->inputs()[3]->get_literal();
const auto& variance = ins->inputs()[4]->get_literal();
// Get epsilon // Get epsilon
auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator()); auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon; auto epsilon = bn_op.epsilon;
// Get convolution weights
const auto& weights = conv_ins->inputs()[1]->get_literal();
// Get convolution op // Get convolution op
auto conv_op = conv_ins->get_operator(); auto conv_op = conv_ins->get_operator();
auto weights_lens = weights.get_shape().lens(); auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens(); auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()}; argument new_weights{weights.get_shape()};
argument new_bias{bias.get_shape()}; argument new_bias{{bias.get_shape().type(), {bias.get_shape().elements()}}};
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)( visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
[&](auto weights2, [&](auto weights2,
auto gamma2, auto gamma2,
...@@ -51,11 +50,11 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -51,11 +50,11 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])( dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])(
[&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
new_weights2(k, c, h, w) = new_weights2(k, c, h, w) =
gamma2(k) / std::sqrt(variance2(k) + epsilon) * weights2(k, c, h, w); gamma2[k] / std::sqrt(variance2[k] + epsilon) * weights2(k, c, h, w);
}); });
dfor(new_bias.get_shape().elements())([&](std::size_t c) { dfor(new_bias.get_shape().elements())([&](std::size_t c) {
new_bias2(c) = new_bias2[c] =
bias2(c) - (gamma2(c) * mean2(c) / std::sqrt(variance2(c) + epsilon)); bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
}); });
}); });
// Replace convolution instruction with updated weights // Replace convolution instruction with updated weights
......
...@@ -18,6 +18,11 @@ struct check_shapes ...@@ -18,6 +18,11 @@ struct check_shapes
{ {
} }
template <class Op>
check_shapes(const shape* b, const shape* e, const Op& op) : begin(b), end(e), name(op.name())
{
}
check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {} check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {}
template <class Op> template <class Op>
...@@ -119,6 +124,13 @@ struct check_shapes ...@@ -119,6 +124,13 @@ struct check_shapes
return *this; return *this;
} }
const check_shapes& elements(std::size_t n) const
{
if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
MIGRAPHX_THROW(prefix() + "Wrong number of elements");
return *this;
}
template <class F> template <class F>
bool same(F f) const bool same(F f) const
{ {
......
...@@ -56,6 +56,9 @@ struct batch_norm_inference ...@@ -56,6 +56,9 @@ struct batch_norm_inference
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(5); check_shapes{inputs, *this}.has(5);
check_shapes{inputs.data(), inputs.data() + 1, *this}.only_dims(4);
check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape().elements(
inputs.front().lens()[1]);
return inputs.front(); return inputs.front();
} }
}; };
......
...@@ -71,6 +71,30 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p) ...@@ -71,6 +71,30 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return std::all_of(c.begin(), c.end(), p); return std::all_of(c.begin(), c.end(), p);
} }
template <class C, class Predicate>
bool any_of(const C& c, const Predicate& p)
{
return std::any_of(c.begin(), c.end(), p);
}
template <class T, class Predicate>
bool any_of(const std::initializer_list<T>& c, const Predicate& p)
{
return std::any_of(c.begin(), c.end(), p);
}
template <class C, class Predicate>
bool none_of(const C& c, const Predicate& p)
{
return std::none_of(c.begin(), c.end(), p);
}
template <class T, class Predicate>
bool none_of(const std::initializer_list<T>& c, const Predicate& p)
{
return std::none_of(c.begin(), c.end(), p);
}
template <class Range, class Iterator> template <class Range, class Iterator>
void copy(Range&& r, Iterator it) void copy(Range&& r, Iterator it)
{ {
......
...@@ -17,6 +17,7 @@ struct program; ...@@ -17,6 +17,7 @@ struct program;
struct schedule struct schedule
{ {
schedule_model model{}; schedule_model model{};
bool enable = true;
std::string name() const { return "schedule"; } std::string name() const { return "schedule"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
...@@ -341,6 +341,8 @@ struct stream_info ...@@ -341,6 +341,8 @@ struct stream_info
void schedule::apply(program& p) const void schedule::apply(program& p) const
{ {
if(not enable)
return;
stream_info si; stream_info si;
auto last = std::prev(p.end()); auto last = std::prev(p.end());
si.accumulate_weights(last, model); si.accumulate_weights(last, model);
......
...@@ -75,10 +75,10 @@ struct cpu_batch_norm_inference ...@@ -75,10 +75,10 @@ struct cpu_batch_norm_inference
par_dfor(num_batch, num_channels, image_height, image_width)( par_dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c) + epsilon) > 0); assert((variance[c] + epsilon) > 0);
result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) / result(n, c, h, w) = gamma[c] * (buffer(n, c, h, w) - mean[c]) /
std::sqrt(variance(c) + epsilon) + std::sqrt(variance[c] + epsilon) +
bias(c); bias[c];
}); });
}); });
} }
......
...@@ -140,6 +140,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -140,6 +140,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto conv = any_cast<miopen_convolution>(ins->get_operator()); auto conv = any_cast<miopen_convolution>(ins->get_operator());
if(conv.op.group > 1) if(conv.op.group > 1)
return false; return false;
if(conv.op.padding_mode != op::padding_mode_t::default_)
return false;
if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd) if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
return false; return false;
auto op = conv.op; auto op = conv.op;
...@@ -251,6 +253,12 @@ struct miopen_conv_bias ...@@ -251,6 +253,12 @@ struct miopen_conv_bias
fusion::op_t conv; fusion::op_t conv;
fusion::op_t bias; fusion::op_t bias;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return op::convolution::reflect(self.op, f);
}
miopen_conv_bias(op::convolution c, const shape& input, const shape& weights, const shape& b) miopen_conv_bias(op::convolution c, const shape& input, const shape& weights, const shape& b)
: op(c), f(input) : op(c), f(input)
{ {
...@@ -288,6 +296,12 @@ struct miopen_conv_bias_relu ...@@ -288,6 +296,12 @@ struct miopen_conv_bias_relu
fusion::op_t bias; fusion::op_t bias;
fusion::op_t relu; fusion::op_t relu;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return op::convolution::reflect(self.op, f);
}
miopen_conv_bias_relu(op::convolution c, miopen_conv_bias_relu(op::convolution c,
const shape& input, const shape& input,
const shape& weights, const shape& weights,
......
...@@ -26,6 +26,8 @@ namespace migraphx { ...@@ -26,6 +26,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_SCHEDULE_PASS)
std::vector<pass> target::get_passes(migraphx::context& gctx) const std::vector<pass> target::get_passes(migraphx::context& gctx) const
{ {
auto& ctx = any_cast<context>(gctx); auto& ctx = any_cast<context>(gctx);
...@@ -55,7 +57,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -55,7 +57,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
fuse_ops{&ctx}, fuse_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}}, schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, enabled(MIGRAPHX_ENABLE_SCHEDULE_PASS{})},
memory_coloring{"hip::allocate"}, memory_coloring{"hip::allocate"},
dead_code_elimination{}, dead_code_elimination{},
eliminate_workspace{}, eliminate_workspace{},
......
...@@ -3,9 +3,13 @@ ...@@ -3,9 +3,13 @@
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp> #include <test.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
bool is_batch_norm(migraphx::instruction& ins) { return ins.name() == "batch_norm_inference"; }
TEST_CASE(fwd_conv_batchnorm_rewrite_test) TEST_CASE(fwd_conv_batchnorm_rewrite_test)
{ {
std::vector<float> xdata = { std::vector<float> xdata = {
...@@ -65,4 +69,105 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test) ...@@ -65,4 +69,105 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
EXPECT(migraphx::verify_range(results_vector1, results_vector2)); EXPECT(migraphx::verify_range(results_vector1, results_vector2));
} }
TEST_CASE(non_literal)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape vars{migraphx::shape::float_type, {4}};
auto create_program = [&]() {
migraphx::program p;
auto x = p.add_parameter("x", xs);
auto w = p.add_parameter("w", ws);
auto conv = p.add_instruction(migraphx::op::convolution{}, x, w);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p;
};
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(any_of(p2, &is_batch_norm));
}
TEST_CASE(as_literal)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape vars{migraphx::shape::float_type, {4}};
auto create_program = [&]() {
migraphx::program p;
auto x = p.add_literal(migraphx::generate_literal(xs, 1));
auto w = p.add_literal(migraphx::generate_literal(ws, 1));
auto conv = p.add_instruction(migraphx::op::convolution{}, x, w);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p;
};
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({});
auto result2 = p2.eval({});
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}
TEST_CASE(literal_reshape)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape vars{migraphx::shape::float_type, {4}};
auto create_program = [&]() {
migraphx::program p;
auto reshape = [&](auto ins) {
return p.add_instruction(migraphx::op::reshape{{1, 4, 1, 1}}, ins);
};
auto x = p.add_literal(migraphx::generate_literal(xs, 1));
auto w = p.add_literal(migraphx::generate_literal(ws, 1));
auto conv = p.add_instruction(migraphx::op::convolution{}, x, w);
auto scale = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))));
auto bias = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))));
auto mean = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))));
auto variance = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))));
p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p;
};
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({});
auto result2 = p2.eval({});
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment