"src/lib/components/chat/Settings/Connections.svelte" did not exist on "87a19274477a5d9fdc76e819cc4c20a4d7b8cf5a"
Unverified Commit a5c1c7f6 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into mem_color_ordering_fix

parents 462a4920 d516b099
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_NAME_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
template <class PrivateMigraphTypeNameProbe>
const std::string& get_type_name()
......@@ -18,7 +18,7 @@ const std::string& get_type_name()
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
const char parameter_name[] = "PrivateMigraphTypeNameProbe =";
const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT
name = __PRETTY_FUNCTION__;
......@@ -41,7 +41,7 @@ const std::string& get_type_name(const T&)
return migraphx::get_type_name<T>();
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -5,32 +5,32 @@
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
};
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_VERIFY_HPP
#define MIGRAPH_GUARD_VERIFY_HPP
#ifndef MIGRAPHX_GUARD_VERIFY_HPP
#define MIGRAPHX_GUARD_VERIFY_HPP
#include <algorithm>
#include <cmath>
......@@ -11,7 +11,7 @@
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
// Compute the value of a range
template <class R>
......@@ -173,6 +173,6 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n
return error <= threshold;
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPHX_GUARD_RTGLIB_VERIFY_ARGS_HPP
#include <migraphx/verify.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
inline bool verify_args(const std::string& name,
const argument& cpu_arg,
......@@ -84,7 +84,7 @@ inline bool verify_args(const std::string& name,
return passed;
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -3,7 +3,7 @@
#include <migraphx/erase.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
......@@ -97,7 +97,7 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output
bool operator==(const instruction& x, const instruction& y)
{
if(not(x.result == y.result and x.op == y.op and x.arguments == y.arguments))
if(std::tie(x.result, x.op, x.arguments) != std::tie(y.result, y.op, y.arguments))
return false;
if(x.name() == "@literal")
return x.lit == y.lit;
......@@ -162,26 +162,55 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this);
}
std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args)
argument instruction::eval() const
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return shapes;
if(op.name() == "@literal")
{
return this->get_literal().get_argument();
}
if(is_context_free(op))
{
std::vector<argument> args;
for(auto&& arg : this->inputs())
{
argument a = arg->eval();
if(a.empty())
return {};
args.push_back(a);
}
return op.compute(result, args);
}
return {};
}
instruction_ref instruction::get_output_alias(instruction_ref ins)
void instruction::finalize(context& ctx)
{
auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs()));
if(has_finalize(this->op))
this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
}
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
{
auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
if(i < 0)
return ins;
if(shallow)
return ins->inputs().at(i);
return get_output_alias(ins->inputs().at(i));
}
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return shapes;
}
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
return op.compute_shape(compute_shapes(args));
return op.compute_shape(to_shapes(args));
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -22,7 +22,7 @@ rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx migraphx_onnx)
if(MIGRAPH_ENABLE_GPU)
if(MIGRAPHX_ENABLE_GPU)
add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist)
target_link_libraries(mnist migraphx_cpu migraphx_gpu migraphx_onnx)
......
......@@ -14,7 +14,10 @@
auto reverse_int(unsigned int i)
{
unsigned char c1, c2, c3, c4;
unsigned char c1;
unsigned char c2;
unsigned char c3;
unsigned char c4;
c1 = i & 255u;
c2 = (i >> 8u) & 255u;
c3 = (i >> 16u) & 255u;
......@@ -32,7 +35,9 @@ read_mnist_images(const std::string& full_path, int& number_of_images, int& imag
if(file.is_open())
{
int magic_number = 0, n_rows = 0, n_cols = 0;
int magic_number = 0;
int n_rows = 0;
int n_cols = 0;
file.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
magic_number = reverse_int(magic_number);
......
......@@ -15,54 +15,59 @@
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/onnx.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
struct unknown
{
std::string op;
std::string name() const { return "unknown:" + op; }
shape compute_shape(std::vector<shape> input) const
{
if(input.empty())
return {};
else
return input.front();
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
};
inline namespace MIGRAPHX_INLINE_NS {
struct onnx_parser
{
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
using op_func =
std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
program prog = program();
bool is_pytorch = false;
std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> map_actv_funcs;
onnx_parser()
{
add_generic_op("MatMul", op::dot{});
add_generic_op("Relu", op::relu{});
add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Abs", op::abs{});
add_generic_op("Exp", op::exp{});
add_generic_op("Log", op::log{});
// disable dropout for inference
add_generic_op("Dropout", op::identity{});
add_generic_op("Identity", op::identity{});
add_generic_op("Sin", op::sin{});
add_generic_op("Cos", op::cos{});
add_generic_op("Tan", op::tan{});
add_generic_op("Sinh", op::sinh{});
add_generic_op("Cosh", op::cosh{});
add_generic_op("Tanh", op::tanh{});
add_generic_op("Asin", op::asin{});
add_generic_op("Acos", op::acos{});
add_generic_op("Atan", op::atan{});
add_broadcastable_binary_op("Add", op::add{});
add_broadcastable_binary_op("Div", op::div{});
add_broadcastable_binary_op("Mul", op::mul{});
add_broadcastable_binary_op("Sub", op::sub{});
add_broadcastable_binary_op("Sum", op::add{});
add_binary_op("Add", op::add{});
add_binary_op("Div", op::div{});
add_binary_op("Mul", op::mul{});
add_binary_op("Sub", op::sub{});
add_variadic_op("Sum", op::add{});
add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{});
add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
......@@ -78,11 +83,38 @@ struct onnx_parser
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Concat", &onnx_parser::parse_concat);
add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("Pad", &onnx_parser::parse_pad);
// init the activation function map
init_actv_func();
}
void init_actv_func()
{
map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
}
template <class F>
void add_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
});
}
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
{
ops.emplace(name, f);
}
......@@ -90,81 +122,101 @@ struct onnx_parser
template <class F>
void add_mem_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
});
}
template <class T>
void add_broadcastable_binary_op(std::string name, T x)
void add_binary_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() != 2)
MIGRAPH_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast"))
MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast") and contains(attributes, "axis"))
{
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
{
uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
auto l =
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l);
}
return prog.add_instruction(x, args);
}
else if(args[0]->get_shape() != args[1]->get_shape())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<std::size_t>* s0 = &args[0]->get_shape().lens();
const std::vector<std::size_t>* s1 = &args[1]->get_shape().lens();
// Make sure s0 is the smaller size
if(s0->size() > s1->size())
std::swap(s0, s1);
// Copy the larger vector to output_lens
std::vector<std::size_t> output_lens(s1->size());
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, args[0]);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, args[1]);
return prog.add_instruction(x, l0, l1);
}
else
{
return prog.add_instruction(x, args);
return add_broadcastable_binary_op(args[0], args[1], x);
}
});
}
template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{
if(arg0->get_shape() != arg1->get_shape())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<std::size_t>* s0 = &arg0->get_shape().lens();
const std::vector<std::size_t>* s1 = &arg1->get_shape().lens();
// Make sure s0 is the smaller size
if(s0->size() > s1->size())
std::swap(s0, s1);
std::vector<std::size_t> output_lens(*s1);
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
return prog.add_instruction(x, l0, l1);
}
else
{
return prog.add_instruction(x, {arg0, arg1});
}
}
template <class T>
void add_generic_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
});
}
template <class T>
void add_variadic_op(std::string name, T x)
{
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()),
args.end(),
args.front(),
[this, x](instruction_ref a, instruction_ref b) {
return add_broadcastable_binary_op(a, b, x);
});
});
}
instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
......@@ -179,9 +231,30 @@ struct onnx_parser
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::convolution op;
auto l0 = args[0];
if(contains(attributes, "pads"))
{
copy(attributes["pads"].ints(), op.padding.begin());
if(contains(attributes, "auto_pad"))
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
std::vector<std::int64_t> padding;
copy(attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
// insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0);
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
if(contains(attributes, "strides"))
{
......@@ -191,6 +264,23 @@ struct onnx_parser
{
copy(attributes["dilations"].ints(), op.dilation.begin());
}
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
if(s.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
}
if(contains(attributes, "group"))
{
op.group = parse_value(attributes.at("group")).at<int>();
}
if(args.size() == 3)
{
uint64_t axis = 1;
......@@ -198,7 +288,7 @@ struct onnx_parser
auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape()}, args[2]);
return prog.add_instruction(op::add{}, l1, l2);
}
return prog.add_instruction(op, args);
return prog.add_instruction(op, l0, args[1]);
}
instruction_ref parse_pooling(const std::string& name,
......@@ -206,6 +296,7 @@ struct onnx_parser
std::vector<instruction_ref> args)
{
op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
auto l0 = args[0];
if(starts_with(name, "Global"))
{
auto lens = args.front()->get_shape().lens();
......@@ -213,7 +304,23 @@ struct onnx_parser
}
if(contains(attributes, "pads"))
{
copy(attributes["pads"].ints(), op.padding.begin());
std::vector<std::int64_t> padding;
copy(attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
// insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0);
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
if(contains(attributes, "strides"))
{
......@@ -223,7 +330,17 @@ struct onnx_parser
{
copy(attributes["kernel_shape"].ints(), op.lengths.begin());
}
return prog.add_instruction(op, std::move(args));
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(s.find("SAME_UPPER") == std::string::npos)
{
MIGRAPHX_THROW("auto_pad only supports SAME_UPPER for pooling");
}
op.padding_mode = op::padding_mode_t::same;
}
return prog.add_instruction(op, l0);
}
instruction_ref
......@@ -246,7 +363,7 @@ struct onnx_parser
instruction_ref
parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
uint64_t axis = 0;
uint64_t axis = 1;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
......@@ -280,6 +397,18 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
int axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
op::gather op{axis};
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......@@ -312,7 +441,7 @@ struct onnx_parser
parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 1.0f;
float beta = 0.0f;
float beta = 1.0f;
bool transa = false;
bool transb = false;
if(contains(attributes, "alpha"))
......@@ -321,7 +450,7 @@ struct onnx_parser
}
if(contains(attributes, "beta"))
{
alpha = parse_value(attributes.at("beta")).at<float>();
beta = parse_value(attributes.at("beta")).at<float>();
}
if(contains(attributes, "transA"))
{
......@@ -336,10 +465,20 @@ struct onnx_parser
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3)
{
uint64_t axis = 1;
auto l3 = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
auto l4 = prog.add_instruction(op::broadcast{axis, l3->get_shape()}, args[2]);
return prog.add_instruction(op::add{}, l3, l4);
if(beta != 0.f)
{
auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
auto l4 = args[2];
if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B)
return l3;
if(beta != 1.f)
{
auto beta_val = prog.add_literal(beta);
auto l5 = prog.add_instruction(op::scalar{args[2]->get_shape()}, beta_val);
l4 = prog.add_instruction(op::mul{}, args[2], l5);
}
return add_broadcastable_binary_op(l3, l4, op::add{});
}
}
return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
}
......@@ -387,6 +526,37 @@ struct onnx_parser
return prog.add_instruction(op, args.front());
}
instruction_ref
parse_elu(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 1.0; // default alpha val for elu
if(contains(attributes, "alpha"))
{
alpha = parse_value(attributes.at("alpha")).at<float>();
}
op::elu op{alpha};
return prog.add_instruction(op, args.front());
}
instruction_ref
parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 0.0001;
float beta = 0.75;
float bias = 1.0;
int size = 1;
if(contains(attributes, "alpha"))
alpha = parse_value(attributes.at("alpha")).at<float>();
if(contains(attributes, "beta"))
beta = parse_value(attributes.at("beta")).at<float>();
if(contains(attributes, "bias"))
bias = parse_value(attributes.at("bias")).at<float>();
if(contains(attributes, "size"))
size = parse_value(attributes.at("size")).at<int>();
op::lrn op{alpha, beta, bias, size};
return prog.add_instruction(op, args.front());
}
instruction_ref parse_imagescaler(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
......@@ -427,6 +597,329 @@ struct onnx_parser
return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
}
instruction_ref
parse_pad(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::vector<int64_t> pads{};
float value = 0.0f;
if(contains(attributes, "pads"))
{
auto&& pad_vals = attributes["pads"].ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
}
if(contains(attributes, "value"))
{
value = parse_value(attributes.at("value")).at<float>();
}
if(contains(attributes, "mode"))
{
auto mode = attributes.at("mode").s();
if(mode != "constant")
MIGRAPHX_THROW("migraphx currently only supports constant padding");
}
return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
}
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
instruction_ref
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i);
});
return prog.add_literal(migraphx::literal{s, vec_shape});
}
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill
// operator
instruction_ref parse_constant_fill(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
int input_as_shape = 0;
int dtype = 1;
float value = 0.0f;
if(contains(attributes, "dtype"))
{
dtype = parse_value(attributes.at("dtype")).at<int>();
}
migraphx::shape::type_t type = get_type(dtype);
if(contains(attributes, "input_as_shape"))
{
input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>();
}
if(contains(attributes, "value"))
{
value = parse_value(attributes.at("value")).at<float>();
}
if(contains(attributes, "extra_shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
}
if(input_as_shape == 1)
{
if(args.size() != 1)
{
MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
}
if(contains(attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
"at the same time");
}
migraphx::argument in = args[0]->eval();
if(in.empty())
{
MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims);
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
}
else if(input_as_shape == 0)
{
if(!contains(attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
}
literal ls = parse_value(attributes.at("shape"));
std::vector<std::size_t> dims;
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims};
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
}
else
{
MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
}
}
std::vector<instruction_ref>
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names{"tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
for_each(names.begin(), names.end(), [&](auto& fn) { vec_names.push_back(fn); });
}
for_each(vec_names.begin(), vec_names.end(), [&](auto& fn) {
if(map_actv_funcs.count(fn) == 0)
{
MIGRAPHX_THROW("RNN: activation function " + std::string(fn) + " not supported");
}
});
// bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction
if(dirct == op::rnn_direction::bidirectional)
{
if(vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& fn) {
return map_actv_funcs[fn];
});
// To be added later
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
// if the number of arguments is less than 6, append
// undefined operator to have 6 arguments
if(args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), (6 - args.size()), ins);
}
// first output for the concatenation of hidden states
auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
std::move(args));
// second output for the last hidden state
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output};
}
std::vector<instruction_ref>
parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names = {"sigmoid", "tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
}
// need 4 activation functions
if(dirct == op::rnn_direction::bidirectional)
{
// 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1 four times. If 2 actv functins are provided,
// assume forward and reverse use the same pair of actv
// functions. For the case of 3 actv functions provided,
// assume the 3rd one is repeated once and used by the
// reverse direction.
// This may need change later
if(vec_names.size() == 1)
{
vec_names.insert(vec_names.end(), 3, vec_names.at(0));
}
else if(vec_names.size() == 2)
{
// repeat the activation functions
vec_names.push_back(vec_names.at(0));
vec_names.push_back(vec_names.at(1));
}
else if(vec_names.size() == 3)
{
vec_names.push_back(vec_names.at(2));
}
}
else
{
if(vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0)
{
MIGRAPHX_THROW("GRU: activation function " + std::string(name) + " not supported");
}
});
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
return map_actv_funcs[name];
});
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
int linear_before_reset = 0;
if(contains(attributes, "linear_before_reset"))
{
linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), 6 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
std::move(args));
// second output for last gru output
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output};
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......@@ -439,7 +932,7 @@ struct onnx_parser
}
else
{
throw std::runtime_error("Failed reading");
MIGRAPHX_THROW("Failed reading onnx file.");
}
}
......@@ -469,14 +962,20 @@ struct onnx_parser
}
for(auto&& p : nodes)
{
this->parse_node(get_name(p.second));
this->parse_node(p.first);
}
}
void parse_undefined(const std::string& name)
{
auto ins = prog.add_instruction(op::undefined{});
instructions[name] = ins;
}
void parse_node(const std::string& name)
{
if(name.empty())
MIGRAPH_THROW("Onnx node must have a name");
MIGRAPHX_THROW("Onnx node must have a name");
if(instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
......@@ -485,23 +984,37 @@ struct onnx_parser
{
if(nodes.count(input) > 0)
{
auto&& iname = get_name(nodes.at(input));
assert(name != iname);
this->parse_node(iname);
args.push_back(instructions.at(iname));
assert(name != input);
this->parse_node(input);
}
else
else if(input.empty())
{
args.push_back(instructions.at(input));
this->parse_undefined(input);
}
args.push_back(instructions.at(input));
}
std::vector<instruction_ref> result;
if(ops.count(node.op_type()) == 0)
{
instructions[name] = prog.add_instruction(unknown{node.op_type()}, args);
result.push_back(prog.add_instruction(unknown{node.op_type()}, args));
}
else
{
instructions[name] = ops[node.op_type()](get_attributes(node), args);
result = ops[node.op_type()](get_attributes(node), args);
}
// Even no output nodes produce output in migraphx
if(node.output().empty() and result.size() == 1)
{
instructions[name] = result.front();
}
else
{
assert(node.output().size() >= result.size());
std::transform(result.begin(),
result.end(),
node.output().begin(),
std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(y, x); });
}
}
}
......@@ -516,25 +1029,24 @@ struct onnx_parser
return result;
}
static std::string get_name(const onnx::NodeProto& node)
{
if(node.name().empty())
{
std::string generated = "migraphx_unnamed_node";
return std::accumulate(node.output().begin(),
node.output().end(),
generated,
[](auto x, auto y) { return x + "_" + y; });
}
return node.name();
}
static node_map get_nodes(const onnx::GraphProto& graph)
{
std::unordered_map<std::string, onnx::NodeProto> result;
std::size_t n = 0;
for(auto&& node : graph.node())
{
result[get_name(node)] = node;
if(node.output().empty())
{
if(node.name().empty())
{
result["migraphx_unamed_node_" + std::to_string(n)] = node;
n++;
}
else
{
result[node.name()] = node;
}
}
for(auto&& output : node.output())
{
result[output] = node;
......@@ -566,12 +1078,17 @@ struct onnx_parser
case onnx::AttributeProto::TENSORS: return {};
case onnx::AttributeProto::GRAPHS: return {};
}
MIGRAPH_THROW("Invalid attribute type");
MIGRAPHX_THROW("Invalid attribute type");
}
static literal parse_tensor(const onnx::TensorProto& t)
{
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty())
{
dims = {1};
}
if(t.has_raw_data())
{
const std::string& s = t.raw_data();
......@@ -594,7 +1111,7 @@ struct onnx_parser
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
}
MIGRAPH_THROW("Invalid tensor type");
MIGRAPHX_THROW("Invalid tensor type");
}
switch(t.data_type())
{
......@@ -625,7 +1142,7 @@ struct onnx_parser
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
}
MIGRAPH_THROW("Invalid tensor type");
MIGRAPHX_THROW("Invalid tensor type");
}
static shape parse_type(const onnx::TypeProto& t)
......@@ -671,6 +1188,28 @@ struct onnx_parser
});
return {shape_type, dims};
}
shape::type_t get_type(int dtype)
{
switch(dtype)
{
case 1: return shape::float_type;
case 2: return shape::uint8_type;
case 3: return shape::int8_type;
case 4: return shape::uint16_type;
case 5: return shape::int16_type;
case 6: return shape::int32_type;
case 7: return shape::int64_type;
case 10: return shape::half_type;
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
default:
{
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
}
}
};
program parse_onnx(const std::string& name)
......@@ -694,5 +1233,5 @@ program parse_onnx(const std::string& name)
return std::move(parser.prog);
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -116,7 +116,7 @@ void verify_reduced_program(F f, double tolerance = 80)
{
migraphx::program p = f();
auto n = std::distance(p.begin(), p.end());
for(int i = 0; i < n; i++)
for(std::size_t i = 0; i < n; i++)
{
verify_reduced(f, i, tolerance);
}
......
#ifndef MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
......@@ -14,17 +14,17 @@
#include <queue>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
//#define MIGRAPH_DEBUG_OPT
//#define MIGRAPHX_DEBUG_OPT
#ifdef MIGRAPH_DEBUG_OPT
#define MIGRAPH_DEBUG(s) s
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPH_DEBUG(s)
#endif // MIGRAPH_DEBUG_OPT
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#endif // MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
......@@ -2,16 +2,16 @@
#include "memory_coloring_impl.hpp"
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(program& p) const
{
if(!enabled(MIGRAPH_DISABLE_MEMORY_COLORING{}))
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{
memory_coloring_impl opt(&p, allocation_op, verify);
opt.run();
}
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include "memory_coloring_impl.hpp"
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring_impl::run()
{
MIGRAPH_DEBUG(dump("---Before memory coloring---"));
MIGRAPH_DEBUG(dump_program());
MIGRAPHX_DEBUG(dump("---Before memory coloring---"));
MIGRAPHX_DEBUG(dump_program());
build();
if(num_of_lives != 0)
{
MIGRAPH_DEBUG(dump_intervals());
MIGRAPHX_DEBUG(dump_intervals());
// Coloring
while(!alloc_queue.empty())
{
......@@ -85,7 +85,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
conflict_queue.pop();
}
segment.offset = offset;
MIGRAPH_DEBUG(segment.dump());
MIGRAPHX_DEBUG(segment.dump());
required_bytes = std::max(required_bytes, offset + segment.size);
return true;
}
......@@ -218,8 +218,8 @@ void memory_coloring_impl::rewrite()
}
}
}
MIGRAPH_DEBUG(dump("---After rewrite---"));
MIGRAPH_DEBUG(dump_program());
MIGRAPHX_DEBUG(dump("---After rewrite---"));
MIGRAPHX_DEBUG(dump_program());
}
void memory_coloring_impl::verify()
......@@ -235,7 +235,7 @@ void memory_coloring_impl::verify()
{
// TODO: This check breaks on the tests
// if(!interval.is_live_on_entry)
// MIGRAPH_THROW("interval is not live on entry");
// MIGRAPHX_THROW("interval is not live on entry");
continue;
}
......@@ -253,14 +253,14 @@ void memory_coloring_impl::verify()
if(range->offset == invalid_offset)
continue;
if(!is_disjoin(*range, segment))
MIGRAPH_THROW("range and segment is not disjoined");
MIGRAPHX_THROW("range and segment is not disjoined");
}
}
}
}
}
#ifdef MIGRAPH_DEBUG_OPT
#ifdef MIGRAPHX_DEBUG_OPT
void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; }
......@@ -334,5 +334,5 @@ void live_interval::dump()
#endif
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include "common_header.hpp"
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
static const int invalid_offset = -1;
......@@ -15,7 +15,7 @@ struct live_range
long long offset; // offset to base pointer of allocated memory trunk.
int vn; // value number that identifies this live_range.
long long size; // size of required memory in bytes
#ifdef MIGRAPH_DEBUG_OPT
#ifdef MIGRAPHX_DEBUG_OPT
void dump();
#endif
};
......@@ -35,7 +35,7 @@ struct live_interval
int get_end() const { return segment.end; }
long long get_offset() const { return segment.offset; }
#ifdef MIGRAPH_DEBUG_OPT
#ifdef MIGRAPHX_DEBUG_OPT
void dump();
#endif
......@@ -84,7 +84,7 @@ struct memory_coloring_impl
{
return is_param(ins) && any_cast<builtin::param>(ins->get_operator()).parameter == "output";
}
bool is_allocate(const instruction_ref ins) { return ins->name() == allocation_op; }
bool is_allocate(const instruction_ref ins) const { return ins->name() == allocation_op; }
static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; }
static bool is_literal(const instruction_ref ins) { return ins->name() == "@literal"; }
static bool is_check_context(const instruction_ref ins)
......@@ -101,7 +101,7 @@ struct memory_coloring_impl
return ((end1 < range2.offset) || (end2 < range1.offset));
}
void verify();
#ifdef MIGRAPH_DEBUG_OPT
#ifdef MIGRAPHX_DEBUG_OPT
void dump(const std::string&);
void dump_program();
void dump_intervals();
......@@ -154,6 +154,6 @@ struct memory_coloring_impl
bool enable_verify;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
......@@ -11,10 +12,10 @@
#include <utility>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_COMPILE)
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_EVAL)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl
{
......@@ -134,6 +135,12 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert(has_instruction(ins));
assert(has_instruction(rep));
assert(ins != rep);
if(ins == std::prev(this->end()))
{
return replace_instruction(ins, op::identity{}, rep);
}
// TODO: Should it be an error if the output is empty?
if(ins->outputs().empty())
{
......@@ -271,6 +278,8 @@ instruction_ref program::end() const { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().get_shape(); }
context& program::get_context() const { return impl->ctx; }
instruction_ref program::validate() const
{
return std::find_if(impl->instructions.begin(),
......@@ -282,7 +291,7 @@ void program::compile(const target& t, tracer trace)
{
assert(this->validate() == impl->instructions.end());
this->impl->ctx = t.get_context();
if(enabled(MIGRAPH_TRACE_COMPILE{}))
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
trace = tracer{std::cout};
trace(*this);
trace();
......@@ -297,8 +306,8 @@ void program::compile(const target& t, tracer trace)
if(invalid != impl->instructions.end())
{
auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
}
trace();
#endif
......@@ -307,7 +316,16 @@ void program::compile(const target& t, tracer trace)
if(invalid != impl->instructions.end())
{
auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW("Invalid program from compilation at instruction " + std::to_string(index));
MIGRAPHX_THROW("Invalid program from compilation at instruction " + std::to_string(index));
}
this->finalize();
}
void program::finalize()
{
for(auto ins : iterator_for(*this))
{
ins->finalize(this->impl->ctx);
}
}
......@@ -334,7 +352,7 @@ argument generic_eval(const program& p,
auto param_name =
any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name))
MIGRAPH_THROW("Parameter not found: " + param_name);
MIGRAPHX_THROW("Parameter not found: " + param_name);
return params.at(param_name);
}));
}
......@@ -361,20 +379,31 @@ argument generic_eval(const program& p,
argument program::eval(std::unordered_map<std::string, argument> params) const
{
if(enabled(MIGRAPH_TRACE_EVAL{}))
auto& ctx = this->impl->ctx;
#ifndef NDEBUG
auto sctx = ctx;
auto check_context = [&](auto f) {
assert(is_shared(ctx, sctx));
auto x = f();
sctx = ctx;
return x;
};
#else
auto check_context = [](auto f) { return f(); };
#endif
if(enabled(MIGRAPHX_TRACE_EVAL{}))
{
auto& ctx = this->impl->ctx;
return generic_eval(*this, this->impl->ctx, std::move(params), [&](auto& ins, auto f) {
return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) {
ctx.finish();
std::cout << "Run instruction: ";
this->debug_print(ins);
return f();
return check_context(f);
});
}
else
{
return generic_eval(
*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); });
*this, ctx, std::move(params), [&](auto&, auto f) { return check_context(f); });
}
}
......@@ -428,8 +457,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
overhead_vec.reserve(n);
for(std::size_t i = 0; i < n; i++)
{
overhead_vec.push_back(time<milliseconds>(
[&] { generic_eval(*this, ctx, params, [](auto...) { return argument{}; }); }));
overhead_vec.push_back(time<milliseconds>([&] { dry_run(params); }));
}
double total_time = common_average(total_vec);
......@@ -493,6 +521,12 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const
std::cout << std::endl;
}
void program::dry_run(std::unordered_map<std::string, argument> params) const
{
auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; });
}
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p)
......@@ -501,5 +535,5 @@ std::ostream& operator<<(std::ostream& os, const program& p)
return os;
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
if(MIGRAPHX_ENABLE_PYTHON)
find_program(DEFAULT_PYTHON_EXE python)
if(DEFAULT_PYTHON_EXE)
set(PYTHON_EXECUTABLE ${DEFAULT_PYTHON_EXE} CACHE PATH "Path to python executable")
endif()
find_package(pybind11 REQUIRED)
pybind11_add_module(migraphx_py migraphx_py.cpp)
set_target_properties(migraphx_py PROPERTIES
OUTPUT_NAME migraphx
C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden
)
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu)
if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(migraphx_py PRIVATE migraphx_gpu)
target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU)
endif()
endif()
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
namespace py = pybind11;
template <class F>
struct throw_half
{
F f;
template <class A>
void operator()(A a) const
{
f(a);
}
void operator()(migraphx::shape::as<migraphx::half>) const
{
throw std::runtime_error("Half not supported in python yet.");
}
};
template <class F>
struct skip_half
{
F f;
template <class A>
void operator()(A a) const
{
f(a);
}
void operator()(migraphx::shape::as<migraphx::half>) const {}
};
template <class F>
void visit_type(const migraphx::shape& s, F f)
{
s.visit_type(throw_half<F>{f});
}
template <class F>
void visit_types(F f)
{
migraphx::shape::visit_types(skip_half<F>{f});
}
template <class T>
py::buffer_info to_buffer_info(T& x)
{
migraphx::shape s = x.get_shape();
py::buffer_info b;
visit_type(s, [&](auto as) {
b = py::buffer_info(x.data(),
as.size(),
py::format_descriptor<decltype(as())>::format(),
s.lens().size(),
s.lens(),
s.strides());
});
return b;
}
migraphx::shape to_shape(const py::buffer_info& info)
{
migraphx::shape::type_t t;
visit_types([&](auto as) {
if(info.format == py::format_descriptor<decltype(as())>::format())
t = as.type_enum();
});
return migraphx::shape{t, info.shape, info.strides};
}
PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
.def(py::init<>())
.def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides)
.def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes)
.def("type_size", &migraphx::shape::type_size)
.def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed)
.def("broadcasted", &migraphx::shape::broadcasted)
.def("standard", &migraphx::shape::standard)
.def("scalar", &migraphx::shape::scalar)
.def("__eq__", std::equal_to<migraphx::shape>{})
.def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def("__init__",
[](migraphx::argument& x, py::buffer b) {
py::buffer_info info = b.request();
new(&x) migraphx::argument(to_shape(info), info.ptr);
})
.def("__eq__", std::equal_to<migraphx::argument>{})
.def("__ne__", std::not_equal_to<migraphx::argument>{})
.def("__repr__", [](const migraphx::argument& x) { return migraphx::to_string(x); });
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::program>(m, "program")
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_shape", &migraphx::program::get_shape)
.def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); })
.def("run", &migraphx::program::eval)
.def("__eq__", std::equal_to<migraphx::program>{})
.def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
m.def("parse_onnx", &migraphx::parse_onnx);
m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu")
return migraphx::cpu::target{};
#ifdef HAVE_GPU
if(name == "gpu")
return migraphx::gpu::target{};
#endif
throw std::runtime_error("Target not found: " + name);
});
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false);
m.def("from_gpu", &migraphx::gpu::from_gpu);
m.def("gpu_sync", &migraphx::gpu::gpu_sync);
m.def("copy_to_gpu", &migraphx::gpu::copy_to_gpu);
#endif
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
#else
m.attr("__version__") = "dev";
#endif
}
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const
{
for(auto ins : iterator_for(prog))
{
if(ins->name() == "rnn")
{
apply_vanilla_rnn(prog, ins);
}
if(ins->name() == "gru")
{
apply_gru(prog, ins);
}
}
}
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
{
assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0);
auto actv_funcs = vanilla_rnn_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn_direction dicrt = rnn_op.direction;
instruction_ref last_output{};
if(dicrt == op::rnn_direction::bidirectional)
{
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = vanilla_rnn_cell(true,
prog,
ins,
args[0],
w_forward,
r_forward,
bias_forward,
ih_forward,
actv_funcs.at(0));
auto ret_reverse = vanilla_rnn_cell(false,
prog,
ins,
args[0],
w_reverse,
r_reverse,
bias_reverse,
ih_reverse,
actv_funcs.at(1));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
{
bool is_forward = (dicrt == op::rnn_direction::forward);
// input weight matrix
auto w = args[1];
// hidden state weight matrix
auto r = args[2];
// process bias and initial hidden state
instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
}
// process intial hidden state
instruction_ref ih;
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih = args[5];
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret =
vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
// search its output to find if there are rnn_last_output operator
// while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
}
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const
{
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
// squeeze and transpose r
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
if(bias != prog.end())
{
long hs = r->get_shape().lens()[2];
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
}
instruction_ref hidden_out = prog.end();
instruction_ref last_out{};
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
std::size_t seq_len = input->get_shape().lens()[0];
for(std::size_t i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
instruction_ref ht;
if(bias != prog.end())
{
ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
}
else
{
ht = xt_ht;
}
// apply activation function
ht = prog.insert_instruction(ins, actv_func, ht);
sih = ht;
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
// output inserted
if(i < seq_len - 1)
{
if(is_forward)
{
hidden_out =
(seq_index == 0)
? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
}
else
{
hidden_out =
(seq_index == seq_len - 1)
? last_out
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
}
}
}
return {hidden_out, last_out};
}
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
{
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
// could be 3 to 6 inputs, but the parse_gru function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have any num of arguments
// when writing their program.
if(rnn_op.direction == op::rnn_direction::bidirectional)
{
if(rnn_op.actv_funcs.empty())
{
// default is tanh
return {op::tanh{}, op::tanh{}};
}
else if(rnn_op.actv_funcs.size() == 1)
{
return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
}
else
{
return rnn_op.actv_funcs;
}
}
else
{
if(rnn_op.actv_funcs.empty())
{
// default is tanh
return {op::tanh{}};
}
else
{
return rnn_op.actv_funcs;
}
}
}
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{
assert(ins->name() == "gru");
const auto actv_funcs = gru_actv_funcs(ins);
// could be 3 to 6 inputs, but the parse_gru function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0.0);
auto gru_op = any_cast<op::gru>(ins->get_operator());
op::rnn_direction dicrt = gru_op.direction;
instruction_ref last_output{};
if(dicrt == op::rnn_direction::bidirectional)
{
// w weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// r weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// intial hidden state
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = gru_cell(true,
prog,
ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
auto ret_reverse = gru_cell(false,
prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
gru_op.linear_before_reset,
actv_funcs.at(2),
actv_funcs.at(3));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten
// from gru operator is a concat
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
{
bool is_forward = (dicrt == op::rnn_direction::forward);
// weight matrix
auto w = args[1];
auto r = args[2];
// bias
instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
}
// intial hidden state
instruction_ref ih{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih = args[5];
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret = gru_cell(is_forward,
prog,
ins,
{args[0], w, r, bias, ih},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
// replace the corresponding rnn_last_output instruction
// with the last_output, if rnn_last_output exists
// while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
}
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2) const
{
assert(inputs.size() == 5);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto bias = inputs.at(3);
auto ih = inputs.at(4);
instruction_ref hidden_states = prog.end();
instruction_ref last_output{};
migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]);
migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<int> data(s.elements(), 1);
auto l1 = prog.add_literal(migraphx::literal{s, data});
// weight matrix
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);
auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
instruction_ref brcst_bz{};
instruction_ref brcst_br{};
instruction_ref brcst_wbh{};
instruction_ref brcst_rbh{};
instruction_ref brcst_bh{};
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br);
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
}
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
if(bias != prog.end())
{
xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz);
}
auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
if(bias != prog.end())
{
xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
}
auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
instruction_ref xht_h;
if(linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
if(bias != prog.end())
{
ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh);
}
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh);
}
}
auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih);
sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
if(i < seq_len - 1)
{
if(is_forward)
{
hidden_states =
(seq_index == 0)
? last_output
: prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
}
else
{
hidden_states =
(seq_index == seq_len - 1)
? last_output
: prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
}
}
}
return {hidden_states, last_output};
}
std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
auto gru_op = any_cast<op::gru>(ins->get_operator());
// before rewrite the gru operator, need to ensure
// we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions
if(gru_op.direction == op::rnn_direction::bidirectional)
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0)};
else if(gru_op.actv_funcs.size() == 2)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1)};
else if(gru_op.actv_funcs.size() == 3)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1),
gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(0)};
else
return gru_op.actv_funcs;
}
else
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}};
else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
else
return gru_op.actv_funcs;
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -7,7 +7,7 @@
#include <iostream>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
struct shape_impl
{
......@@ -169,12 +169,12 @@ std::string shape::type_string() const
{
switch(this->type())
{
#define MIGRAPH_SHAPE_TYPE_STRING_CASE(x, t) \
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE(x, t) \
case x: return #x;
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_TYPE_STRING_CASE)
#undef MIGRAPH_SHAPE_TYPE_STRING_CASE
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE
}
MIGRAPH_THROW("Invalid type");
MIGRAPHX_THROW("Invalid type");
}
bool operator==(const shape& x, const shape& y)
......@@ -191,5 +191,5 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
return os;
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -5,7 +5,7 @@
#include <migraphx/literal.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
struct find_add_lit_broadcast
{
......@@ -61,5 +61,5 @@ struct find_add_lit_broadcast
void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); }
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -7,58 +7,101 @@
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
bool is_reshaper(const std::string& name)
bool is_reshaper(instruction_ref ins)
{
// clang-format off
static const std::unordered_set<std::string> names = {
"reshape",
"transpose",
// "broadcast",
"contiguous"
};
// clang-format on
return contains(names, name);
return contains(names, ins->name());
}
bool is_transpose_output(instruction_ref ins)
{
if(ins->outputs().size() != 1)
return false;
if(ins->outputs().front()->name() == "contiguous")
return is_transpose_output(ins->outputs().front());
return ins->outputs().front()->name() == "transpose";
}
instruction_ref find_transpose_input(instruction_ref ins)
{
if(ins->inputs().size() != 1)
return ins;
if(ins->inputs().front()->name() == "contiguous")
return find_transpose_input(ins->inputs().front());
if(ins->inputs().front()->name() == "transpose")
return ins->inputs().front();
return ins;
}
void simplify_reshapes::apply(program& p) const
{
auto end = std::prev(p.end());
for(auto ins : iterator_for(p))
{
if(not is_reshaper(ins->name()))
continue;
if(ins->outputs().size() != 1)
continue;
if(is_reshaper(ins->outputs().front()->name()))
if(ins->outputs().empty() and ins != end)
continue;
// Gather reshapes
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()->name()))
if(is_reshaper(ins))
{
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
reshapes.push_back(reshapes.back()->inputs().front());
}
if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper))
continue;
// Gather reshapes
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
for(auto start : iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
for(auto start : iterator_for(reshapes))
{
r = std::make_pair(*start, *last);
break;
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{
p.replace_instruction(r.first, r.second);
}
}
if(r.first != r.second)
else if(ins->name() == "transpose")
{
p.replace_instruction(r.first, r.second);
if(is_transpose_output(ins))
continue;
auto x = ins;
auto t = ins;
do
{
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
continue;
p.replace_instruction(ins, t->inputs().front());
}
}
// Replace all reshapes with as_shape
for(auto ins : iterator_for(p))
{
if(ins->name() != "reshape")
continue;
p.replace_instruction(ins, op::as_shape{ins->get_shape()}, ins->inputs());
}
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment