Commit bc5d7f75 authored by Paul's avatar Paul
Browse files

Merge from develop

parents 47c0854d a5b0afa0
#include <migraph/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraph/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraph/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraphx/generate.hpp>
#include <migraph/verify.hpp> #include <migraphx/verify.hpp>
migraph::program::parameter_map create_param_map(const migraph::program& p, bool gpu = true) migraphx::program::parameter_map create_param_map(const migraphx::program& p, bool gpu = true)
{ {
migraph::program::parameter_map m; migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
if(gpu) if(gpu)
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second)); m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second));
else else
m[x.first] = migraph::generate_argument(x.second); m[x.first] = migraphx::generate_argument(x.second);
} }
return m; return m;
} }
...@@ -25,9 +25,9 @@ int main(int argc, char const* argv[]) ...@@ -25,9 +25,9 @@ int main(int argc, char const* argv[])
{ {
std::string file = argv[1]; std::string file = argv[1];
std::size_t n = argc > 2 ? std::stoul(argv[2]) : 50; std::size_t n = argc > 2 ? std::stoul(argv[2]) : 50;
auto p = migraph::parse_onnx(file); auto p = migraphx::parse_onnx(file);
std::cout << "Compiling ... " << std::endl; std::cout << "Compiling ... " << std::endl;
p.compile(migraph::gpu::target{}); p.compile(migraphx::gpu::target{});
std::cout << "Allocating params ... " << std::endl; std::cout << "Allocating params ... " << std::endl;
auto m = create_param_map(p); auto m = create_param_map(p);
std::cout << "Running performance report ... " << std::endl; std::cout << "Running performance report ... " << std::endl;
......
#include <migraph/onnx.hpp> #include <migraphx/onnx.hpp>
int main(int argc, char const* argv[]) int main(int argc, char const* argv[])
{ {
if(argc > 1) if(argc > 1)
{ {
std::string file = argv[1]; std::string file = argv[1];
auto prog = migraph::parse_onnx(file); auto prog = migraphx::parse_onnx(file);
std::cout << prog << std::endl; std::cout << prog << std::endl;
} }
} }
#include <migraph/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraph/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraph/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraph/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraphx/generate.hpp>
#include <migraph/verify_args.hpp> #include <migraphx/verify_args.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
template <class T> template <class T>
auto get_hash(const T& x) auto get_hash(const T& x)
...@@ -15,14 +15,14 @@ auto get_hash(const T& x) ...@@ -15,14 +15,14 @@ auto get_hash(const T& x)
} }
template <class F> template <class F>
migraph::argument run_cpu(F f) migraphx::argument run_cpu(F f)
{ {
auto p = f(); auto p = f();
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
migraph::program::parameter_map m; migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
m[x.first] = migraph::generate_argument(x.second, get_hash(x.first)); m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
} }
auto out = p.eval(m); auto out = p.eval(m);
std::cout << p << std::endl; std::cout << p << std::endl;
...@@ -30,19 +30,20 @@ migraph::argument run_cpu(F f) ...@@ -30,19 +30,20 @@ migraph::argument run_cpu(F f)
} }
template <class F> template <class F>
migraph::argument run_gpu(F f) migraphx::argument run_gpu(F f)
{ {
auto p = f(); auto p = f();
p.compile(migraph::gpu::target{}); p.compile(migraphx::gpu::target{});
migraph::program::parameter_map m; migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second, get_hash(x.first))); m[x.first] =
migraphx::gpu::to_gpu(migraphx::generate_argument(x.second, get_hash(x.first)));
} }
auto out = migraph::gpu::from_gpu(p.eval(m)); auto out = migraphx::gpu::from_gpu(p.eval(m));
std::cout << p << std::endl; std::cout << p << std::endl;
return migraph::gpu::from_gpu(out); return migraphx::gpu::from_gpu(out);
} }
template <class F> template <class F>
...@@ -50,12 +51,12 @@ void verify_program(const std::string& name, F f, double tolerance = 100) ...@@ -50,12 +51,12 @@ void verify_program(const std::string& name, F f, double tolerance = 100)
{ {
auto x = run_cpu(f); auto x = run_cpu(f);
auto y = run_gpu(f); auto y = run_gpu(f);
migraph::verify_args(name, x, y, tolerance); migraphx::verify_args(name, x, y, tolerance);
// std::cout << "cpu: " << x << std::endl; // std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl; // std::cout << "gpu: " << y << std::endl;
} }
void verify_instructions(const migraph::program& prog, double tolerance = 80) void verify_instructions(const migraphx::program& prog, double tolerance = 80)
{ {
for(auto&& ins : prog) for(auto&& ins : prog)
{ {
...@@ -68,8 +69,8 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80) ...@@ -68,8 +69,8 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80)
if(ins.name() == "reshape") if(ins.name() == "reshape")
continue; continue;
auto create_program = [&] { auto create_program = [&] {
migraph::program p; migraphx::program p;
std::vector<migraph::instruction_ref> inputs; std::vector<migraphx::instruction_ref> inputs;
for(auto&& arg : ins.inputs()) for(auto&& arg : ins.inputs())
{ {
if(arg->name() == "@literal") if(arg->name() == "@literal")
...@@ -100,8 +101,8 @@ void verify_reduced(F f, int n, double tolerance = 80) ...@@ -100,8 +101,8 @@ void verify_reduced(F f, int n, double tolerance = 80)
{ {
auto create_program = [&] { auto create_program = [&] {
migraph::program p = f(); migraphx::program p = f();
auto last = std::prev(p.end(), n + 1); auto last = std::prev(p.end(), n + 1);
p.remove_instructions(last, p.end()); p.remove_instructions(last, p.end());
return p; return p;
}; };
...@@ -113,9 +114,9 @@ void verify_reduced(F f, int n, double tolerance = 80) ...@@ -113,9 +114,9 @@ void verify_reduced(F f, int n, double tolerance = 80)
template <class F> template <class F>
void verify_reduced_program(F f, double tolerance = 80) void verify_reduced_program(F f, double tolerance = 80)
{ {
migraph::program p = f(); migraphx::program p = f();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
for(int i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
verify_reduced(f, i, tolerance); verify_reduced(f, i, tolerance);
} }
...@@ -127,7 +128,7 @@ int main(int argc, char const* argv[]) ...@@ -127,7 +128,7 @@ int main(int argc, char const* argv[])
if(not args.empty()) if(not args.empty())
{ {
std::string file = args.front(); std::string file = args.front();
auto p = migraph::parse_onnx(file); auto p = migraphx::parse_onnx(file);
std::cout << p << std::endl; std::cout << p << std::endl;
if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-i"; })) if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-i"; }))
...@@ -136,11 +137,11 @@ int main(int argc, char const* argv[]) ...@@ -136,11 +137,11 @@ int main(int argc, char const* argv[])
} }
else if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-r"; })) else if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-r"; }))
{ {
verify_reduced_program([&] { return migraph::parse_onnx(file); }); verify_reduced_program([&] { return migraphx::parse_onnx(file); });
} }
else else
{ {
verify_program(file, [&] { return migraph::parse_onnx(file); }); verify_program(file, [&] { return migraphx::parse_onnx(file); });
} }
} }
} }
#ifndef MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP #define MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/pass_config.hpp> #include <migraphx/pass_config.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
#include <set> #include <set>
#include <list> #include <list>
#include <vector> #include <vector>
#include <queue> #include <queue>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
//#define MIGRAPH_DEBUG_OPT //#define MIGRAPHX_DEBUG_OPT
#ifdef MIGRAPH_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPH_DEBUG(s) s #define MIGRAPHX_DEBUG(s) s
#else #else
#define MIGRAPH_DEBUG(s) #define MIGRAPHX_DEBUG(s)
#endif // MIGRAPH_DEBUG_OPT #endif // MIGRAPHX_DEBUG_OPT
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP #endif // MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraph/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include "memory_coloring_impl.hpp" #include "memory_coloring_impl.hpp"
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(program& p) const void memory_coloring::apply(program& p) const
{ {
if(!enabled(MIGRAPH_DISABLE_MEMORY_COLORING{})) if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{ {
memory_coloring_impl opt(&p, allocation_op, verify); memory_coloring_impl opt(&p, allocation_op, verify);
opt.run(); opt.run();
} }
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include "memory_coloring_impl.hpp" #include "memory_coloring_impl.hpp"
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring_impl::run() void memory_coloring_impl::run()
{ {
MIGRAPH_DEBUG(dump("---Before memory coloring---")); MIGRAPHX_DEBUG(dump("---Before memory coloring---"));
MIGRAPH_DEBUG(dump_program()); MIGRAPHX_DEBUG(dump_program());
build(); build();
if(num_of_lives != 0) if(num_of_lives != 0)
{ {
MIGRAPH_DEBUG(dump_intervals()); MIGRAPHX_DEBUG(dump_intervals());
// Coloring // Coloring
while(!alloc_queue.empty()) while(!alloc_queue.empty())
{ {
...@@ -85,7 +85,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -85,7 +85,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
conflict_queue.pop(); conflict_queue.pop();
} }
segment.offset = offset; segment.offset = offset;
MIGRAPH_DEBUG(segment.dump()); MIGRAPHX_DEBUG(segment.dump());
required_bytes = std::max(required_bytes, offset + segment.size); required_bytes = std::max(required_bytes, offset + segment.size);
return true; return true;
} }
...@@ -118,11 +118,11 @@ void memory_coloring_impl::build() ...@@ -118,11 +118,11 @@ void memory_coloring_impl::build()
live_range& range = def_interval->segment; live_range& range = def_interval->segment;
def_interval->result = iter->get_shape(); def_interval->result = iter->get_shape();
def_interval->is_literal = is_lit; def_interval->is_literal = is_lit;
range.begin = cur_points;
def_interval->def_point = cur_points;
range.size = (iter->get_shape()).bytes();
if(!is_lit || unify_literals) if(!is_lit || unify_literals)
alloc_queue.push(def_interval); alloc_queue.push(def_interval);
range.begin = cur_points;
def_interval->def_point = cur_points;
range.size = (iter->get_shape()).bytes();
live_set.erase(range.vn); live_set.erase(range.vn);
} }
} }
...@@ -217,8 +217,8 @@ void memory_coloring_impl::rewrite() ...@@ -217,8 +217,8 @@ void memory_coloring_impl::rewrite()
} }
} }
} }
MIGRAPH_DEBUG(dump("---After rewrite---")); MIGRAPHX_DEBUG(dump("---After rewrite---"));
MIGRAPH_DEBUG(dump_program()); MIGRAPHX_DEBUG(dump_program());
} }
void memory_coloring_impl::verify() void memory_coloring_impl::verify()
...@@ -232,9 +232,8 @@ void memory_coloring_impl::verify() ...@@ -232,9 +232,8 @@ void memory_coloring_impl::verify()
if(segment.begin == invalid_offset) if(segment.begin == invalid_offset)
{ {
// TODO: This check breaks on the tests if(!interval.is_live_on_entry)
// if(!interval.is_live_on_entry) MIGRAPHX_THROW("interval is not live on entry");
// MIGRAPH_THROW("interval is not live on entry");
continue; continue;
} }
...@@ -252,14 +251,14 @@ void memory_coloring_impl::verify() ...@@ -252,14 +251,14 @@ void memory_coloring_impl::verify()
if(range->offset == invalid_offset) if(range->offset == invalid_offset)
continue; continue;
if(!is_disjoin(*range, segment)) if(!is_disjoin(*range, segment))
MIGRAPH_THROW("range and segment is not disjoined"); MIGRAPHX_THROW("range and segment is not disjoined");
} }
} }
} }
} }
} }
#ifdef MIGRAPH_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; } void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; }
...@@ -333,5 +332,5 @@ void live_interval::dump() ...@@ -333,5 +332,5 @@ void live_interval::dump()
#endif #endif
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#ifndef MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP #define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
static const int invalid_offset = -1; static const int invalid_offset = -1;
...@@ -15,7 +15,7 @@ struct live_range ...@@ -15,7 +15,7 @@ struct live_range
long long offset; // offset to base pointer of allocated memory trunk. long long offset; // offset to base pointer of allocated memory trunk.
int vn; // value number that identifies this live_range. int vn; // value number that identifies this live_range.
long long size; // size of required memory in bytes long long size; // size of required memory in bytes
#ifdef MIGRAPH_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
void dump(); void dump();
#endif #endif
}; };
...@@ -35,7 +35,7 @@ struct live_interval ...@@ -35,7 +35,7 @@ struct live_interval
int get_end() const { return segment.end; } int get_end() const { return segment.end; }
long long get_offset() const { return segment.offset; } long long get_offset() const { return segment.offset; }
#ifdef MIGRAPH_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
void dump(); void dump();
#endif #endif
...@@ -84,7 +84,7 @@ struct memory_coloring_impl ...@@ -84,7 +84,7 @@ struct memory_coloring_impl
{ {
return is_param(ins) && any_cast<builtin::param>(ins->get_operator()).parameter == "output"; return is_param(ins) && any_cast<builtin::param>(ins->get_operator()).parameter == "output";
} }
bool is_allocate(const instruction_ref ins) { return ins->name() == allocation_op; } bool is_allocate(const instruction_ref ins) const { return ins->name() == allocation_op; }
static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; } static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; }
static bool is_literal(const instruction_ref ins) { return ins->name() == "@literal"; } static bool is_literal(const instruction_ref ins) { return ins->name() == "@literal"; }
static bool is_check_context(const instruction_ref ins) static bool is_check_context(const instruction_ref ins)
...@@ -101,7 +101,7 @@ struct memory_coloring_impl ...@@ -101,7 +101,7 @@ struct memory_coloring_impl
return ((end1 < range2.offset) || (end2 < range1.offset)); return ((end1 < range2.offset) || (end2 < range1.offset));
} }
void verify(); void verify();
#ifdef MIGRAPH_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
void dump(const std::string&); void dump(const std::string&);
void dump_program(); void dump_program();
void dump_intervals(); void dump_intervals();
...@@ -154,6 +154,6 @@ struct memory_coloring_impl ...@@ -154,6 +154,6 @@ struct memory_coloring_impl
bool enable_verify; bool enable_verify;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/env.hpp> #include <migraphx/operators.hpp>
#include <migraph/ranges.hpp> #include <migraphx/env.hpp>
#include <migraph/time.hpp> #include <migraphx/ranges.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_COMPILE) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_EVAL) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl struct program_impl
{ {
...@@ -134,6 +135,12 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re ...@@ -134,6 +135,12 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert(has_instruction(ins)); assert(has_instruction(ins));
assert(has_instruction(rep)); assert(has_instruction(rep));
assert(ins != rep); assert(ins != rep);
if(ins == std::prev(this->end()))
{
return replace_instruction(ins, op::identity{}, rep);
}
// TODO: Should it be an error if the output is empty? // TODO: Should it be an error if the output is empty?
if(ins->outputs().empty()) if(ins->outputs().empty())
{ {
...@@ -271,6 +278,8 @@ instruction_ref program::end() const { return impl->instructions.end(); } ...@@ -271,6 +278,8 @@ instruction_ref program::end() const { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().get_shape(); } shape program::get_shape() const { return impl->instructions.back().get_shape(); }
context& program::get_context() const { return impl->ctx; }
instruction_ref program::validate() const instruction_ref program::validate() const
{ {
return std::find_if(impl->instructions.begin(), return std::find_if(impl->instructions.begin(),
...@@ -282,7 +291,7 @@ void program::compile(const target& t, tracer trace) ...@@ -282,7 +291,7 @@ void program::compile(const target& t, tracer trace)
{ {
assert(this->validate() == impl->instructions.end()); assert(this->validate() == impl->instructions.end());
this->impl->ctx = t.get_context(); this->impl->ctx = t.get_context();
if(enabled(MIGRAPH_TRACE_COMPILE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}))
trace = tracer{std::cout}; trace = tracer{std::cout};
trace(*this); trace(*this);
trace(); trace();
...@@ -297,8 +306,8 @@ void program::compile(const target& t, tracer trace) ...@@ -297,8 +306,8 @@ void program::compile(const target& t, tracer trace)
if(invalid != impl->instructions.end()) if(invalid != impl->instructions.end())
{ {
auto index = std::distance(impl->instructions.begin(), invalid); auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " + MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name()); std::to_string(index) + ": " + invalid->name());
} }
trace(); trace();
#endif #endif
...@@ -307,7 +316,16 @@ void program::compile(const target& t, tracer trace) ...@@ -307,7 +316,16 @@ void program::compile(const target& t, tracer trace)
if(invalid != impl->instructions.end()) if(invalid != impl->instructions.end())
{ {
auto index = std::distance(impl->instructions.begin(), invalid); auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW("Invalid program from compilation at instruction " + std::to_string(index)); MIGRAPHX_THROW("Invalid program from compilation at instruction " + std::to_string(index));
}
this->finalize();
}
void program::finalize()
{
for(auto ins : iterator_for(*this))
{
ins->finalize(this->impl->ctx);
} }
} }
...@@ -334,7 +352,7 @@ argument generic_eval(const program& p, ...@@ -334,7 +352,7 @@ argument generic_eval(const program& p,
auto param_name = auto param_name =
any_cast<builtin::param>(ins->get_operator()).parameter; any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name)) if(not contains(params, param_name))
MIGRAPH_THROW("Parameter not found: " + param_name); MIGRAPHX_THROW("Parameter not found: " + param_name);
return params.at(param_name); return params.at(param_name);
})); }));
} }
...@@ -361,20 +379,31 @@ argument generic_eval(const program& p, ...@@ -361,20 +379,31 @@ argument generic_eval(const program& p,
argument program::eval(std::unordered_map<std::string, argument> params) const argument program::eval(std::unordered_map<std::string, argument> params) const
{ {
if(enabled(MIGRAPH_TRACE_EVAL{})) auto& ctx = this->impl->ctx;
#ifndef NDEBUG
auto sctx = ctx;
auto check_context = [&](auto f) {
assert(is_shared(ctx, sctx));
auto x = f();
sctx = ctx;
return x;
};
#else
auto check_context = [](auto f) { return f(); };
#endif
if(enabled(MIGRAPHX_TRACE_EVAL{}))
{ {
auto& ctx = this->impl->ctx; return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) {
return generic_eval(*this, this->impl->ctx, std::move(params), [&](auto& ins, auto f) {
ctx.finish(); ctx.finish();
std::cout << "Run instruction: "; std::cout << "Run instruction: ";
this->debug_print(ins); this->debug_print(ins);
return f(); return check_context(f);
}); });
} }
else else
{ {
return generic_eval( return generic_eval(
*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); }); *this, ctx, std::move(params), [&](auto&, auto f) { return check_context(f); });
} }
} }
...@@ -428,8 +457,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -428,8 +457,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
overhead_vec.reserve(n); overhead_vec.reserve(n);
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
overhead_vec.push_back(time<milliseconds>( overhead_vec.push_back(time<milliseconds>([&] { dry_run(params); }));
[&] { generic_eval(*this, ctx, params, [](auto...) { return argument{}; }); }));
} }
double total_time = common_average(total_vec); double total_time = common_average(total_vec);
...@@ -493,6 +521,12 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const ...@@ -493,6 +521,12 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const
std::cout << std::endl; std::cout << std::endl;
} }
void program::dry_run(std::unordered_map<std::string, argument> params) const
{
auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; });
}
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); } bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p) std::ostream& operator<<(std::ostream& os, const program& p)
...@@ -501,5 +535,5 @@ std::ostream& operator<<(std::ostream& os, const program& p) ...@@ -501,5 +535,5 @@ std::ostream& operator<<(std::ostream& os, const program& p)
return os; return os;
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
if(MIGRAPHX_ENABLE_PYTHON)
find_program(DEFAULT_PYTHON_EXE python)
if(DEFAULT_PYTHON_EXE)
set(PYTHON_EXECUTABLE ${DEFAULT_PYTHON_EXE} CACHE PATH "Path to python executable")
endif()
find_package(pybind11 REQUIRED)
pybind11_add_module(migraphx_py migraphx_py.cpp)
set_target_properties(migraphx_py PROPERTIES
OUTPUT_NAME migraphx
C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden
)
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu)
if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(migraphx_py PRIVATE migraphx_gpu)
target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU)
endif()
endif()
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
namespace py = pybind11;
template <class F>
struct throw_half
{
F f;
template <class A>
void operator()(A a) const
{
f(a);
}
void operator()(migraphx::shape::as<migraphx::half>) const
{
throw std::runtime_error("Half not supported in python yet.");
}
};
template <class F>
struct skip_half
{
F f;
template <class A>
void operator()(A a) const
{
f(a);
}
void operator()(migraphx::shape::as<migraphx::half>) const {}
};
template <class F>
void visit_type(const migraphx::shape& s, F f)
{
s.visit_type(throw_half<F>{f});
}
template <class F>
void visit_types(F f)
{
migraphx::shape::visit_types(skip_half<F>{f});
}
template <class T>
py::buffer_info to_buffer_info(T& x)
{
migraphx::shape s = x.get_shape();
py::buffer_info b;
visit_type(s, [&](auto as) {
b = py::buffer_info(x.data(),
as.size(),
py::format_descriptor<decltype(as())>::format(),
s.lens().size(),
s.lens(),
s.strides());
});
return b;
}
migraphx::shape to_shape(const py::buffer_info& info)
{
migraphx::shape::type_t t;
visit_types([&](auto as) {
if(info.format == py::format_descriptor<decltype(as())>::format())
t = as.type_enum();
});
return migraphx::shape{t, info.shape, info.strides};
}
PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
.def(py::init<>())
.def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides)
.def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes)
.def("type_size", &migraphx::shape::type_size)
.def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed)
.def("broadcasted", &migraphx::shape::broadcasted)
.def("standard", &migraphx::shape::standard)
.def("scalar", &migraphx::shape::scalar)
.def("__eq__", std::equal_to<migraphx::shape>{})
.def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def("__init__",
[](migraphx::argument& x, py::buffer b) {
py::buffer_info info = b.request();
new(&x) migraphx::argument(to_shape(info), info.ptr);
})
.def("__eq__", std::equal_to<migraphx::argument>{})
.def("__ne__", std::not_equal_to<migraphx::argument>{})
.def("__repr__", [](const migraphx::argument& x) { return migraphx::to_string(x); });
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::program>(m, "program")
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_shape", &migraphx::program::get_shape)
.def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); })
.def("run", &migraphx::program::eval)
.def("__eq__", std::equal_to<migraphx::program>{})
.def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
m.def("parse_onnx", &migraphx::parse_onnx);
m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu")
return migraphx::cpu::target{};
#ifdef HAVE_GPU
if(name == "gpu")
return migraphx::gpu::target{};
#endif
throw std::runtime_error("Target not found: " + name);
});
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false);
m.def("from_gpu", &migraphx::gpu::from_gpu);
m.def("gpu_sync", &migraphx::gpu::gpu_sync);
m.def("copy_to_gpu", &migraphx::gpu::copy_to_gpu);
#endif
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
#else
m.attr("__version__") = "dev";
#endif
}
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const
{
for(auto ins : iterator_for(prog))
{
if(ins->name() == "rnn")
{
apply_vanilla_rnn(prog, ins);
}
if(ins->name() == "gru")
{
apply_gru(prog, ins);
}
}
}
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
{
assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0);
auto actv_funcs = vanilla_rnn_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn_direction dicrt = rnn_op.direction;
instruction_ref last_output{};
if(dicrt == op::rnn_direction::bidirectional)
{
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = vanilla_rnn_cell(true,
prog,
ins,
args[0],
w_forward,
r_forward,
bias_forward,
ih_forward,
actv_funcs.at(0));
auto ret_reverse = vanilla_rnn_cell(false,
prog,
ins,
args[0],
w_reverse,
r_reverse,
bias_reverse,
ih_reverse,
actv_funcs.at(1));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
{
bool is_forward = (dicrt == op::rnn_direction::forward);
// input weight matrix
auto w = args[1];
// hidden state weight matrix
auto r = args[2];
// process bias and initial hidden state
instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
}
// process intial hidden state
instruction_ref ih;
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih = args[5];
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret =
vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
// search its output to find if there are rnn_last_output operator
// while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
}
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const
{
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
// squeeze and transpose r
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
if(bias != prog.end())
{
long hs = r->get_shape().lens()[2];
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
}
instruction_ref hidden_out = prog.end();
instruction_ref last_out{};
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
std::size_t seq_len = input->get_shape().lens()[0];
for(std::size_t i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
instruction_ref ht;
if(bias != prog.end())
{
ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
}
else
{
ht = xt_ht;
}
// apply activation function
ht = prog.insert_instruction(ins, actv_func, ht);
sih = ht;
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
// output inserted
if(i < seq_len - 1)
{
if(is_forward)
{
hidden_out =
(seq_index == 0)
? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
}
else
{
hidden_out =
(seq_index == seq_len - 1)
? last_out
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
}
}
}
return {hidden_out, last_out};
}
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
{
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
// could be 3 to 6 inputs, but the parse_gru function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have any num of arguments
// when writing their program.
if(rnn_op.direction == op::rnn_direction::bidirectional)
{
if(rnn_op.actv_funcs.empty())
{
// default is tanh
return {op::tanh{}, op::tanh{}};
}
else if(rnn_op.actv_funcs.size() == 1)
{
return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
}
else
{
return rnn_op.actv_funcs;
}
}
else
{
if(rnn_op.actv_funcs.empty())
{
// default is tanh
return {op::tanh{}};
}
else
{
return rnn_op.actv_funcs;
}
}
}
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{
assert(ins->name() == "gru");
const auto actv_funcs = gru_actv_funcs(ins);
// could be 3 to 6 inputs, but the parse_gru function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0.0);
auto gru_op = any_cast<op::gru>(ins->get_operator());
op::rnn_direction dicrt = gru_op.direction;
instruction_ref last_output{};
if(dicrt == op::rnn_direction::bidirectional)
{
// w weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// r weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// intial hidden state
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = gru_cell(true,
prog,
ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
auto ret_reverse = gru_cell(false,
prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
gru_op.linear_before_reset,
actv_funcs.at(2),
actv_funcs.at(3));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten
// from gru operator is a concat
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
{
bool is_forward = (dicrt == op::rnn_direction::forward);
// weight matrix
auto w = args[1];
auto r = args[2];
// bias
instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
}
// intial hidden state
instruction_ref ih{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih = args[5];
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret = gru_cell(is_forward,
prog,
ins,
{args[0], w, r, bias, ih},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
// replace the corresponding rnn_last_output instruction
// with the last_output, if rnn_last_output exists
// while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
}
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2) const
{
assert(inputs.size() == 5);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto bias = inputs.at(3);
auto ih = inputs.at(4);
instruction_ref hidden_states = prog.end();
instruction_ref last_output{};
migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]);
migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<int> data(s.elements(), 1);
auto l1 = prog.add_literal(migraphx::literal{s, data});
// weight matrix
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);
auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
instruction_ref brcst_bz{};
instruction_ref brcst_br{};
instruction_ref brcst_wbh{};
instruction_ref brcst_rbh{};
instruction_ref brcst_bh{};
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br);
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
}
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
if(bias != prog.end())
{
xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz);
}
auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
if(bias != prog.end())
{
xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
}
auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
instruction_ref xht_h;
if(linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
if(bias != prog.end())
{
ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh);
}
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh);
}
}
auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih);
sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
if(i < seq_len - 1)
{
if(is_forward)
{
hidden_states =
(seq_index == 0)
? last_output
: prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
}
else
{
hidden_states =
(seq_index == seq_len - 1)
? last_output
: prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
}
}
}
return {hidden_states, last_output};
}
std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
auto gru_op = any_cast<op::gru>(ins->get_operator());
// before rewrite the gru operator, need to ensure
// we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions
if(gru_op.direction == op::rnn_direction::bidirectional)
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0)};
else if(gru_op.actv_funcs.size() == 2)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1)};
else if(gru_op.actv_funcs.size() == 3)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1),
gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(0)};
else
return gru_op.actv_funcs;
}
else
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}};
else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
else
return gru_op.actv_funcs;
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraph/shape.hpp> #include <migraphx/shape.hpp>
#include <migraph/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <numeric> #include <numeric>
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct shape_impl struct shape_impl
{ {
...@@ -169,12 +169,12 @@ std::string shape::type_string() const ...@@ -169,12 +169,12 @@ std::string shape::type_string() const
{ {
switch(this->type()) switch(this->type())
{ {
#define MIGRAPH_SHAPE_TYPE_STRING_CASE(x, t) \ #define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE(x, t) \
case x: return #x; case x: return #x;
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_TYPE_STRING_CASE) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE)
#undef MIGRAPH_SHAPE_TYPE_STRING_CASE #undef MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE
} }
MIGRAPH_THROW("Invalid type"); MIGRAPHX_THROW("Invalid type");
} }
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
...@@ -191,5 +191,5 @@ std::ostream& operator<<(std::ostream& os, const shape& x) ...@@ -191,5 +191,5 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
return os; return os;
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraph/literal.hpp> #include <migraphx/literal.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct find_add_lit_broadcast struct find_add_lit_broadcast
{ {
...@@ -61,5 +61,5 @@ struct find_add_lit_broadcast ...@@ -61,5 +61,5 @@ struct find_add_lit_broadcast
void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); } void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/ranges.hpp> #include <migraphx/ranges.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool is_reshaper(const std::string& name) bool is_reshaper(instruction_ref ins)
{ {
// clang-format off // clang-format off
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
"reshape", "reshape",
"transpose",
// "broadcast",
"contiguous" "contiguous"
}; };
// clang-format on // clang-format on
return contains(names, name); return contains(names, ins->name());
}
bool is_transpose_output(instruction_ref ins)
{
if(ins->outputs().size() != 1)
return false;
if(ins->outputs().front()->name() == "contiguous")
return is_transpose_output(ins->outputs().front());
return ins->outputs().front()->name() == "transpose";
}
instruction_ref find_transpose_input(instruction_ref ins)
{
if(ins->inputs().size() != 1)
return ins;
if(ins->inputs().front()->name() == "contiguous")
return find_transpose_input(ins->inputs().front());
if(ins->inputs().front()->name() == "transpose")
return ins->inputs().front();
return ins;
} }
void simplify_reshapes::apply(program& p) const void simplify_reshapes::apply(program& p) const
{ {
auto end = std::prev(p.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(not is_reshaper(ins->name())) if(ins->outputs().empty() and ins != end)
continue;
if(ins->outputs().size() != 1)
continue;
if(is_reshaper(ins->outputs().front()->name()))
continue; continue;
// Gather reshapes if(is_reshaper(ins))
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()->name()))
{ {
assert(!reshapes.back()->inputs().empty()); if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper))
assert(p.has_instruction(reshapes.back()->inputs().front())); continue;
reshapes.push_back(reshapes.back()->inputs().front()); // Gather reshapes
} std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()}; std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
for(auto start : iterator_for(reshapes)) for(auto start : iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{ {
r = std::make_pair(*start, *last); auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
break; return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{
p.replace_instruction(r.first, r.second);
} }
} }
if(r.first != r.second) else if(ins->name() == "transpose")
{ {
p.replace_instruction(r.first, r.second); if(is_transpose_output(ins))
continue;
auto x = ins;
auto t = ins;
do
{
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
continue;
p.replace_instruction(ins, t->inputs().front());
} }
} }
// Replace all reshapes with as_shape
for(auto ins : iterator_for(p))
{
if(ins->name() != "reshape")
continue;
p.replace_instruction(ins, op::as_shape{ins->get_shape()}, ins->inputs());
}
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
add_library(migraph_cpu add_library(migraphx_cpu
target.cpp target.cpp
lowering.cpp lowering.cpp
gemm.cpp gemm.cpp
) )
set_target_properties(migraph_cpu PROPERTIES EXPORT_NAME cpu) set_target_properties(migraphx_cpu PROPERTIES EXPORT_NAME cpu)
find_path(BLAZE_INCLUDE blaze/Blaze.h) find_path(BLAZE_INCLUDE blaze/Blaze.h)
find_package(Threads) find_package(Threads)
rocm_clang_tidy_check(migraph_cpu) rocm_clang_tidy_check(migraphx_cpu)
target_link_libraries(migraph_cpu migraph Threads::Threads) target_link_libraries(migraphx_cpu migraphx Threads::Threads)
target_include_directories(migraph_cpu PRIVATE ${BLAZE_INCLUDE}) target_include_directories(migraphx_cpu PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraph_cpu PRIVATE -DBLAZE_USE_CPP_THREADS) target_compile_definitions(migraphx_cpu PRIVATE -DBLAZE_USE_CPP_THREADS)
rocm_install_targets( rocm_install_targets(
TARGETS migraph_cpu TARGETS migraphx_cpu
INCLUDE INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/include
) )
......
#include <migraph/cpu/gemm.hpp> #include <migraphx/cpu/gemm.hpp>
#include <migraph/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraph/requires.hpp> #include <migraphx/requires.hpp>
#include <blaze/math/CustomMatrix.h> #include <blaze/math/CustomMatrix.h>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
template <class T> template <class T>
...@@ -94,5 +94,5 @@ void migemm( ...@@ -94,5 +94,5 @@ void migemm(
} }
} // namespace cpu } // namespace cpu
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#ifndef MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
#define MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace cpu {
struct context
{
void finish() const {}
};
} // namespace cpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_CPU_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_CPU_TARGET_HPP
#include <migraph/program.hpp>
#include <migraph/cpu/context.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace cpu {
struct target
{
std::string name() const;
std::vector<pass> get_passes(migraph::context& ctx) const;
migraph::context get_context() const { return context{}; }
};
} // namespace cpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
struct context
{
void finish() const {}
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_CPU_GEMM_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_CPU_GEMM_HPP
#define MIGRAPH_GUARD_RTGLIB_CPU_GEMM_HPP #define MIGRAPHX_GUARD_RTGLIB_CPU_GEMM_HPP
#include <migraph/argument.hpp> #include <migraphx/argument.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
void migemm( void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta); const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta);
} // namespace cpu } // namespace cpu
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment