Commit 087c205e authored by Paul's avatar Paul
Browse files

Merge from develop

parents a3a9e469 e15b8333
......@@ -15,55 +15,59 @@
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/onnx.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct unknown
{
std::string op;
std::string name() const { return "unknown:" + op; }
shape compute_shape(std::vector<shape> input) const
{
if(input.empty())
return {};
else
return input.front();
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
};
struct onnx_parser
{
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
using op_func =
std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
bool is_pytorch = false;
std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> map_actv_funcs;
onnx_parser()
{
add_generic_op("MatMul", op::dot{});
add_generic_op("Relu", op::relu{});
add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Abs", op::abs{});
add_generic_op("Exp", op::exp{});
add_generic_op("Log", op::log{});
// disable dropout for inference
add_generic_op("Dropout", op::identity{});
add_generic_op("Identity", op::identity{});
add_generic_op("Sin", op::sin{});
add_generic_op("Cos", op::cos{});
add_generic_op("Tan", op::tan{});
add_generic_op("Sinh", op::sinh{});
add_generic_op("Cosh", op::cosh{});
add_generic_op("Tanh", op::tanh{});
add_generic_op("Asin", op::asin{});
add_generic_op("Acos", op::acos{});
add_generic_op("Atan", op::atan{});
add_binary_op("Add", op::add{});
add_binary_op("Div", op::div{});
add_binary_op("Mul", op::mul{});
add_binary_op("Sub", op::sub{});
add_broadcastable_binary_op("Add", op::add{});
add_broadcastable_binary_op("Div", op::div{});
add_broadcastable_binary_op("Mul", op::mul{});
add_broadcastable_binary_op("Sub", op::sub{});
add_broadcastable_binary_op("Sum", op::add{});
add_variadic_op("Sum", op::add{});
add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{});
add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
......@@ -79,11 +83,39 @@ struct onnx_parser
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Concat", &onnx_parser::parse_concat);
add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
// init the activation function map
init_actv_func();
}
void init_actv_func()
{
map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
}
template <class F>
void add_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
});
}
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
{
ops.emplace(name, f);
}
......@@ -91,81 +123,101 @@ struct onnx_parser
template <class F>
void add_mem_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
});
}
template <class T>
void add_broadcastable_binary_op(std::string name, T x)
void add_binary_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast"))
if(contains(attributes, "broadcast") and contains(attributes, "axis"))
{
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
{
uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
auto l =
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l);
}
return prog.add_instruction(x, args);
}
else if(args[0]->get_shape() != args[1]->get_shape())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<std::size_t>* s0 = &args[0]->get_shape().lens();
const std::vector<std::size_t>* s1 = &args[1]->get_shape().lens();
// Make sure s0 is the smaller size
if(s0->size() > s1->size())
std::swap(s0, s1);
// Copy the larger vector to output_lens
std::vector<std::size_t> output_lens = *s1;
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, args[0]);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, args[1]);
return prog.add_instruction(x, l0, l1);
}
else
{
return prog.add_instruction(x, args);
return add_broadcastable_binary_op(args[0], args[1], x);
}
});
}
template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{
if(arg0->get_shape() != arg1->get_shape())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<std::size_t>* s0 = &arg0->get_shape().lens();
const std::vector<std::size_t>* s1 = &arg1->get_shape().lens();
// Make sure s0 is the smaller size
if(s0->size() > s1->size())
std::swap(s0, s1);
std::vector<std::size_t> output_lens(*s1);
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
return prog.add_instruction(x, l0, l1);
}
else
{
return prog.add_instruction(x, {arg0, arg1});
}
}
template <class T>
void add_generic_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
});
}
template <class T>
void add_variadic_op(std::string name, T x)
{
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()),
args.end(),
args.front(),
[this, x](instruction_ref a, instruction_ref b) {
return add_broadcastable_binary_op(a, b, x);
});
});
}
instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
......@@ -180,24 +232,30 @@ struct onnx_parser
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::convolution op;
auto l0 = args[0];
if(contains(attributes, "pads"))
{
if(contains(attributes, "auto_pad"))
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
std::vector<std::size_t> padding(4);
copy(attributes["pads"].ints(), padding.begin());
std::vector<std::int64_t> padding;
copy(attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
// insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0);
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
if(contains(attributes, "strides"))
{
......@@ -217,9 +275,13 @@ struct onnx_parser
if(s.find("SAME") != std::string::npos)
{
op.padding_mode = op::convolution::same;
op.padding_mode = op::padding_mode_t::same;
}
}
if(contains(attributes, "group"))
{
op.group = parse_value(attributes.at("group")).at<int>();
}
if(args.size() == 3)
{
uint64_t axis = 1;
......@@ -227,7 +289,7 @@ struct onnx_parser
auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape()}, args[2]);
return prog.add_instruction(op::add{}, l1, l2);
}
return prog.add_instruction(op, args);
return prog.add_instruction(op, l0, args[1]);
}
instruction_ref parse_pooling(const std::string& name,
......@@ -235,6 +297,7 @@ struct onnx_parser
std::vector<instruction_ref> args)
{
op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
auto l0 = args[0];
if(starts_with(name, "Global"))
{
auto lens = args.front()->get_shape().lens();
......@@ -242,18 +305,23 @@ struct onnx_parser
}
if(contains(attributes, "pads"))
{
std::vector<std::size_t> padding(4);
copy(attributes["pads"].ints(), padding.begin());
std::vector<std::int64_t> padding;
copy(attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
// insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0);
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
if(contains(attributes, "strides"))
{
......@@ -266,13 +334,14 @@ struct onnx_parser
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(to_upper(s) != "NOTSET")
if(s.find("SAME_UPPER") == std::string::npos)
{
MIGRAPHX_THROW("auto_pad is not supported for pooling");
MIGRAPHX_THROW("auto_pad only supports SAME_UPPER for pooling");
}
op.padding_mode = op::padding_mode_t::same;
}
return prog.add_instruction(op, std::move(args));
return prog.add_instruction(op, l0);
}
instruction_ref
......@@ -286,7 +355,9 @@ struct onnx_parser
}
if(args.size() == 2)
{
literal s = args[1]->get_literal();
auto s = args[1]->eval();
if(s.empty())
MIGRAPHX_THROW("Dynamic shape is not supported.");
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, args[0]);
......@@ -295,7 +366,7 @@ struct onnx_parser
instruction_ref
parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
uint64_t axis = 0;
uint64_t axis = 1;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
......@@ -329,6 +400,18 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
int axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
op::gather op{axis};
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......@@ -353,7 +436,15 @@ struct onnx_parser
attribute_map attributes,
const std::vector<instruction_ref>&)
{
literal v = parse_value(attributes.at("value"));
literal v = parse_value(attributes.at("value"));
auto dim_size = attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar
if(dim_size == 0)
{
migraphx::shape scalar_shape{v.get_shape().type()};
return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
}
return prog.add_literal(v);
}
......@@ -361,7 +452,7 @@ struct onnx_parser
parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 1.0f;
float beta = 0.0f;
float beta = 1.0f;
bool transa = false;
bool transb = false;
if(contains(attributes, "alpha"))
......@@ -370,7 +461,7 @@ struct onnx_parser
}
if(contains(attributes, "beta"))
{
alpha = parse_value(attributes.at("beta")).at<float>();
beta = parse_value(attributes.at("beta")).at<float>();
}
if(contains(attributes, "transA"))
{
......@@ -380,17 +471,31 @@ struct onnx_parser
{
transb = parse_value(attributes.at("transB")).at<bool>();
}
std::vector<int64_t> perm = {1, 0};
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3)
{
uint64_t axis = 1;
auto l3 = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
auto l4 = prog.add_instruction(op::broadcast{axis, l3->get_shape()}, args[2]);
return prog.add_instruction(op::add{}, l3, l4);
if(beta != 0.f)
{
auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
auto l4 = args[2];
if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B)
return l3;
if(beta != 1.f)
{
auto beta_val = prog.add_literal(beta);
auto l5 = prog.add_instruction(op::scalar{args[2]->get_shape()}, beta_val);
l4 = prog.add_instruction(op::mul{}, args[2], l5);
}
return add_broadcastable_binary_op(l3, l4, op::add{});
}
}
return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
auto dot_res = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
return dot_res;
}
instruction_ref
......@@ -436,6 +541,37 @@ struct onnx_parser
return prog.add_instruction(op, args.front());
}
instruction_ref
parse_elu(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 1.0; // default alpha val for elu
if(contains(attributes, "alpha"))
{
alpha = parse_value(attributes.at("alpha")).at<float>();
}
op::elu op{alpha};
return prog.add_instruction(op, args.front());
}
instruction_ref
parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 0.0001;
float beta = 0.75;
float bias = 1.0;
int size = 1;
if(contains(attributes, "alpha"))
alpha = parse_value(attributes.at("alpha")).at<float>();
if(contains(attributes, "beta"))
beta = parse_value(attributes.at("beta")).at<float>();
if(contains(attributes, "bias"))
bias = parse_value(attributes.at("bias")).at<float>();
if(contains(attributes, "size"))
size = parse_value(attributes.at("size")).at<int>();
op::lrn op{alpha, beta, bias, size};
return prog.add_instruction(op, args.front());
}
instruction_ref parse_imagescaler(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
......@@ -476,6 +612,503 @@ struct onnx_parser
return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
}
instruction_ref
parse_pad(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::vector<int64_t> pads{};
float value = 0.0f;
if(contains(attributes, "pads"))
{
auto&& pad_vals = attributes["pads"].ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
}
if(contains(attributes, "value"))
{
value = parse_value(attributes.at("value")).at<float>();
}
if(contains(attributes, "mode"))
{
auto mode = attributes.at("mode").s();
if(mode != "constant")
MIGRAPHX_THROW("migraphx currently only supports constant padding");
}
return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
}
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
instruction_ref
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i);
});
return prog.add_literal(migraphx::literal{s, vec_shape});
}
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill
// operator
instruction_ref parse_constant_fill(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
int input_as_shape = 0;
int dtype = 1;
float value = 0.0f;
if(contains(attributes, "dtype"))
{
dtype = parse_value(attributes.at("dtype")).at<int>();
}
migraphx::shape::type_t type = get_type(dtype);
if(contains(attributes, "input_as_shape"))
{
input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>();
}
if(contains(attributes, "value"))
{
value = parse_value(attributes.at("value")).at<float>();
}
if(contains(attributes, "extra_shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
}
if(input_as_shape == 1)
{
if(args.size() != 1)
{
MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
}
if(contains(attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
"at the same time");
}
migraphx::argument in = args[0]->eval();
if(in.empty())
{
MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims);
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
}
else if(input_as_shape == 0)
{
if(!contains(attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
}
literal ls = parse_value(attributes.at("shape"));
std::vector<std::size_t> dims;
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims};
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
}
else
{
MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
}
}
std::vector<instruction_ref>
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names{"tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin());
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
}
// bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction
if(dirct == op::rnn_direction::bidirectional)
{
if(vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& fn) {
return map_actv_funcs[fn];
});
// To be added later
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
// if the number of arguments is less than 6, append
// undefined operator to have 6 arguments
if(args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), (6 - args.size()), ins);
}
// first output for the concatenation of hidden states
auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
std::move(args));
// second output for the last hidden state
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output};
}
std::vector<instruction_ref>
parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names = {"sigmoid", "tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin());
}
// need 4 activation functions
if(dirct == op::rnn_direction::bidirectional)
{
// 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1 four times. If 2 actv functins are provided,
// assume forward and reverse use the same pair of actv
// functions. For the case of 3 actv functions provided,
// assume the 3rd one is repeated once and used by the
// reverse direction.
// This may need change later
if(vec_names.size() == 1)
{
vec_names.insert(vec_names.end(), 3, vec_names.at(0));
}
else if(vec_names.size() == 2)
{
// repeat the activation functions
vec_names.push_back(vec_names.at(0));
vec_names.push_back(vec_names.at(1));
}
else if(vec_names.size() == 3)
{
vec_names.push_back(vec_names.at(2));
}
}
else
{
if(vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
return map_actv_funcs[name];
});
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
int linear_before_reset = 0;
if(contains(attributes, "linear_before_reset"))
{
linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), 6 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
std::move(args));
// second output for last gru output
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output};
}
std::vector<instruction_ref>
parse_lstm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
else if(direction == "forward")
{
dirct = op::rnn_direction::forward;
}
else
{
MIGRAPHX_THROW("LSTM: incorrect direction attribute");
}
std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin());
}
// need 6 activation functions for bidirectional directions
if(dirct == op::rnn_direction::bidirectional)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch(vec_names.size())
{
case 1:
vec_names = {vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0)};
break;
case 2:
// repeat the 2nd actv func once, then repeat all three another time
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(1),
vec_names.at(0),
vec_names.at(1),
vec_names.at(1)};
break;
case 3:
// repeat all three actv funcs once
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(0),
vec_names.at(1),
vec_names.at(2)};
break;
case 4:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(3),
vec_names.at(3)};
break;
case 5:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(4),
vec_names.at(4)};
break;
default: break;
}
}
else
{
switch(vec_names.size())
{
case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
break;
default: break;
}
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
return map_actv_funcs[name];
});
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
int input_forget = 0;
if(contains(attributes, "input_forget"))
{
input_forget = parse_value(attributes.at("input_forget")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 8)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), 8 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
// second output for last lstm output
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
// third output for last cell output
auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);
return {hidden_states, last_output, last_cell_output};
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......@@ -488,7 +1121,7 @@ struct onnx_parser
}
else
{
throw std::runtime_error("Failed reading");
MIGRAPHX_THROW("Failed reading onnx file.");
}
}
......@@ -518,10 +1151,16 @@ struct onnx_parser
}
for(auto&& p : nodes)
{
this->parse_node(get_name(p.second));
this->parse_node(p.first);
}
}
void parse_undefined(const std::string& name)
{
auto ins = prog.add_instruction(op::undefined{});
instructions[name] = ins;
}
void parse_node(const std::string& name)
{
if(name.empty())
......@@ -534,23 +1173,37 @@ struct onnx_parser
{
if(nodes.count(input) > 0)
{
auto&& iname = get_name(nodes.at(input));
assert(name != iname);
this->parse_node(iname);
args.push_back(instructions.at(iname));
assert(name != input);
this->parse_node(input);
}
else
else if(input.empty())
{
args.push_back(instructions.at(input));
this->parse_undefined(input);
}
args.push_back(instructions.at(input));
}
std::vector<instruction_ref> result;
if(ops.count(node.op_type()) == 0)
{
instructions[name] = prog.add_instruction(unknown{node.op_type()}, args);
result.push_back(prog.add_instruction(unknown{node.op_type()}, args));
}
else
{
result = ops[node.op_type()](get_attributes(node), args);
}
// Even no output nodes produce output in migraphx
if(node.output().empty() and result.size() == 1)
{
instructions[name] = result.front();
}
else
{
instructions[name] = ops[node.op_type()](get_attributes(node), args);
assert(node.output().size() >= result.size());
std::transform(result.begin(),
result.end(),
node.output().begin(),
std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(y, x); });
}
}
}
......@@ -565,25 +1218,24 @@ struct onnx_parser
return result;
}
static std::string get_name(const onnx::NodeProto& node)
{
if(node.name().empty())
{
std::string generated = "migraphx_unnamed_node";
return std::accumulate(node.output().begin(),
node.output().end(),
generated,
[](auto x, auto y) { return x + "_" + y; });
}
return node.name();
}
static node_map get_nodes(const onnx::GraphProto& graph)
{
std::unordered_map<std::string, onnx::NodeProto> result;
std::size_t n = 0;
for(auto&& node : graph.node())
{
result[get_name(node)] = node;
if(node.output().empty())
{
if(node.name().empty())
{
result["migraphx_unamed_node_" + std::to_string(n)] = node;
n++;
}
else
{
result[node.name()] = node;
}
}
for(auto&& output : node.output())
{
result[output] = node;
......@@ -621,6 +1273,11 @@ struct onnx_parser
static literal parse_tensor(const onnx::TensorProto& t)
{
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty())
{
dims = {1};
}
if(t.has_raw_data())
{
const std::string& s = t.raw_data();
......@@ -665,7 +1322,15 @@ struct onnx_parser
case onnx::TensorProto::BOOL:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16:
return literal{{shape::half_type, dims}, t.float_data().begin(), t.float_data().end()};
{
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
data_uint16.end(),
std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half.begin(), data_half.end()};
}
case onnx::TensorProto::DOUBLE:
return literal{
{shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
......@@ -720,6 +1385,28 @@ struct onnx_parser
});
return {shape_type, dims};
}
shape::type_t get_type(int dtype)
{
switch(dtype)
{
case 1: return shape::float_type;
case 2: return shape::uint8_type;
case 3: return shape::int8_type;
case 4: return shape::uint16_type;
case 5: return shape::int16_type;
case 6: return shape::int32_type;
case 7: return shape::int64_type;
case 10: return shape::half_type;
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
default:
{
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
}
}
};
program parse_onnx(const std::string& name)
......
......@@ -116,7 +116,7 @@ void verify_reduced_program(F f, double tolerance = 80)
{
migraphx::program p = f();
auto n = std::distance(p.begin(), p.end());
for(int i = 0; i < n; i++)
for(std::size_t i = 0; i < n; i++)
{
verify_reduced(f, i, tolerance);
}
......
......@@ -118,11 +118,11 @@ void memory_coloring_impl::build()
live_range& range = def_interval->segment;
def_interval->result = iter->get_shape();
def_interval->is_literal = is_lit;
range.begin = cur_points;
def_interval->def_point = cur_points;
range.size = (iter->get_shape()).bytes();
if(!is_lit || unify_literals)
alloc_queue.push(def_interval);
range.begin = cur_points;
def_interval->def_point = cur_points;
range.size = (iter->get_shape()).bytes();
live_set.erase(range.vn);
}
}
......@@ -203,9 +203,8 @@ void memory_coloring_impl::rewrite()
if(is_allocate(ins))
{
assert(!ins->inputs().empty());
p_program->replace_instruction(
ins, op::load{ins->inputs().at(0)->get_shape(), offset}, scratch_param);
ins, op::load{ins->get_shape(), offset}, scratch_param);
}
else if(is_literal(ins))
{
......@@ -233,9 +232,8 @@ void memory_coloring_impl::verify()
if(segment.begin == invalid_offset)
{
// TODO: This check breaks on the tests
// if(!interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
if(!interval.is_live_on_entry)
MIGRAPHX_THROW("interval is not live on entry");
continue;
}
......
......@@ -84,7 +84,7 @@ struct memory_coloring_impl
{
return is_param(ins) && any_cast<builtin::param>(ins->get_operator()).parameter == "output";
}
bool is_allocate(const instruction_ref ins) { return ins->name() == allocation_op; }
bool is_allocate(const instruction_ref ins) const { return ins->name() == allocation_op; }
static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; }
static bool is_literal(const instruction_ref ins) { return ins->name() == "@literal"; }
static bool is_check_context(const instruction_ref ins)
......
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
......@@ -134,6 +136,12 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert(has_instruction(ins));
assert(has_instruction(rep));
assert(ins != rep);
if(ins == std::prev(this->end()))
{
return replace_instruction(ins, op::identity{}, rep);
}
// TODO: Should it be an error if the output is empty?
if(ins->outputs().empty())
{
......@@ -271,6 +279,8 @@ instruction_ref program::end() const { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().get_shape(); }
context& program::get_context() const { return impl->ctx; }
instruction_ref program::validate() const
{
return std::find_if(impl->instructions.begin(),
......@@ -309,6 +319,15 @@ void program::compile(const target& t, tracer trace)
auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPHX_THROW("Invalid program from compilation at instruction " + std::to_string(index));
}
this->finalize();
}
void program::finalize()
{
for(auto ins : iterator_for(*this))
{
ins->finalize(this->impl->ctx);
}
}
template <class F>
......@@ -330,13 +349,17 @@ argument generic_eval(const program& p,
}
else if(ins->name() == "@param")
{
results.emplace(ins, trace(ins, [&] {
auto param_name =
any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name))
MIGRAPHX_THROW("Parameter not found: " + param_name);
return params.at(param_name);
}));
results.emplace(
ins, trace(ins, [&] {
auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name))
MIGRAPHX_THROW("Parameter not found: " + param_name);
auto param = params.at(param_name);
if(param.get_shape() != ins->get_shape())
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name);
return param;
}));
}
else if(ins->name() == "@outline")
{
......@@ -361,20 +384,31 @@ argument generic_eval(const program& p,
argument program::eval(std::unordered_map<std::string, argument> params) const
{
auto& ctx = this->impl->ctx;
#ifndef NDEBUG
auto sctx = ctx;
auto check_context = [&](auto f) {
assert(is_shared(ctx, sctx));
auto x = f();
sctx = ctx;
return x;
};
#else
auto check_context = [](auto f) { return f(); };
#endif
if(enabled(MIGRAPHX_TRACE_EVAL{}))
{
auto& ctx = this->impl->ctx;
return generic_eval(*this, this->impl->ctx, std::move(params), [&](auto& ins, auto f) {
return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) {
ctx.finish();
std::cout << "Run instruction: ";
this->debug_print(ins);
return f();
return check_context(f);
});
}
else
{
return generic_eval(
*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); });
*this, ctx, std::move(params), [&](auto&, auto f) { return check_context(f); });
}
}
......@@ -428,8 +462,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
overhead_vec.reserve(n);
for(std::size_t i = 0; i < n; i++)
{
overhead_vec.push_back(time<milliseconds>(
[&] { generic_eval(*this, ctx, params, [](auto...) { return argument{}; }); }));
overhead_vec.push_back(time<milliseconds>([&] { dry_run(params); }));
}
double total_time = common_average(total_vec);
......@@ -493,6 +526,12 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const
std::cout << std::endl;
}
void program::dry_run(std::unordered_map<std::string, argument> params) const
{
auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; });
}
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p)
......
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
if(MIGRAPHX_ENABLE_PYTHON)
find_program(DEFAULT_PYTHON_EXE python)
if(DEFAULT_PYTHON_EXE)
set(PYTHON_EXECUTABLE ${DEFAULT_PYTHON_EXE} CACHE PATH "Path to python executable")
endif()
find_package(pybind11 REQUIRED)
pybind11_add_module(migraphx_py migraphx_py.cpp)
set_target_properties(migraphx_py PROPERTIES
OUTPUT_NAME migraphx
C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden
)
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu)
if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(migraphx_py PRIVATE migraphx_gpu)
target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU)
endif()
rocm_install_targets(TARGETS migraphx_py)
endif()
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
namespace py = pybind11;
template <class F>
struct throw_half
{
F f;
template <class A>
void operator()(A a) const
{
f(a);
}
void operator()(migraphx::shape::as<migraphx::half>) const
{
throw std::runtime_error("Half not supported in python yet.");
}
void operator()(migraphx::tensor_view<migraphx::half>) const
{
throw std::runtime_error("Half not supported in python yet.");
}
};
template <class F>
struct skip_half
{
F f;
template <class A>
void operator()(A a) const
{
f(a);
}
void operator()(migraphx::shape::as<migraphx::half>) const {}
void operator()(migraphx::tensor_view<migraphx::half>) const {}
};
template <class F>
void visit_type(const migraphx::shape& s, F f)
{
s.visit_type(throw_half<F>{f});
}
template <class T, class F>
void visit(const migraphx::raw_data<T>& x, F f)
{
x.visit(throw_half<F>{f});
}
template <class F>
void visit_types(F f)
{
migraphx::shape::visit_types(skip_half<F>{f});
}
template <class T>
py::buffer_info to_buffer_info(T& x)
{
migraphx::shape s = x.get_shape();
auto strides = s.strides();
std::transform(
strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); });
py::buffer_info b;
visit_type(s, [&](auto as) {
b = py::buffer_info(x.data(),
as.size(),
py::format_descriptor<decltype(as())>::format(),
s.lens().size(),
s.lens(),
strides);
});
return b;
}
migraphx::shape to_shape(const py::buffer_info& info)
{
migraphx::shape::type_t t;
std::size_t n = 0;
visit_types([&](auto as) {
if(info.format == py::format_descriptor<decltype(as())>::format())
{
t = as.type_enum();
n = sizeof(as());
}
});
auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
return n > 0 ? i / n : 0;
});
return migraphx::shape{t, info.shape, strides};
}
PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
.def(py::init<>())
.def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides)
.def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes)
.def("type_size", &migraphx::shape::type_size)
.def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed)
.def("broadcasted", &migraphx::shape::broadcasted)
.def("standard", &migraphx::shape::standard)
.def("scalar", &migraphx::shape::scalar)
.def("__eq__", std::equal_to<migraphx::shape>{})
.def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def("__init__",
[](migraphx::argument& x, py::buffer b) {
py::buffer_info info = b.request();
new(&x) migraphx::argument(to_shape(info), info.ptr);
})
.def("get_shape", &migraphx::argument::get_shape)
.def("tolist",
[](migraphx::argument& x) {
py::list l{x.get_shape().elements()};
visit(x, [&](auto data) { l = py::cast(data.to_vector()); });
return l;
})
.def("__eq__", std::equal_to<migraphx::argument>{})
.def("__ne__", std::not_equal_to<migraphx::argument>{})
.def("__repr__", [](const migraphx::argument& x) { return migraphx::to_string(x); });
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::program>(m, "program")
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_shape", &migraphx::program::get_shape)
.def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); })
.def("run", &migraphx::program::eval)
.def("__eq__", std::equal_to<migraphx::program>{})
.def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
m.def("parse_onnx", &migraphx::parse_onnx);
m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu")
return migraphx::cpu::target{};
#ifdef HAVE_GPU
if(name == "gpu")
return migraphx::gpu::target{};
#endif
throw std::runtime_error("Target not found: " + name);
});
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false);
m.def("from_gpu", &migraphx::gpu::from_gpu);
m.def("gpu_sync", &migraphx::gpu::gpu_sync);
m.def("copy_to_gpu", &migraphx::gpu::copy_to_gpu);
#endif
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
#else
m.attr("__version__") = "dev";
#endif
}
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const
{
for(auto ins : iterator_for(prog))
{
if(ins->name() == "rnn")
{
apply_vanilla_rnn(prog, ins);
}
else if(ins->name() == "gru")
{
apply_gru(prog, ins);
}
else if(ins->name() == "lstm")
{
apply_lstm(prog, ins);
}
}
}
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
{
assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0);
auto actv_funcs = vanilla_rnn_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn_direction dicrt = rnn_op.direction;
instruction_ref last_output{};
if(dicrt == op::rnn_direction::bidirectional)
{
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = vanilla_rnn_cell(true,
prog,
ins,
args[0],
w_forward,
r_forward,
bias_forward,
ih_forward,
actv_funcs.at(0));
auto ret_reverse = vanilla_rnn_cell(false,
prog,
ins,
args[0],
w_reverse,
r_reverse,
bias_reverse,
ih_reverse,
actv_funcs.at(1));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
{
bool is_forward = (dicrt == op::rnn_direction::forward);
// input weight matrix
auto w = args[1];
// hidden state weight matrix
auto r = args[2];
// process bias and initial hidden state
instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
}
// process intial hidden state
instruction_ref ih;
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih = args[5];
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret =
vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
// search its output to find if there are rnn_last_output operator
// while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
}
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const
{
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
// squeeze and transpose r
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
if(bias != prog.end())
{
long hs = r->get_shape().lens()[2];
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
}
instruction_ref hidden_out = prog.end();
instruction_ref last_out{};
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
std::size_t seq_len = input->get_shape().lens()[0];
for(std::size_t i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
instruction_ref ht;
if(bias != prog.end())
{
ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
}
else
{
ht = xt_ht;
}
// apply activation function
ht = prog.insert_instruction(ins, actv_func, ht);
sih = ht;
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
// output inserted
if(i < seq_len - 1)
{
if(is_forward)
{
hidden_out =
(seq_index == 0)
? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
}
else
{
hidden_out =
(seq_index == seq_len - 1)
? last_out
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
}
}
}
return {hidden_out, last_out};
}
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
{
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
// could be 3 to 6 inputs, but the parse_gru function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have any num of arguments
// when writing their program.
if(rnn_op.direction == op::rnn_direction::bidirectional)
{
if(rnn_op.actv_funcs.empty())
{
// default is tanh
return {op::tanh{}, op::tanh{}};
}
else if(rnn_op.actv_funcs.size() == 1)
{
return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
}
else
{
return rnn_op.actv_funcs;
}
}
else
{
if(rnn_op.actv_funcs.empty())
{
// default is tanh
return {op::tanh{}};
}
else
{
return rnn_op.actv_funcs;
}
}
}
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{
assert(ins->name() == "gru");
const auto actv_funcs = gru_actv_funcs(ins);
// could be 3 to 6 inputs, but the parse_gru function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0.0);
auto gru_op = any_cast<op::gru>(ins->get_operator());
op::rnn_direction dicrt = gru_op.direction;
instruction_ref last_output{};
if(dicrt == op::rnn_direction::bidirectional)
{
// w weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// r weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// intial hidden state
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = gru_cell(true,
prog,
ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
auto ret_reverse = gru_cell(false,
prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
gru_op.linear_before_reset,
actv_funcs.at(2),
actv_funcs.at(3));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten
// from gru operator is a concat
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
{
bool is_forward = (dicrt == op::rnn_direction::forward);
// weight matrix
auto w = args[1];
auto r = args[2];
// bias
instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
}
// intial hidden state
instruction_ref ih{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih = args[5];
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret = gru_cell(is_forward,
prog,
ins,
{args[0], w, r, bias, ih},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
// replace the corresponding rnn_last_output instruction
// with the last_output, if rnn_last_output exists
// while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
}
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2) const
{
assert(inputs.size() == 5);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto bias = inputs.at(3);
auto ih = inputs.at(4);
instruction_ref hidden_states = prog.end();
instruction_ref last_output{};
migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]);
migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<int> data(s.elements(), 1);
auto l1 = prog.add_literal(migraphx::literal{s, data});
// weight matrix
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);
auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
instruction_ref brcst_bz{};
instruction_ref brcst_br{};
instruction_ref brcst_wbh{};
instruction_ref brcst_rbh{};
instruction_ref brcst_bh{};
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br);
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
}
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
if(bias != prog.end())
{
xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz);
}
auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
if(bias != prog.end())
{
xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
}
auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
instruction_ref xht_h;
if(linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
if(bias != prog.end())
{
ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh);
}
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh);
}
}
auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih);
sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
if(i < seq_len - 1)
{
if(is_forward)
{
hidden_states =
(seq_index == 0)
? last_output
: prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
}
else
{
hidden_states =
(seq_index == seq_len - 1)
? last_output
: prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
}
}
}
return {hidden_states, last_output};
}
std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
auto gru_op = any_cast<op::gru>(ins->get_operator());
// before rewrite the gru operator, need to ensure
// we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions
if(gru_op.direction == op::rnn_direction::bidirectional)
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0)};
else if(gru_op.actv_funcs.size() == 2)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1)};
else if(gru_op.actv_funcs.size() == 3)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1),
gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(0)};
else
return gru_op.actv_funcs;
}
else
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}};
else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
else
return gru_op.actv_funcs;
}
}
// for lstm operators
void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
{
assert(ins->name() == "lstm");
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
migraphx::shape pph_shape{type, {1, 3 * hidden_size}};
std::vector<float> pph_data(pph_shape.elements(), 0.0);
auto actv_funcs = lstm_actv_funcs(ins);
auto lstm_op = any_cast<op::lstm>(ins->get_operator());
op::rnn_direction dirct = lstm_op.direction;
instruction_ref last_output{};
instruction_ref last_cell_output{};
if(dirct == op::rnn_direction::bidirectional)
{
// input weight matrix
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// process intial hidden state, it is the 6th argument
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() >= 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process initial cell value
instruction_ref ic_forward{};
instruction_ref ic_reverse{};
if(args.size() >= 7 && args[6]->name() != "undefined")
{
ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]);
ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]);
}
else
{
ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process weight of the peephole
instruction_ref pph_forward = prog.end();
instruction_ref pph_reverse = prog.end();
if(args.size() == 8 && args[7]->name() != "undefined")
{
pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
}
auto ret_forward = lstm_cell(
true,
prog,
ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward},
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2));
auto ret_reverse = lstm_cell(
false,
prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse},
actv_funcs.at(3),
actv_funcs.at(4),
actv_funcs.at(5));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// last cell output
last_cell_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[2], ret_reverse[2]);
// the following logic is to ensure the last instruction is a concat
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
{
bool is_forward = (dirct == op::rnn_direction::forward);
// weight matrices
auto w = args[1];
auto r = args[2];
// bias
instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
}
// initial hidden state
instruction_ref ih{};
if(args.size() >= 6 && args[5]->name() != "undefined")
{
ih = args[5];
}
else
{
ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// initial cell value
instruction_ref ic{};
if(args.size() >= 7 && args[6]->name() != "undefined")
{
ic = args[6];
}
else
{
ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process weight of the peephole
instruction_ref pph = prog.end();
if(args.size() == 8 && args[7]->name() != "undefined")
{
pph = args[7];
}
auto ret = lstm_cell(is_forward,
prog,
ins,
{args[0], w, r, bias, ih, ic, pph},
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
last_cell_output = ret[2];
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
// replace the corresponding lstm_last_output instruction
// with the last_output, and the lstm_last_cell_output with
// the last_cell_output. The while loop is to handle the case
// of multiple lstm_last_output and lstm_last_cell_output
// operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
auto last_cell_output_it = ins->outputs().begin();
while(last_cell_output_it != ins->outputs().end())
{
last_cell_output_it = std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "lstm_last_cell_output";
});
if(last_cell_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_cell_output_it, last_cell_output);
last_cell_output_it++;
}
}
}
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
const operation& actv_func1,
const operation& actv_func2,
const operation& actv_func3) const
{
// must have 7 args in the input vector
assert(inputs.size() == 7);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto bias = inputs.at(3);
auto ih = inputs.at(4);
auto ic = inputs.at(5);
auto pph = inputs.at(6);
instruction_ref hidden_states = prog.end();
instruction_ref last_output{};
instruction_ref last_cell_output{};
migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]);
std::vector<int64_t> perm{1, 0};
// w matrix
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
auto tran_wi = prog.insert_instruction(ins, op::transpose{perm}, wi);
auto wo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wo = prog.insert_instruction(ins, op::transpose{perm}, wo);
auto wf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wf = prog.insert_instruction(ins, op::transpose{perm}, wf);
auto wc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sw);
auto tran_wc = prog.insert_instruction(ins, op::transpose{perm}, wc);
// r matrix
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto ri = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto tran_ri = prog.insert_instruction(ins, op::transpose{perm}, ri);
auto ro = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_ro = prog.insert_instruction(ins, op::transpose{perm}, ro);
auto rf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rf = prog.insert_instruction(ins, op::transpose{perm}, rf);
auto rc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sr);
auto tran_rc = prog.insert_instruction(ins, op::transpose{perm}, rc);
// initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// initial cell state
auto sic = prog.insert_instruction(ins, op::squeeze{{0}}, ic);
auto ic_shape = sic->get_shape();
// bias
instruction_ref bi_brcst{};
instruction_ref bo_brcst{};
instruction_ref bf_brcst{};
instruction_ref bc_brcst{};
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto bxi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto bhi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto bi = prog.insert_instruction(ins, op::add{}, bxi, bhi);
bi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bi);
auto bxo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto bho = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
auto bo = prog.insert_instruction(ins, op::add{}, bxo, bho);
bo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bo);
auto bxf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
auto bhf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias);
auto bf = prog.insert_instruction(ins, op::add{}, bxf, bhf);
bf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bf);
auto bxc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto bhc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
auto bc = prog.insert_instruction(ins, op::add{}, bxc, bhc);
bc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bc);
}
// peep hole
instruction_ref pphi_brcst{};
instruction_ref ppho_brcst{};
instruction_ref pphf_brcst{};
if(pph != prog.end())
{
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi);
pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho);
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf);
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst);
}
for(long i = 0; i < seq_len; ++i)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
if(pph != prog.end())
{
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
}
if(bias != prog.end())
{
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst);
}
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf);
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf);
if(pph != prog.end())
{
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
}
if(bias != prog.end())
{
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst);
}
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
auto xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc);
auto ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc);
auto ct_before_actv = prog.insert_instruction(ins, op::add{}, xt_wc, ht_rc);
if(bias != prog.end())
{
ct_before_actv = prog.insert_instruction(ins, op::add{}, ct_before_actv, bc_brcst);
}
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct
auto ft_cell = prog.insert_instruction(ins, op::mul{}, ft, sic);
auto it_ct = prog.insert_instruction(ins, op::mul{}, it, ct);
auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
last_cell_output = cellt;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
if(pph != prog.end())
{
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
}
if(bias != prog.end())
{
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst);
}
auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);
// Ht = ot (.) h(Ct)
auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
auto ht = prog.insert_instruction(ins, op::mul{}, ot, h_cellt);
sic = cellt;
sih = ht;
last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
if(i < seq_len - 1)
{
if(i == 0)
{
hidden_states = last_output;
}
else
{
auto concat_arg0 = is_forward ? hidden_states : last_output;
auto concat_arg1 = is_forward ? last_output : hidden_states;
hidden_states =
prog.insert_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
}
last_cell_output = prog.insert_instruction(ins, op::unsqueeze{{0}}, last_cell_output);
return {hidden_states, last_output, last_cell_output};
}
std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
auto lstm_op = any_cast<op::lstm>(ins->get_operator());
// before rewrite the lstm operator, need to ensure
// we have 6 actv funcs, even though a user does not
// specifiy any actv func. If less than 46, use the
// algorithm in parse_lstm to make 6 actv functions
const auto& actv_funcs = lstm_op.actv_funcs;
std::size_t num_actv_funcs = actv_funcs.size();
if(lstm_op.direction == op::rnn_direction::bidirectional)
{
switch(num_actv_funcs)
{
case 0:
return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1:
return {actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0),
actv_funcs.at(0)};
case 2:
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(1),
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(1)};
case 3:
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2)};
case 4:
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(3),
actv_funcs.at(3),
actv_funcs.at(3)};
case 5:
return {actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2),
actv_funcs.at(3),
actv_funcs.at(4),
actv_funcs.at(4)};
default: return actv_funcs;
}
}
else
{
switch(num_actv_funcs)
{
case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
default: return actv_funcs;
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -19,7 +19,7 @@ struct shape_impl
shape_impl() : m_type(shape::float_type), m_standard(false) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true) {}
shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true)
{
......@@ -169,10 +169,10 @@ std::string shape::type_string() const
{
switch(this->type())
{
#define MIGRAPHX_SHAPE_TYPE_STRING_CASE(x, t) \
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE(x, t) \
case x: return #x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_TYPE_STRING_CASE)
#undef MIGRAPHX_SHAPE_TYPE_STRING_CASE
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE
}
MIGRAPHX_THROW("Invalid type");
}
......
......@@ -9,55 +9,98 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool is_reshaper(const std::string& name)
bool is_reshaper(instruction_ref ins)
{
// clang-format off
static const std::unordered_set<std::string> names = {
"reshape",
"transpose",
// "broadcast",
"contiguous"
};
// clang-format on
return contains(names, name);
return contains(names, ins->name());
}
bool is_transpose_output(instruction_ref ins)
{
if(ins->outputs().size() != 1)
return false;
if(ins->outputs().front()->name() == "contiguous")
return is_transpose_output(ins->outputs().front());
return ins->outputs().front()->name() == "transpose";
}
instruction_ref find_transpose_input(instruction_ref ins)
{
if(ins->inputs().size() != 1)
return ins;
if(ins->inputs().front()->name() == "contiguous")
return find_transpose_input(ins->inputs().front());
if(ins->inputs().front()->name() == "transpose")
return ins->inputs().front();
return ins;
}
void simplify_reshapes::apply(program& p) const
{
auto end = std::prev(p.end());
for(auto ins : iterator_for(p))
{
if(not is_reshaper(ins->name()))
continue;
if(ins->outputs().size() != 1)
continue;
if(is_reshaper(ins->outputs().front()->name()))
if(ins->outputs().empty() and ins != end)
continue;
// Gather reshapes
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()->name()))
if(is_reshaper(ins))
{
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
reshapes.push_back(reshapes.back()->inputs().front());
}
if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper))
continue;
// Gather reshapes
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
for(auto start : iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
for(auto start : iterator_for(reshapes))
{
r = std::make_pair(*start, *last);
break;
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{
p.replace_instruction(r.first, r.second);
}
}
if(r.first != r.second)
else if(ins->name() == "transpose")
{
p.replace_instruction(r.first, r.second);
if(is_transpose_output(ins))
continue;
auto x = ins;
auto t = ins;
do
{
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
continue;
p.replace_instruction(ins, t->inputs().front());
}
}
// Replace all reshapes with as_shape
for(auto ins : iterator_for(p))
{
if(ins->name() != "reshape")
continue;
p.replace_instruction(ins, op::as_shape{ins->get_shape()}, ins->inputs());
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -7,6 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct pass;
namespace cpu {
struct target
......
......@@ -5,6 +5,7 @@
#include <migraphx/operators.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/cpu/gemm.hpp>
#include <unordered_map>
#include <utility>
......@@ -19,6 +20,14 @@ T zero(const T&)
return T(0);
}
template <class T>
typename std::conditional_t<std::is_integral<T>{}, std::make_signed<T>, std::enable_if<true, T>>::
type
make_signed(T x)
{
return x;
}
//
// cpu implemenataion of batch norm for inference
//
......@@ -64,7 +73,7 @@ struct cpu_batch_norm_inference
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width)(
par_dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c) + epsilon) > 0);
result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) /
......@@ -79,7 +88,7 @@ struct cpu_batch_norm_inference
visit_all(output, input, mini_batch_mean, mini_batch_mean, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width)(
par_dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c, h, w) + epsilon) > 0);
result(n, c, h, w) = gamma(c, h, w) *
......@@ -94,6 +103,43 @@ struct cpu_batch_norm_inference
}
};
struct cpu_lrn
{
op::lrn op;
std::string name() const { return "cpu::lrn"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
int n_batch = output_shape.lens()[0];
int channels = output_shape.lens()[1];
int height = output_shape.lens()[2];
int width = output_shape.lens()[3];
float alphaoverarea = op.alpha / op.size;
int radius = (op.size - 1) / 2;
par_dfor(n_batch, height, width)([&](int b, int h, int w) {
float scale = 0;
dfor(channels)([&](int c) {
auto start = (c - radius) < 0 ? 0 : (c - radius);
auto end = (c + radius) > channels ? channels : (c + radius);
for(auto k = start; k < end; ++k)
{
scale += std::pow(input(b, k, h, w), 2);
}
scale *= alphaoverarea;
scale += op.bias;
scale = std::pow(scale, -op.beta);
output(b, c, h, w) = input(b, c, h, w) * scale;
});
});
});
return result;
}
};
struct cpu_convolution
{
op::convolution op;
......@@ -104,28 +150,33 @@ struct cpu_convolution
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in_h = input.get_shape().lens()[2];
auto in_w = input.get_shape().lens()[3];
auto wei_c = weights.get_shape().lens()[1];
auto wei_h = weights.get_shape().lens()[2];
auto wei_w = weights.get_shape().lens()[3];
dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2],
output_shape.lens()[3])(
auto in = input.get_shape().lens();
auto in_h = in[2];
auto in_w = in[3];
auto wei = weights.get_shape().lens();
auto wei_n = wei[0];
auto wei_c = wei[1];
auto wei_h = wei[2];
auto wei_w = wei[3];
par_dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2],
output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1];
const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1];
const int group_id = w / (wei_n / op.group);
double acc = 0;
dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) {
const int in_x = start_x + x;
const int in_y = start_y + y;
const int in_x = start_x + x;
const int in_y = start_y + y;
const int in_ch = group_id * wei_c + k;
if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
{
acc += input(o, k, in_x, in_y) * weights(w, k, x, y);
acc += input(o, in_ch, in_x, in_y) * weights(w, k, x, y);
}
});
output(o, w, i, j) = acc;
......@@ -158,7 +209,8 @@ struct cpu_im2col
const std::size_t& stride_h = op.stride[0];
const std::size_t& stride_w = op.stride[1];
int kdiv2_h, kdiv2_w;
int kdiv2_h;
int kdiv2_w;
kdiv2_h = kernel_h / 2;
kdiv2_w = kernel_w / 2;
// calculate output sizes
......@@ -231,10 +283,10 @@ struct cpu_pooling
auto in_h = input.get_shape().lens()[2];
auto in_w = input.get_shape().lens()[3];
dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2],
output_shape.lens()[3])(
par_dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2],
output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x0 = i * op.stride[0] - op.padding[0];
const int start_y0 = j * op.stride[1] - op.padding[1];
......@@ -271,14 +323,33 @@ struct cpu_contiguous
std::string name() const { return "cpu::contiguous"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
return op.compute(output_shape, std::move(args));
}
};
struct cpu_pad
{
op::pad op;
std::string name() const { return "cpu::contiguous"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
assert(output_shape.standard());
argument result{output_shape};
result.visit([&](auto output) { std::fill(output.begin(), output.end(), op.value); });
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
shape_for_each(input.get_shape(), [&](const auto& idx) {
std::vector<std::size_t> new_idx(idx.size());
std::transform(
idx.begin(), idx.end(), op.pads.begin(), new_idx.begin(), [](auto i, auto j) {
return i + j;
});
output(new_idx.begin(), new_idx.end()) = input(idx.begin(), idx.end());
});
});
return result;
}
};
......@@ -290,24 +361,7 @@ struct cpu_concat
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
std::vector<std::size_t> coffsets = op.compute_offsets(output_shape, args);
for(std::size_t l = 0; l < args.size(); l++)
{
auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape =
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
// cppcheck-suppress useStlAlgorithm
for(std::size_t i = 0; i < nelements; i++)
{
slice[i] = input[i];
}
});
}
return result;
return op.compute(output_shape, std::move(args));
}
};
......@@ -325,6 +379,18 @@ struct cpu_gemm
}
};
struct cpu_gather
{
op::gather op;
std::string name() const { return "cpu::gather"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
return op.compute(output_shape, std::move(args));
}
};
struct identity_op
{
std::string name() const { return "cpu::identity"; }
......@@ -339,7 +405,7 @@ struct abs_op
std::string name() const { return "cpu::abs"; }
auto fcn() const
{
return [](auto x) { return std::abs(x); };
return [](auto x) { return std::abs(make_signed(x)); };
}
};
......@@ -352,6 +418,15 @@ struct exp_op
}
};
struct log_op
{
std::string name() const { return "cpu::log"; }
auto fcn() const
{
return [](auto x) { return std::log(x); };
}
};
struct sin_op
{
std::string name() const { return "cpu::sin"; }
......@@ -406,6 +481,24 @@ struct atan_op
}
};
struct sinh_op
{
std::string name() const { return "cpu::sinh"; }
auto fcn() const
{
return [](auto x) { return std::sinh(x); };
}
};
struct cosh_op
{
std::string name() const { return "cpu::cosh"; }
auto fcn() const
{
return [](auto x) { return std::cosh(x); };
}
};
struct tanh_op
{
std::string name() const { return "cpu::tanh"; }
......@@ -453,6 +546,17 @@ struct leaky_relu_op
}
};
struct elu_op
{
op::elu op;
std::string name() const { return "cpu::elu"; }
auto fcn() const
{
auto& a = op.alpha;
return [a](auto x) { return x > 0 ? x : a * std::expm1(x); };
}
};
template <typename Op>
struct cpu_unary
{
......@@ -545,6 +649,24 @@ struct div_op
}
};
struct max_op
{
std::string name() const { return "max"; }
auto fcn() const
{
return [](auto x, auto y) { return std::max(x, y); };
}
};
struct min_op
{
std::string name() const { return "min"; }
auto fcn() const
{
return [](auto x, auto y) { return std::min(x, y); };
}
};
template <typename Op>
struct cpu_binary
{
......@@ -596,22 +718,35 @@ struct cpu_apply
apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>();
apply_map["gather"] = extend_op<cpu_gather, op::gather>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
apply_map["abs"] = simple_op<cpu_unary<abs_op>>();
apply_map["sinh"] = simple_op<cpu_unary<sinh_op>>();
apply_map["cosh"] = simple_op<cpu_unary<cosh_op>>();
apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>();
apply_map["sigmoid"] = simple_op<cpu_unary<sigmoid_op>>();
apply_map["exp"] = simple_op<cpu_unary<exp_op>>();
apply_map["log"] = simple_op<cpu_unary<log_op>>();
apply_map["neg"] = simple_op<cpu_unary<neg_op>>();
apply_map["sin"] = simple_op<cpu_unary<sin_op>>();
apply_map["cos"] = simple_op<cpu_unary<cos_op>>();
apply_map["tan"] = simple_op<cpu_unary<tan_op>>();
apply_map["asin"] = simple_op<cpu_unary<asin_op>>();
apply_map["acos"] = simple_op<cpu_unary<acos_op>>();
apply_map["atan"] = simple_op<cpu_unary<atan_op>>();
apply_map["relu"] = simple_op<cpu_unary<relu_op>>();
apply_map["add"] = simple_op<cpu_binary<add_op>>();
apply_map["sub"] = simple_op<cpu_binary<sub_op>>();
apply_map["mul"] = simple_op<cpu_binary<mul_op>>();
apply_map["div"] = simple_op<cpu_binary<div_op>>();
apply_map["max"] = simple_op<cpu_binary<max_op>>();
apply_map["min"] = simple_op<cpu_binary<min_op>>();
apply_map["softmax"] = simple_op<softmax2d>();
}
......
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -11,7 +14,11 @@ std::string target::name() const { return "cpu"; }
std::vector<pass> target::get_passes(migraphx::context&) const
{
return {auto_contiguous{}, lowering{}};
return {auto_contiguous{},
rewrite_rnn{},
dead_code_elimination{},
lowering{},
dead_code_elimination{}};
}
} // namespace cpu
......
......@@ -12,11 +12,25 @@ endif()
add_library(migraphx_device
device/add.cpp
device/max.cpp
device/min.cpp
device/exp.cpp
device/log.cpp
device/sin.cpp
device/cos.cpp
device/tan.cpp
device/sinh.cpp
device/cosh.cpp
device/asin.cpp
device/acos.cpp
device/atan.cpp
device/add_relu.cpp
device/contiguous.cpp
device/mul.cpp
device/concat.cpp
device/pad.cpp
device/gather.cpp
device/sub.cpp
)
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device)
......@@ -38,12 +52,16 @@ add_library(migraphx_gpu
concat.cpp
relu.cpp
leaky_relu.cpp
add.cpp
sin.cpp
mul.cpp
tanh.cpp
batchnorm.cpp
write_literals.cpp
rocblas.cpp
sigmoid.cpp
abs.cpp
elu.cpp
pad.cpp
gather.cpp
lrn.cpp
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu)
......
#include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_abs::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
argument miopen_abs::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1;
float beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenActivationForward(ctx.get_stream().get_miopen(),
ad.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit());
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/add.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_add::compute_shape(const std::vector<shape>& inputs) const
{
// check_shapes{inputs, *this}.has(3).standard();
check_shapes{inputs, *this}.has(3);
return inputs.at(0);
}
argument hip_add::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::add(ctx.get_stream().get(), args[2], args[0], args[1]);
return args[2];
}
shape miopen_add::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3).not_broadcasted();
return inputs.at(0);
}
argument miopen_add::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1, beta = 0;
auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[1].get_shape());
auto c_desc = make_tensor(output_shape);
miopenOpTensor(ctx.get_stream().get_miopen(),
miopenTensorOpAdd,
&alpha,
a_desc.get(),
args[0].implicit(),
&alpha,
b_desc.get(),
args[1].implicit(),
&beta,
c_desc.get(),
args[2].implicit());
return args[2];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/batchnorm.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -22,7 +19,8 @@ argument miopen_batch_norm_inference::compute(context& ctx,
auto y_desc = make_tensor(output_shape);
auto bn_desc = make_tensor(args[3].get_shape());
float alpha = 1.0, beta = 0.0f;
float alpha = 1.0;
float beta = 0.0f;
miopenBatchNormalizationForwardInference(ctx.get_stream().get_miopen(),
miopenBatchNormMode_t(op.bn_mode),
......
#include <migraphx/gpu/concat.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/contiguous.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -21,7 +19,8 @@ argument miopen_convolution::compute(context& ctx,
auto w_desc = make_tensor(args[1].get_shape());
auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0;
float alpha = 1;
float beta = 0;
miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
......@@ -40,11 +39,11 @@ argument miopen_convolution::compute(context& ctx,
shape miopen_convolution::compile(context& ctx,
const shape& output_shape,
std::vector<instruction_ref> inputs)
std::vector<shape> inputs)
{
shape workspace_shape{};
auto x_desc = make_tensor(inputs[0]->get_shape());
auto w_desc = make_tensor(inputs[1]->get_shape());
auto x_desc = make_tensor(inputs[0]);
auto w_desc = make_tensor(inputs[1]);
auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0;
......@@ -56,31 +55,44 @@ shape miopen_convolution::compile(context& ctx,
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x = to_gpu(generate_argument(inputs[0]->get_shape()));
auto w = to_gpu(generate_argument(inputs[1]->get_shape()));
auto x = to_gpu(generate_argument(inputs[0]));
auto w = to_gpu(generate_argument(inputs[1]));
auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape);
int algo_count = 1;
miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(),
x.implicit(),
w_desc.get(),
w.implicit(),
cd.get(),
y_desc.get(),
y.implicit(),
1,
&algo_count,
&perf,
workspace.implicit(),
workspace_size,
false);
algo = perf.fwd_algo;
auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(),
x.implicit(),
w_desc.get(),
w.implicit(),
cd.get(),
y_desc.get(),
y.implicit(),
1,
&algo_count,
&perf,
workspace.implicit(),
workspace_size,
false);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("Find convolution failed");
handle = ctx.get_stream().get_miopen();
algo = perf.fwd_algo;
return shape{shape::int8_type, {perf.memory}};
}
void miopen_convolution::finalize(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
if(handle == ctx.get_stream().get_miopen())
return;
// TODO: Check that workspace hasn't changed
compile(ctx, output_shape, std::move(inputs));
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment