Commit f21a7346 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into propogate-constant

parents 68858a5b 369cb3a5
...@@ -20,6 +20,7 @@ add_library(migraphx ...@@ -20,6 +20,7 @@ add_library(migraphx
program.cpp program.cpp
shape.cpp shape.cpp
schedule.cpp schedule.cpp
pass_manager.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
opt/memory_coloring.cpp opt/memory_coloring.cpp
......
...@@ -63,7 +63,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -63,7 +63,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()}); auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights}); auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights});
auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape()}, l_bias); auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape().lens()}, l_bias);
p.replace_instruction(ins, op::add{}, {c, b}); p.replace_instruction(ins, op::add{}, {c, b});
} }
} }
......
...@@ -27,38 +27,36 @@ namespace op { ...@@ -27,38 +27,36 @@ namespace op {
struct broadcast struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"), f(self.broadcast_lens, "dims"));
} }
shape broadcast_shape;
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto input = inputs.at(0); auto input = inputs.at(0);
std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0); std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
if(std::all_of(broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) { if(std::all_of(
return x == 1; broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
}))
{ {
if(axis != 0) if(axis != 0)
MIGRAPHX_THROW("when broadcasting tensor of size 1, axis should be 0"); MIGRAPHX_THROW("BROADCAST: when broadcasting tensor of size 1, axis should be 0");
return {t, broadcast_shape.lens(), std::move(bcast_strides)}; return {t, broadcast_lens, std::move(bcast_strides)};
} }
else else
{ {
assert(broadcast_shape.lens().size() - axis >= input.lens().size()); assert(broadcast_lens.size() - axis >= input.lens().size());
if(!std::equal( if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis)) MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match");
MIGRAPHX_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_shape.lens(), std::move(bcast_strides)}; return {t, broadcast_lens, std::move(bcast_strides)};
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -31,6 +31,8 @@ enum class rnn_direction ...@@ -31,6 +31,8 @@ enum class rnn_direction
bidirectional, bidirectional,
}; };
std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -19,6 +19,13 @@ namespace op { ...@@ -19,6 +19,13 @@ namespace op {
struct gather struct gather
{ {
int axis = 0; int axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "gather"; } std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -27,6 +27,16 @@ struct gru ...@@ -27,6 +27,16 @@ struct gru
float clip = 0.0f; float clip = 0.0f;
int linear_before_reset = 0; int linear_before_reset = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.linear_before_reset, "linear_before_reset"));
}
std::string name() const { return "gru"; } std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -19,6 +19,13 @@ namespace op { ...@@ -19,6 +19,13 @@ namespace op {
struct logsoftmax struct logsoftmax
{ {
int axis = 1; int axis = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "logsoftmax"; } std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -25,6 +25,15 @@ struct lstm ...@@ -25,6 +25,15 @@ struct lstm
float clip = 0.0f; float clip = 0.0f;
int input_forget = 0; int input_forget = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.input_forget, "input_forget"));
}
std::string name() const { return "lstm"; } std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -25,6 +25,15 @@ struct rnn ...@@ -25,6 +25,15 @@ struct rnn
rnn_direction direction = rnn_direction::forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"));
}
std::string name() const { return "rnn"; } std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -18,7 +18,13 @@ namespace op { ...@@ -18,7 +18,13 @@ namespace op {
struct scalar struct scalar
{ {
shape scalar_bcast; std::vector<std::size_t> scalar_bcast_lens;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.scalar_bcast_lens, "scalar_bcst_dims"));
}
std::string name() const { return "scalar"; } std::string name() const { return "scalar"; }
...@@ -26,8 +32,8 @@ struct scalar ...@@ -26,8 +32,8 @@ struct scalar
{ {
assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1); assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1);
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
std::vector<std::size_t> strides(scalar_bcast.lens().size(), 0); std::vector<std::size_t> strides(scalar_bcast_lens.size(), 0);
return {t, scalar_bcast.lens(), strides}; return {t, scalar_bcast_lens, strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP
#include <list>
#include <unordered_map>
#include <migraphx/operation.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/tracer.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace = tracer{});
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -141,8 +141,8 @@ struct onnx_parser ...@@ -141,8 +141,8 @@ struct onnx_parser
if(broadcasted != 0) if(broadcasted != 0)
{ {
uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>(); uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
auto l = auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(x, args[0], l);
} }
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
...@@ -306,7 +306,7 @@ struct onnx_parser ...@@ -306,7 +306,7 @@ struct onnx_parser
{ {
uint64_t axis = 1; uint64_t axis = 1;
auto l1 = prog.add_instruction(op, args[0], args[1]); auto l1 = prog.add_instruction(op, args[0], args[1]);
auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape()}, args[2]); auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape().lens()}, args[2]);
return prog.add_instruction(op::add{}, l1, l2); return prog.add_instruction(op::add{}, l1, l2);
} }
return prog.add_instruction(op, l0, args[1]); return prog.add_instruction(op, l0, args[1]);
...@@ -671,15 +671,15 @@ struct onnx_parser ...@@ -671,15 +671,15 @@ struct onnx_parser
auto&& bias_floats = attributes["bias"].floats(); auto&& bias_floats = attributes["bias"].floats();
bias = std::vector<float>(bias_floats.begin(), bias_floats.end()); bias = std::vector<float>(bias_floats.begin(), bias_floats.end());
} }
auto input_shape = args.front()->get_shape(); auto input_lens = args.front()->get_shape().lens();
auto scale_val = prog.add_literal(scale); auto scale_val = prog.add_literal(scale);
auto bias_vals = prog.add_literal( auto bias_vals = prog.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_shape}, scale_val); auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor); auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_shape}, bias_vals); auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
} }
......
#include <migraphx/program.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{
for(auto& p : passes)
{
trace("Pass: ", p.name());
p.apply(prog);
trace(prog);
#ifndef NDEBUG
trace("Validate ...");
auto invalid = prog.validate();
if(invalid != prog.end())
{
auto index = std::distance(prog.begin(), invalid);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
}
trace();
#endif
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/pass_manager.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
...@@ -291,23 +292,7 @@ void program::compile(const target& t, tracer trace) ...@@ -291,23 +292,7 @@ void program::compile(const target& t, tracer trace)
trace = tracer{std::cout}; trace = tracer{std::cout};
trace(*this); trace(*this);
trace(); trace();
for(auto&& p : t.get_passes(this->impl->ctx)) run_passes(*this, t.get_passes(this->impl->ctx), trace);
{
trace("Pass: ", p.name());
p.apply(*this);
trace(*this);
#ifndef NDEBUG
trace("Validate ...");
auto invalid = this->validate();
if(invalid != impl->instructions.end())
{
auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
}
trace();
#endif
}
auto invalid = this->validate(); auto invalid = this->validate();
if(invalid != impl->instructions.end()) if(invalid != impl->instructions.end())
{ {
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -213,7 +214,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -213,7 +214,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb); auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b); bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape().lens()}, b);
} }
instruction_ref hidden_out = prog.end(); instruction_ref hidden_out = prog.end();
...@@ -520,25 +521,26 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -520,25 +521,26 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
instruction_ref brcst_bh{}; instruction_ref brcst_bh{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto broadcast_lens = sih->get_shape().lens();
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias); auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh); auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias); auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias); auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh); brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz); auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz); brcst_bz = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr); auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br); brcst_br = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, br);
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh); auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh); brcst_bh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, bh);
} }
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
...@@ -945,8 +947,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -945,8 +947,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// initial cell state // initial cell state
auto sic = prog.insert_instruction(ins, op::squeeze{{0}}, ic); auto sic = prog.insert_instruction(ins, op::squeeze{{0}}, ic);
auto ic_shape = sic->get_shape(); auto ic_lens = sic->get_shape().lens();
// bias // bias
instruction_ref bi_brcst{}; instruction_ref bi_brcst{};
...@@ -955,26 +957,27 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -955,26 +957,27 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref bc_brcst{}; instruction_ref bc_brcst{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto bxi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto bxi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto bhi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); auto bhi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto bi = prog.insert_instruction(ins, op::add{}, bxi, bhi); auto bi = prog.insert_instruction(ins, op::add{}, bxi, bhi);
bi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bi); bi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bi);
auto bxo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto bxo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto bho = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias); auto bho = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
auto bo = prog.insert_instruction(ins, op::add{}, bxo, bho); auto bo = prog.insert_instruction(ins, op::add{}, bxo, bho);
bo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bo); bo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bo);
auto bxf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias); auto bxf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
auto bhf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias); auto bhf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias);
auto bf = prog.insert_instruction(ins, op::add{}, bxf, bhf); auto bf = prog.insert_instruction(ins, op::add{}, bxf, bhf);
bf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bf); bf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bf);
auto bxc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias); auto bxc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto bhc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias); auto bhc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
auto bc = prog.insert_instruction(ins, op::add{}, bxc, bhc); auto bc = prog.insert_instruction(ins, op::add{}, bxc, bhc);
bc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bc); bc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bc);
} }
// peep hole // peep hole
...@@ -986,13 +989,13 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -986,13 +989,13 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
{ {
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph); auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi); pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphi);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph); auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho); ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, ppho);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph); auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf); pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphf);
} }
for(long i = 0; i < seq_len; ++i) for(long i = 0; i < seq_len; ++i)
...@@ -1166,5 +1169,14 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1166,5 +1169,14 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
} }
} }
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -110,6 +110,7 @@ struct tf_parser ...@@ -110,6 +110,7 @@ struct tf_parser
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{});
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd); add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
...@@ -117,6 +118,7 @@ struct tf_parser ...@@ -117,6 +118,7 @@ struct tf_parser
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean); add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Pack", &tf_parser::parse_pack); add_mem_op("Pack", &tf_parser::parse_pack);
...@@ -124,6 +126,7 @@ struct tf_parser ...@@ -124,6 +126,7 @@ struct tf_parser
add_mem_op("Reshape", &tf_parser::parse_reshape); add_mem_op("Reshape", &tf_parser::parse_reshape);
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze); add_mem_op("Squeeze", &tf_parser::parse_squeeze);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
} }
template <class F> template <class F>
...@@ -235,7 +238,7 @@ struct tf_parser ...@@ -235,7 +238,7 @@ struct tf_parser
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel) uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]);
return prog.add_instruction(op::add{}, args[0], l0); return prog.add_instruction(op::add{}, args[0], l0);
} }
...@@ -336,6 +339,32 @@ struct tf_parser ...@@ -336,6 +339,32 @@ struct tf_parser
return prog.add_instruction(op, {args[0], weights}); return prog.add_instruction(op, {args[0], weights});
} }
instruction_ref
parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
bool transa = false;
bool transb = false;
if(contains(attributes, "transpose_a"))
{
transa = attributes.at("transpose_a").b();
}
if(contains(attributes, "transpose_b"))
{
transb = attributes.at("transpose_a").b();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
return prog.add_instruction(op::dot{}, l1, l2);
}
instruction_ref instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -508,6 +537,46 @@ struct tf_parser ...@@ -508,6 +537,46 @@ struct tf_parser
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, args[0]);
} }
instruction_ref parse_stridedslice(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size();
if(num_axes >= 4)
{
reorder_data(starts);
reorder_data(ends);
}
op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end());
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1;
std::vector<int64_t> squeeze_axes;
if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to squeeze
if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i);
}
if(num_axes >= 4)
{
squeeze_axes = parse_axes(squeeze_axes);
}
auto l0 = prog.add_instruction(op, args[0]);
return prog.add_instruction(op::squeeze{squeeze_axes}, l0);
}
void parse_graph(const tensorflow::GraphDef& graph) void parse_graph(const tensorflow::GraphDef& graph)
{ {
nodes = get_nodes(graph, input_nodes); nodes = get_nodes(graph, input_nodes);
......
...@@ -60,7 +60,7 @@ TEST_CASE(after_literal_broadcast) ...@@ -60,7 +60,7 @@ TEST_CASE(after_literal_broadcast)
auto l2 = p.add_literal(get_2()); auto l2 = p.add_literal(get_2());
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape()}, l2); auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b); p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted()); EXPECT(p.get_shape().broadcasted());
...@@ -91,7 +91,7 @@ TEST_CASE(after_param_broadcast) ...@@ -91,7 +91,7 @@ TEST_CASE(after_param_broadcast)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {2}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {2}});
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape()}, l2); auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2);
p.add_instruction(pass_op{}, b); p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted()); EXPECT(p.get_shape().broadcasted());
......
...@@ -351,7 +351,7 @@ TEST_CASE(gemm_mutli_dim1_2_3) ...@@ -351,7 +351,7 @@ TEST_CASE(gemm_mutli_dim1_2_3)
float beta = 0.41; float beta = 0.41;
auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2); auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2);
auto l_beta = p.add_literal(beta); auto l_beta = p.add_literal(beta);
auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape()}, l_beta); auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape().lens()}, l_beta);
auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3); auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3);
p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha); p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
......
...@@ -651,7 +651,7 @@ TEST_CASE(broadcast_test) ...@@ -651,7 +651,7 @@ TEST_CASE(broadcast_test)
uint64_t axis = 0; uint64_t axis = 0;
auto l1 = p.add_literal(migraphx::literal{a_shape, a_data}); auto l1 = p.add_literal(migraphx::literal{a_shape, a_data});
auto l2 = p.add_literal(migraphx::literal{b_shape, b_data}); auto l2 = p.add_literal(migraphx::literal{b_shape, b_data});
p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2); p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape().lens()}, l2);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
auto output = result.get<int32_t>(); auto output = result.get<int32_t>();
...@@ -671,7 +671,7 @@ TEST_CASE(add_broadcast_test) ...@@ -671,7 +671,7 @@ TEST_CASE(add_broadcast_test)
uint64_t axis = 0; uint64_t axis = 0;
auto l1 = p.add_literal(migraphx::literal{a_shape, a_data}); auto l1 = p.add_literal(migraphx::literal{a_shape, a_data});
auto l2 = p.add_literal(migraphx::literal{b_shape, b_data}); auto l2 = p.add_literal(migraphx::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2); auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l1, l3); p.add_instruction(migraphx::op::add{}, l1, l3);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -809,11 +809,11 @@ TEST_CASE(imagescaler_test) ...@@ -809,11 +809,11 @@ TEST_CASE(imagescaler_test)
0.35, 0.35,
0.45}}); 0.45}});
auto scale_val = p.add_literal(2.f); auto scale_val = p.add_literal(2.f);
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val); auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, img, scaled_tensor); auto img_scaled = p.add_instruction(migraphx::op::mul{}, img, scaled_tensor);
auto bias_vals = p.add_literal( auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s}, bias_vals); auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
......
...@@ -369,7 +369,7 @@ struct test_scale : verify_program<test_scale> ...@@ -369,7 +369,7 @@ struct test_scale : verify_program<test_scale>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", migraphx::shape::float_type); auto y = p.add_parameter("y", migraphx::shape::float_type);
auto scale = p.add_instruction(migraphx::op::scalar{s}, y); auto scale = p.add_instruction(migraphx::op::scalar{s.lens()}, y);
p.add_instruction(migraphx::op::mul{}, x, scale); p.add_instruction(migraphx::op::mul{}, x, scale);
return p; return p;
} }
...@@ -415,7 +415,7 @@ struct test_triadd2 : verify_program<test_triadd2> ...@@ -415,7 +415,7 @@ struct test_triadd2 : verify_program<test_triadd2>
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b); auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s}, z); auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto sum = p.add_instruction(migraphx::op::add{}, x, y); auto sum = p.add_instruction(migraphx::op::add{}, x, y);
p.add_instruction(migraphx::op::add{}, sum, zb); p.add_instruction(migraphx::op::add{}, sum, zb);
return p; return p;
...@@ -430,7 +430,7 @@ struct test_add_broadcast : verify_program<test_add_broadcast> ...@@ -430,7 +430,7 @@ struct test_add_broadcast : verify_program<test_add_broadcast>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -444,7 +444,7 @@ struct test_add_broadcast2 : verify_program<test_add_broadcast2> ...@@ -444,7 +444,7 @@ struct test_add_broadcast2 : verify_program<test_add_broadcast2>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -458,7 +458,7 @@ struct test_add_broadcast3 : verify_program<test_add_broadcast3> ...@@ -458,7 +458,7 @@ struct test_add_broadcast3 : verify_program<test_add_broadcast3>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 5}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 5}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -472,7 +472,7 @@ struct test_add_broadcast4 : verify_program<test_add_broadcast4> ...@@ -472,7 +472,7 @@ struct test_add_broadcast4 : verify_program<test_add_broadcast4>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 5}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 5}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -486,7 +486,7 @@ struct test_add_broadcast5 : verify_program<test_add_broadcast5> ...@@ -486,7 +486,7 @@ struct test_add_broadcast5 : verify_program<test_add_broadcast5>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 8}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 8}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}});
auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y);
p.add_instruction(migraphx::op::add{}, x, by); p.add_instruction(migraphx::op::add{}, x, by);
return p; return p;
} }
...@@ -501,7 +501,7 @@ struct test_triadd_broadcast : verify_program<test_triadd_broadcast> ...@@ -501,7 +501,7 @@ struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}}); auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}}); auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}});
auto z = p.add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}}); auto z = p.add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}});
auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape()}, y); auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y);
auto sum = p.add_instruction(migraphx::op::add{}, x, by); auto sum = p.add_instruction(migraphx::op::add{}, x, by);
p.add_instruction(migraphx::op::add{}, sum, z); p.add_instruction(migraphx::op::add{}, sum, z);
return p; return p;
...@@ -533,7 +533,7 @@ struct test_sub2 : verify_program<test_sub2> ...@@ -533,7 +533,7 @@ struct test_sub2 : verify_program<test_sub2>
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b); auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s}, z); auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto diff = p.add_instruction(migraphx::op::sub{}, x, y); auto diff = p.add_instruction(migraphx::op::sub{}, x, y);
p.add_instruction(migraphx::op::sub{}, diff, zb); p.add_instruction(migraphx::op::sub{}, diff, zb);
return p; return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment