Commit 1599d553 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into device-refactor

parents 2220bd25 15eb1987
...@@ -40,6 +40,7 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU ...@@ -40,6 +40,7 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU
set(PACKAGE_DEPENDS) set(PACKAGE_DEPENDS)
add_subdirectory(driver)
add_subdirectory(onnx) add_subdirectory(onnx)
add_subdirectory(tf) add_subdirectory(tf)
......
add_executable(driver main.cpp verify.cpp perf.cpp)
rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_cpu migraphx_onnx migraphx_tf)
if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(driver migraphx_gpu)
target_compile_definitions(driver PRIVATE -DHAVE_GPU)
endif()
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP
#include <algorithm>
#include <functional>
#include <iostream>
#include <set>
#include <string>
#include <sstream>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef MIGRAPHX_USE_CLANG_TIDY
#define MIGRAPHX_DRIVER_STATIC
#else
#define MIGRAPHX_DRIVER_STATIC static
#endif
template <class T>
struct value_parser
{
template <MIGRAPHX_REQUIRES(not std::is_enum<T>{})>
static T apply(const std::string& x)
{
T result;
std::stringstream ss;
ss.str(x);
ss >> result;
if(ss.fail())
throw std::runtime_error("Failed to parse: " + x);
return result;
}
template <MIGRAPHX_REQUIRES(std::is_enum<T>{})>
static T apply(const std::string& x)
{
std::ptrdiff_t i;
std::stringstream ss;
ss.str(x);
ss >> i;
if(ss.fail())
throw std::runtime_error("Failed to parse: " + x);
return static_cast<T>(i);
}
};
struct argument_parser
{
struct argument
{
std::vector<std::string> flags;
std::function<bool(argument_parser&, const std::vector<std::string>&)> action{};
std::string type = "";
std::string help = "";
std::string metavar = "";
std::string default_value = "";
unsigned nargs = 1;
};
template <class T, class... Fs>
void operator()(T& x, const std::vector<std::string>& flags, Fs... fs)
{
arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) {
if(params.empty())
throw std::runtime_error("Flag with no value.");
x = value_parser<T>::apply(params.back());
return false;
}});
argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>();
arg.default_value = to_string(x);
migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
}
template <class... Fs>
void operator()(std::nullptr_t x, std::vector<std::string> flags, Fs... fs)
{
arguments.push_back({std::move(flags)});
argument& arg = arguments.back();
arg.type = "";
arg.nargs = 0;
migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
}
MIGRAPHX_DRIVER_STATIC auto nargs(unsigned n = 1)
{
return [=](auto&&, auto& arg) { arg.nargs = n; };
}
template <class F>
MIGRAPHX_DRIVER_STATIC auto write_action(F f)
{
return [=](auto& x, auto& arg) {
arg.action = [&, f](auto& self, const std::vector<std::string>& params) {
f(self, x, params);
return false;
};
};
}
template <class F>
MIGRAPHX_DRIVER_STATIC auto do_action(F f)
{
return [=](auto&, auto& arg) {
arg.nargs = 0;
arg.action = [&, f](auto& self, const std::vector<std::string>&) {
f(self);
return true;
};
};
}
MIGRAPHX_DRIVER_STATIC auto append()
{
return write_action([](auto&, auto& x, auto& params) {
using type = typename decltype(params)::value_type;
std::transform(params.begin(),
params.end(),
std::inserter(x, x.end()),
[](std::string y) { return value_parser<type>::apply(y); });
});
}
MIGRAPHX_DRIVER_STATIC auto show_help(const std::string& msg = "")
{
return do_action([=](auto& self) {
for(auto&& arg : self.arguments)
{
std::cout << std::endl;
std::string prefix = " ";
if(arg.flags.empty())
{
std::cout << prefix;
std::cout << arg.metavar;
}
for(const std::string& a : arg.flags)
{
std::cout << prefix;
std::cout << a;
prefix = ", ";
}
if(not arg.type.empty())
{
std::cout << " [" << arg.type << "]";
if(not arg.default_value.empty())
std::cout << " (Default: " << arg.default_value << ")";
}
std::cout << std::endl;
std::cout << " " << arg.help << std::endl;
}
std::cout << std::endl;
if(not msg.empty())
std::cout << msg << std::endl;
});
}
MIGRAPHX_DRIVER_STATIC auto help(const std::string& help)
{
return [=](auto&, auto& arg) { arg.help = help; };
}
MIGRAPHX_DRIVER_STATIC auto metavar(const std::string& metavar)
{
return [=](auto&, auto& arg) { arg.metavar = metavar; };
}
template <class T>
MIGRAPHX_DRIVER_STATIC auto set_value(T value)
{
return [=](auto& x, auto& arg) {
arg.nargs = 0;
arg.type = "";
arg.action = [&, value](auto&, const std::vector<std::string>&) {
x = value;
return false;
};
};
}
bool parse(std::vector<std::string> args)
{
std::unordered_map<std::string, unsigned> keywords;
for(auto&& arg : arguments)
{
for(auto&& flag : arg.flags)
keywords[flag] = arg.nargs + 1;
}
auto arg_map =
generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; });
for(auto&& arg : arguments)
{
auto flags = arg.flags;
if(flags.empty())
flags = {""};
for(auto&& flag : flags)
{
if(arg_map.count(flag) > 0)
{
if(arg.action(*this, arg_map[flag]))
return true;
}
}
}
return false;
}
using string_map = std::unordered_map<std::string, std::vector<std::string>>;
template <class IsKeyword>
static string_map generic_parse(std::vector<std::string> as, IsKeyword is_keyword)
{
string_map result;
std::string flag;
bool clear = false;
for(auto&& x : as)
{
auto k = is_keyword(x);
if(k > 0)
{
flag = x;
result[flag]; // Ensure the flag exists
if(k == 1)
flag = "";
else if(k == 2)
clear = true;
else
clear = false;
}
else
{
result[flag].push_back(x);
if(clear)
flag = "";
clear = false;
}
}
return result;
}
private:
std::vector<argument> arguments;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP
#define MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP
#include "argument_parser.hpp"
#include <migraphx/config.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_map>
#include <utility>
#include <vector>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
inline auto& get_commands()
{
static std::unordered_map<std::string, std::function<void(std::vector<std::string> args)>> m;
return m;
}
template <class T>
std::string compute_command_name()
{
static const std::string& tname = get_type_name<T>();
auto name = tname.substr(tname.rfind("::") + 2);
if(ends_with(name, "_command"))
name = name.substr(0, name.size() - 8);
if(ends_with(name, "_cmd"))
name = name.substr(0, name.size() - 4);
return name;
}
template <class T>
const std::string& command_name()
{
static const std::string& name = compute_command_name<T>();
return name;
}
template <class T>
void run_command(std::vector<std::string> args, bool add_help = false)
{
T x;
argument_parser ap;
if(add_help)
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help());
x.parse(ap);
if(ap.parse(std::move(args)))
return;
x.run();
}
template <class T>
int auto_register_command()
{
auto& m = get_commands();
m[command_name<T>()] = [](std::vector<std::string> args) { run_command<T>(args, true); };
return 0;
}
template <class T>
struct command
{
static int static_register;
// This typedef ensures that the static member will be instantiated if
// the class itself is instantiated
using static_register_type =
std::integral_constant<decltype(&static_register), &static_register>;
};
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
template <class T>
int command<T>::static_register = auto_register_command<T>(); // NOLINT
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
#include "argument_parser.hpp"
#include "command.hpp"
#include "verify.hpp"
#include "perf.hpp"
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
struct loader
{
std::string file;
std::string file_type;
bool is_nhwc = true;
unsigned trim = 0;
void parse(argument_parser& ap)
{
ap(file, {}, ap.metavar("<input file>"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
}
program load()
{
program p;
if(file_type.empty())
{
if(ends_with(file, ".onnx"))
file_type = "onnx";
else if(ends_with(file, ".pb"))
file_type = "tf";
}
std::cout << "Reading: " << file << std::endl;
if(file_type == "onnx")
p = parse_onnx(file);
else if(file_type == "tf")
p = parse_tf(file, is_nhwc);
if(trim > 0)
{
auto last = std::prev(p.end(), trim);
p.remove_instructions(last, p.end());
}
return p;
}
};
struct compiler
{
loader l;
bool gpu = true;
void parse(argument_parser& ap)
{
l.parse(ap);
ap(gpu, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value(true));
ap(gpu, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value(false));
}
program compile()
{
auto p = l.load();
compile_program(p, gpu);
return p;
}
auto params(const program& p) { return create_param_map(p, gpu); }
};
struct read : command<read>
{
loader l;
void parse(argument_parser& ap) { l.parse(ap); }
void run()
{
auto p = l.load();
std::cout << p << std::endl;
}
};
struct verify : command<verify>
{
loader l;
double tolerance = 80;
bool per_instruction = false;
bool reduce = false;
void parse(argument_parser& ap)
{
l.parse(ap);
ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors"));
ap(per_instruction,
{"-i", "--per-instruction"},
ap.help("Verify each instruction"),
ap.set_value(true));
ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true));
}
void run()
{
auto p = l.load();
std::cout << p << std::endl;
if(per_instruction)
{
verify_instructions(p, tolerance);
}
else if(reduce)
{
verify_reduced_program(p, tolerance);
}
else
{
verify_program(l.file, p, tolerance);
}
}
};
struct compile : command<compile>
{
compiler c;
void parse(argument_parser& ap) { c.parse(ap); }
void run()
{
std::cout << "Compiling ... " << std::endl;
auto p = c.compile();
std::cout << p << std::endl;
}
};
struct run_cmd : command<run_cmd>
{
compiler c;
void parse(argument_parser& ap) { c.parse(ap); }
void run()
{
std::cout << "Compiling ... " << std::endl;
auto p = c.compile();
std::cout << "Allocating params ... " << std::endl;
auto m = c.params(p);
p.eval(m);
std::cout << p << std::endl;
}
};
struct perf : command<perf>
{
compiler c;
unsigned n = 100;
void parse(argument_parser& ap)
{
c.parse(ap);
ap(n, {"--iterations", "-n"}, ap.help("Number of iterations to run for perf report"));
}
void run()
{
std::cout << "Compiling ... " << std::endl;
auto p = c.compile();
std::cout << "Allocating params ... " << std::endl;
auto m = c.params(p);
std::cout << "Running performance report ... " << std::endl;
p.perf_report(std::cout, n, m);
}
};
struct main_command
{
static std::string get_command_help()
{
std::string result = "Commands:\n";
return std::accumulate(get_commands().begin(),
get_commands().end(),
result,
[](auto r, auto&& p) { return r + " " + p.first + "\n"; });
}
void parse(argument_parser& ap)
{
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
}
void run() {}
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
using namespace migraphx::driver; // NOLINT
int main(int argc, const char* argv[])
{
std::vector<std::string> args(argv + 1, argv + argc);
if(args.empty())
return 0;
auto&& m = get_commands();
auto cmd = args.front();
if(m.count(cmd) > 0)
{
m.at(cmd)({args.begin() + 1, args.end()});
}
else
{
run_command<main_command>(args);
}
return 0;
}
#include "perf.hpp"
#include <migraphx/cpu/target.hpp>
#include <migraphx/generate.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
program::parameter_map create_param_map(const program& p, bool gpu)
{
program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
#ifdef HAVE_GPU
if(gpu)
m[x.first] = gpu::to_gpu(generate_argument(x.second));
else
#else
(void)gpu;
#endif
m[x.first] = generate_argument(x.second);
}
return m;
}
void compile_program(program& p, bool gpu)
{
if(gpu)
{
#ifdef HAVE_GPU
p.compile(gpu::target{});
#else
MIGRAPHX_THROW("Gpu not supported.");
#endif
}
else
{
p.compile(cpu::target{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_PERF_HPP
#define MIGRAPHX_GUARD_RTGLIB_PERF_HPP
#include <migraphx/program.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
program::parameter_map create_param_map(const program& p, bool gpu = true);
void compile_program(program& p, bool gpu = true);
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
#include "verify.hpp"
#include <migraphx/cpu/target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify_args.hpp>
#include <migraphx/instruction.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
auto get_hash(const T& x)
{
return std::hash<T>{}(x);
}
argument run_cpu(program p)
{
p.compile(cpu::target{});
program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = generate_argument(x.second, get_hash(x.first));
}
auto out = p.eval(m);
std::cout << p << std::endl;
return out;
}
argument run_gpu(program p)
{
#ifdef HAVE_GPU
p.compile(gpu::target{});
program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = gpu::to_gpu(generate_argument(x.second, get_hash(x.first)));
}
auto out = gpu::from_gpu(p.eval(m));
std::cout << p << std::endl;
return gpu::from_gpu(out);
#else
(void)p;
MIGRAPHX_THROW("Gpu unsupported!");
#endif
}
void verify_program(const std::string& name, const program& p, double tolerance)
{
auto x = run_cpu(p);
auto y = run_gpu(p);
verify_args(name, x, y, tolerance);
// std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl;
}
void verify_instructions(const program& prog, double tolerance)
{
for(auto&& ins : prog)
{
if(ins.name().front() == '@')
continue;
if(ins.name() == "broadcast")
continue;
if(ins.name() == "transpose")
continue;
if(ins.name() == "reshape")
continue;
program p;
std::vector<instruction_ref> inputs;
for(auto&& arg : ins.inputs())
{
if(arg->name() == "@literal")
inputs.push_back(p.add_literal(arg->get_literal()));
else
inputs.push_back(p.add_parameter(std::to_string(inputs.size()), arg->get_shape()));
}
p.add_instruction(ins.get_operator(), inputs);
try
{
std::cout << "Verify: " << ins.name() << std::endl;
std::cout << p << std::endl;
verify_program(ins.name(), p, tolerance);
}
catch(...)
{
std::cout << "Instruction " << ins.name() << " threw an exception." << std::endl;
throw;
}
}
}
void verify_reduced(program p, int n, double tolerance)
{
auto last = std::prev(p.end(), n + 1);
p.remove_instructions(last, p.end());
std::cout << "Verify: " << std::endl;
std::cout << p << std::endl;
verify_program(std::to_string(n), p, tolerance);
}
void verify_reduced_program(const program& p, double tolerance)
{
auto n = std::distance(p.begin(), p.end());
for(std::size_t i = 0; i < n; i++)
{
verify_reduced(p, i, tolerance);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP
#include <migraphx/program.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
argument run_cpu(program p);
argument run_gpu(program p);
void verify_program(const std::string& name, const program& p, double tolerance = 100);
void verify_instructions(const program& prog, double tolerance = 80);
void verify_reduced_program(const program& p, double tolerance = 80);
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
...@@ -44,8 +44,6 @@ void eliminate_pad::update_op(T, ...@@ -44,8 +44,6 @@ void eliminate_pad::update_op(T,
std::array<size_t, 2> new_pads{static_cast<size_t>(pads[2]), static_cast<size_t>(pads[3])}; std::array<size_t, 2> new_pads{static_cast<size_t>(pads[2]), static_cast<size_t>(pads[3])};
T op = any_cast<T>(ins->get_operator()); T op = any_cast<T>(ins->get_operator());
if(op.padding_mode != op::padding_mode_t::default_)
return;
op.padding = new_pads; op.padding = new_pads;
std::vector<instruction_ref> new_inputs{ins->inputs()}; std::vector<instruction_ref> new_inputs{ins->inputs()};
......
...@@ -28,8 +28,10 @@ struct binary : op_name<Derived> ...@@ -28,8 +28,10 @@ struct binary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto s1 = args[0].get_shape();
auto s2 = args[1].get_shape();
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().packed() and input2.get_shape().packed()) if(s1 == s2 and input1.get_shape().packed() and input2.get_shape().packed())
{ {
std::transform(input1.begin(), std::transform(input1.begin(),
input1.end(), input1.end(),
......
...@@ -44,8 +44,7 @@ struct convolution ...@@ -44,8 +44,7 @@ struct convolution
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
auto t = input.type(); auto t = input.type();
if(padding_mode == default_)
{
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
...@@ -64,32 +63,6 @@ struct convolution ...@@ -64,32 +63,6 @@ struct convolution
1)), 1)),
}}; }};
} }
else if(padding_mode == same)
{
return {t,
{input.lens()[0],
weights.lens()[0],
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
}
else if(padding_mode == valid)
{
return {
t,
{input.lens()[0],
weights.lens()[0],
static_cast<std::size_t>(std::ceil(
static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
static_cast<std::size_t>(std::ceil(
static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
}
else
{
MIGRAPHX_THROW("Invalid padding mode");
}
}
}; };
} // namespace op } // namespace op
......
...@@ -48,52 +48,22 @@ struct pooling ...@@ -48,52 +48,22 @@ struct pooling
assert(lengths[0] <= (input.lens()[2] + 2 * padding[0])); assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
assert(lengths[1] <= (input.lens()[3] + 2 * padding[1])); assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
if(padding_mode == default_)
{
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
input.lens()[1], input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
floor_divide<std::ptrdiff_t>( floor_divide<std::ptrdiff_t>(input.lens()[2] + 2 * padding[0] - lengths[0],
input.lens()[2] + 2 * padding[0] - lengths[0], stride[0]) + stride[0]) +
1)), 1)),
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
floor_divide<std::ptrdiff_t>( floor_divide<std::ptrdiff_t>(input.lens()[3] + 2 * padding[1] - lengths[1],
input.lens()[3] + 2 * padding[1] - lengths[1], stride[1]) + stride[1]) +
1)), 1)),
}}; }};
} }
else if(padding_mode == same)
{
return {t,
{input.lens()[0],
input.lens()[1],
ceil_divide<std::size_t>(input.lens()[2], stride[0]),
ceil_divide<std::size_t>(input.lens()[3], stride[1])}};
}
else if(padding_mode == valid)
{
return {
t,
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) + 1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) + 1)),
}};
}
else
{
MIGRAPHX_THROW("Invalid padding mode");
}
}
}; };
} // namespace op } // namespace op
......
...@@ -2,13 +2,24 @@ ...@@ -2,13 +2,24 @@
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP #define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <utility> #include <utility>
#include <cstdint>
#include <vector>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
inline std::size_t calculate_padding(std::size_t weight_dim, std::size_t dilation) inline void calculate_padding(int64_t idx,
std::vector<int64_t>& pads,
int64_t input_dim,
int64_t stride,
int64_t dilation,
int64_t weight_dim)
{ {
return (dilation * (weight_dim - 1)) / 2; int64_t output_dim = input_dim / stride;
int64_t pad = std::max(static_cast<int64_t>(0),
(output_dim - 1) * stride + dilation * weight_dim - input_dim);
pads[idx] = pad / 2;
pads[idx + 2] = pad - pad / 2;
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -15,35 +15,18 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT ...@@ -15,35 +15,18 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
template <bool B> template <bool B>
using bool_c = std::integral_constant<bool, B>; using bool_c = std::integral_constant<bool, B>;
template <int N> #define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y
struct requires_enum #define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y)
{
enum e
{
a = 0
};
};
#define MIGRAPHX_REQUIRES_CAT(x, y) x##y #define MIGRAPHX_REQUIRES_VAR() MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__)
#ifdef CPPCHECK #ifdef CPPCHECK
#define MIGRAPHX_REQUIRES(...) class = void #define MIGRAPHX_REQUIRES(...) class = void
#else #else
#if 0
// TODO: This currently crashed on clang
#define MIGRAPHX_REQUIRES(...) \ #define MIGRAPHX_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \ bool MIGRAPHX_REQUIRES_VAR() = true, \
PrivateRequires, \ typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() && (migraphx::and_<__VA_ARGS__>{})), \
__LINE__) = migraphx::requires_enum<__LINE__>::a, \ int>::type = 0
class = typename std::enable_if<and_<__VA_ARGS__, \
MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__) == \
migraphx::requires_enum<__LINE__>::a>{}>::type
#else
#define MIGRAPHX_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \
PrivateRequires, __LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__>{}>::type
#endif
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -52,6 +52,8 @@ inline std::string transform_string(std::string s, F f) ...@@ -52,6 +52,8 @@ inline std::string transform_string(std::string s, F f)
inline std::string to_upper(std::string s) { return transform_string(std::move(s), ::toupper); } inline std::string to_upper(std::string s) { return transform_string(std::move(s), ::toupper); }
inline std::string to_lower(std::string s) { return transform_string(std::move(s), ::tolower); }
inline bool starts_with(const std::string& value, const std::string& prefix) inline bool starts_with(const std::string& value, const std::string& prefix)
{ {
if(prefix.size() > value.size()) if(prefix.size() > value.size())
......
...@@ -19,7 +19,7 @@ rocm_install_targets( ...@@ -19,7 +19,7 @@ rocm_install_targets(
add_executable(read_onnx read_onnx.cpp) add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx) rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx migraphx_onnx) target_link_libraries(read_onnx migraphx_cpu migraphx_onnx)
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
......
...@@ -100,6 +100,7 @@ struct onnx_parser ...@@ -100,6 +100,7 @@ struct onnx_parser
void init_actv_func() void init_actv_func()
{ {
// Support name format of all lower case or the first letter capital
map_actv_funcs.insert(std::make_pair("tanh", op::tanh{})); map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
map_actv_funcs.insert(std::make_pair("relu", op::relu{})); map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{})); map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
...@@ -352,7 +353,8 @@ struct onnx_parser ...@@ -352,7 +353,8 @@ struct onnx_parser
{ {
// insert zeros for pad op (args[0] has 4 dims) // insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}; padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0); l0 = prog.add_instruction(op::pad{padding, std::numeric_limits<float>::lowest()},
l0);
} }
else else
{ {
...@@ -870,7 +872,9 @@ struct onnx_parser ...@@ -870,7 +872,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin()); std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
} }
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) { auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
...@@ -961,7 +965,9 @@ struct onnx_parser ...@@ -961,7 +965,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin()); std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
} }
// need 4 activation functions // need 4 activation functions
...@@ -1088,7 +1094,9 @@ struct onnx_parser ...@@ -1088,7 +1094,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin()); std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
} }
// need 6 activation functions for bidirectional directions // need 6 activation functions for bidirectional directions
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
...@@ -101,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -101,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
t = as.type_enum(); t = as.type_enum();
n = sizeof(as()); n = sizeof(as());
} }
}); });
if(n == 0)
{
MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type" + info.format);
}
auto strides = info.strides; auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t { std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
return n > 0 ? i / n : 0; return n > 0 ? i / n : 0;
......
...@@ -205,16 +205,18 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -205,16 +205,18 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
auto sih_lens = sih->get_shape().lens();
// bias // bias
instruction_ref bb{};
if(bias != prog.end()) if(bias != prog.end())
{ {
long hs = r->get_shape().lens()[2]; long hs = static_cast<long>(r->get_shape().lens()[2]);
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb); auto wrb = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape().lens()}, b); bb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb);
} }
instruction_ref hidden_out = prog.end(); instruction_ref hidden_out = prog.end();
...@@ -228,19 +230,14 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -228,19 +230,14 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw); auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr); auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
instruction_ref ht;
if(bias != prog.end()) if(bias != prog.end())
{ {
ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias); xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb);
}
else
{
ht = xt_ht;
} }
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
// apply activation function // apply activation function
ht = prog.insert_instruction(ins, actv_func, ht); auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
sih = ht; sih = ht;
// add the dimensions of sequence length (axis 0 for sequence length, // add the dimensions of sequence length (axis 0 for sequence length,
...@@ -485,62 +482,41 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -485,62 +482,41 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
long hs = static_cast<long>(r_shape.lens()[2]); long hs = static_cast<long>(r_shape.lens()[2]);
migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]}); migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<int> data(s.elements(), 1); std::vector<float> data(s.elements(), 1.0f);
auto l1 = prog.add_literal(migraphx::literal{s, data}); auto l1 = prog.add_literal(migraphx::literal{s, data});
// weight matrix // w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw); auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);
auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);
// r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr); auto rzr = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz); auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr); auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh); auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// initial states // initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
size_t bs = ih->get_shape().lens()[1];
// bias // bias
instruction_ref brcst_bz{}; instruction_ref bwb{};
instruction_ref brcst_br{}; instruction_ref brb_zr{};
instruction_ref brcst_wbh{}; instruction_ref brb_h{};
instruction_ref brcst_rbh{};
instruction_ref brcst_bh{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto broadcast_lens = sih->get_shape().lens();
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
brcst_bz = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr); auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
brcst_br = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, br); auto rb_h = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brb_zr = prog.insert_instruction(
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh); ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, bh); brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
} }
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
...@@ -549,56 +525,58 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -549,56 +525,58 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz) auto xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw);
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz); auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
if(bias != prog.end()) if(bias != prog.end())
{ {
xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz); xt_w = prog.insert_instruction(ins, op::add{}, xt_w, bwb);
ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr);
} }
auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr); auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr); auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
if(bias != prog.end()) auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr);
{ auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);
xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
} auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
auto rt = prog.insert_instruction(ins, actv_func1, xht_r); auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z);
auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r);
instruction_ref xht_h; instruction_ref hr_h{};
if(linear_before_reset == 0) if(linear_before_reset == 0)
{ {
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih); auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end()) if(bias != prog.end())
{ {
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh); hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh, brb_h);
}
else
{
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
} }
} }
else else
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh); instruction_ref ht1_rh{};
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
if(bias != prog.end()) if(bias != prog.end())
{ {
ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh); ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh, brb_h);
} }
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh); else
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{ {
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh); ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
} }
hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
} }
auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1 // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt); auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
...@@ -913,35 +891,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -913,35 +891,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
migraphx::shape r_shape = r->get_shape(); migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]); long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]); long hs = static_cast<long>(r_shape.lens()[2]);
auto bs = ih->get_shape().lens()[1];
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
// w matrix // w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw); auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw);
auto tran_wi = prog.insert_instruction(ins, op::transpose{perm}, wi);
auto wo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wo = prog.insert_instruction(ins, op::transpose{perm}, wo);
auto wf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wf = prog.insert_instruction(ins, op::transpose{perm}, wf);
auto wc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sw);
auto tran_wc = prog.insert_instruction(ins, op::transpose{perm}, wc);
// r matrix // r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto ri = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr); auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr);
auto tran_ri = prog.insert_instruction(ins, op::transpose{perm}, ri);
auto ro = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_ro = prog.insert_instruction(ins, op::transpose{perm}, ro);
auto rf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rf = prog.insert_instruction(ins, op::transpose{perm}, rf);
auto rc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sr);
auto tran_rc = prog.insert_instruction(ins, op::transpose{perm}, rc);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
...@@ -951,40 +910,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -951,40 +910,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ic_lens = sic->get_shape().lens(); auto ic_lens = sic->get_shape().lens();
// bias // bias
instruction_ref bi_brcst{}; instruction_ref wrb{};
instruction_ref bo_brcst{};
instruction_ref bf_brcst{};
instruction_ref bc_brcst{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto bxi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto ub_wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
auto bhi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); auto ub_rb = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
auto bi = prog.insert_instruction(ins, op::add{}, bxi, bhi); auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb);
bi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bi);
auto bxo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto bho = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
auto bo = prog.insert_instruction(ins, op::add{}, bxo, bho);
bo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bo);
auto bxf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias); wrb = prog.insert_instruction(
auto bhf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias); ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb);
auto bf = prog.insert_instruction(ins, op::add{}, bxf, bhf);
bf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bf);
auto bxc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto bhc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
auto bc = prog.insert_instruction(ins, op::add{}, bxc, bhc);
bc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bc);
} }
// peep hole // peep hole
instruction_ref pphi_brcst{}; instruction_ref pphi_brcst{};
instruction_ref ppho_brcst{}; instruction_ref ppho_brcst{};
instruction_ref pphf_brcst{}; instruction_ref pphf_brcst{};
if(pph != prog.end()) if(pph != prog.end())
{ {
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
...@@ -1004,44 +946,31 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1004,44 +946,31 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi); auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri); auto xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
if(pph != prog.end())
{
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
}
if(bias != prog.end()) if(bias != prog.end())
{ {
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst); xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb);
} }
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf); auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih);
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf); auto ft_before_actv =
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf); prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih);
auto ct_before_actv =
prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih);
if(pph != prog.end()) if(pph != prog.end())
{ {
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic); auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct); ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
} }
if(bias != prog.end()) auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
{
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst);
}
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv); auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
auto xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc);
auto ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc);
auto ct_before_actv = prog.insert_instruction(ins, op::add{}, xt_wc, ht_rc);
if(bias != prog.end())
{
ct_before_actv = prog.insert_instruction(ins, op::add{}, ct_before_actv, bc_brcst);
}
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv); auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct // equation Ct = ft (.) Ct-1 + it (.) ct
...@@ -1050,19 +979,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1050,19 +979,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct); auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
last_cell_output = cellt; last_cell_output = cellt;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
if(pph != prog.end()) if(pph != prog.end())
{ {
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt); auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt); ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
} }
if(bias != prog.end())
{
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst);
}
auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv); auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);
// Ht = ot (.) h(Ct) // Ht = ot (.) h(Ct)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment