"vscode:/vscode.git/clone" did not exist on "a3dd38d9d3bf70c9bf270ccdb7e83c6710ff296e"
Commit d2549384 authored by Khalique's avatar Khalique
Browse files

manual merge

parents 67048d04 ab6cd9d3
#include <migraph/onnx.hpp> #include <migraphx/onnx.hpp>
int main(int argc, char const* argv[]) int main(int argc, char const* argv[])
{ {
if(argc > 1) if(argc > 1)
{ {
std::string file = argv[1]; std::string file = argv[1];
auto prog = migraph::parse_onnx(file); auto prog = migraphx::parse_onnx(file);
std::cout << prog << std::endl; std::cout << prog << std::endl;
} }
} }
#include <migraph/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraph/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraph/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraph/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraphx/generate.hpp>
#include <migraph/verify_args.hpp> #include <migraphx/verify_args.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
template <class T> template <class T>
auto get_hash(const T& x) auto get_hash(const T& x)
...@@ -15,14 +15,14 @@ auto get_hash(const T& x) ...@@ -15,14 +15,14 @@ auto get_hash(const T& x)
} }
template <class F> template <class F>
migraph::argument run_cpu(F f) migraphx::argument run_cpu(F f)
{ {
auto p = f(); auto p = f();
p.compile(migraph::cpu::target{}); p.compile(migraphx::cpu::target{});
migraph::program::parameter_map m; migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
m[x.first] = migraph::generate_argument(x.second, get_hash(x.first)); m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
} }
auto out = p.eval(m); auto out = p.eval(m);
std::cout << p << std::endl; std::cout << p << std::endl;
...@@ -30,19 +30,20 @@ migraph::argument run_cpu(F f) ...@@ -30,19 +30,20 @@ migraph::argument run_cpu(F f)
} }
template <class F> template <class F>
migraph::argument run_gpu(F f) migraphx::argument run_gpu(F f)
{ {
auto p = f(); auto p = f();
p.compile(migraph::gpu::target{}); p.compile(migraphx::gpu::target{});
migraph::program::parameter_map m; migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second, get_hash(x.first))); m[x.first] =
migraphx::gpu::to_gpu(migraphx::generate_argument(x.second, get_hash(x.first)));
} }
auto out = migraph::gpu::from_gpu(p.eval(m)); auto out = migraphx::gpu::from_gpu(p.eval(m));
std::cout << p << std::endl; std::cout << p << std::endl;
return migraph::gpu::from_gpu(out); return migraphx::gpu::from_gpu(out);
} }
template <class F> template <class F>
...@@ -50,12 +51,12 @@ void verify_program(const std::string& name, F f, double tolerance = 100) ...@@ -50,12 +51,12 @@ void verify_program(const std::string& name, F f, double tolerance = 100)
{ {
auto x = run_cpu(f); auto x = run_cpu(f);
auto y = run_gpu(f); auto y = run_gpu(f);
migraph::verify_args(name, x, y, tolerance); migraphx::verify_args(name, x, y, tolerance);
// std::cout << "cpu: " << x << std::endl; // std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl; // std::cout << "gpu: " << y << std::endl;
} }
void verify_instructions(const migraph::program& prog, double tolerance = 80) void verify_instructions(const migraphx::program& prog, double tolerance = 80)
{ {
for(auto&& ins : prog) for(auto&& ins : prog)
{ {
...@@ -68,8 +69,8 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80) ...@@ -68,8 +69,8 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80)
if(ins.name() == "reshape") if(ins.name() == "reshape")
continue; continue;
auto create_program = [&] { auto create_program = [&] {
migraph::program p; migraphx::program p;
std::vector<migraph::instruction_ref> inputs; std::vector<migraphx::instruction_ref> inputs;
for(auto&& arg : ins.inputs()) for(auto&& arg : ins.inputs())
{ {
if(arg->name() == "@literal") if(arg->name() == "@literal")
...@@ -100,8 +101,8 @@ void verify_reduced(F f, int n, double tolerance = 80) ...@@ -100,8 +101,8 @@ void verify_reduced(F f, int n, double tolerance = 80)
{ {
auto create_program = [&] { auto create_program = [&] {
migraph::program p = f(); migraphx::program p = f();
auto last = std::prev(p.end(), n + 1); auto last = std::prev(p.end(), n + 1);
p.remove_instructions(last, p.end()); p.remove_instructions(last, p.end());
return p; return p;
}; };
...@@ -113,9 +114,9 @@ void verify_reduced(F f, int n, double tolerance = 80) ...@@ -113,9 +114,9 @@ void verify_reduced(F f, int n, double tolerance = 80)
template <class F> template <class F>
void verify_reduced_program(F f, double tolerance = 80) void verify_reduced_program(F f, double tolerance = 80)
{ {
migraph::program p = f(); migraphx::program p = f();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
for(int i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
verify_reduced(f, i, tolerance); verify_reduced(f, i, tolerance);
} }
...@@ -127,7 +128,7 @@ int main(int argc, char const* argv[]) ...@@ -127,7 +128,7 @@ int main(int argc, char const* argv[])
if(not args.empty()) if(not args.empty())
{ {
std::string file = args.front(); std::string file = args.front();
auto p = migraph::parse_onnx(file); auto p = migraphx::parse_onnx(file);
std::cout << p << std::endl; std::cout << p << std::endl;
if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-i"; })) if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-i"; }))
...@@ -136,11 +137,11 @@ int main(int argc, char const* argv[]) ...@@ -136,11 +137,11 @@ int main(int argc, char const* argv[])
} }
else if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-r"; })) else if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-r"; }))
{ {
verify_reduced_program([&] { return migraph::parse_onnx(file); }); verify_reduced_program([&] { return migraphx::parse_onnx(file); });
} }
else else
{ {
verify_program(file, [&] { return migraph::parse_onnx(file); }); verify_program(file, [&] { return migraphx::parse_onnx(file); });
} }
} }
} }
#ifndef MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP #define MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/pass_config.hpp> #include <migraphx/pass_config.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
#include <set> #include <set>
#include <list> #include <list>
#include <vector> #include <vector>
#include <queue> #include <queue>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
//#define MIGRAPH_DEBUG_OPT //#define MIGRAPHX_DEBUG_OPT
#ifdef MIGRAPH_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPH_DEBUG(s) s #define MIGRAPHX_DEBUG(s) s
#else #else
#define MIGRAPH_DEBUG(s) #define MIGRAPHX_DEBUG(s)
#endif // MIGRAPH_DEBUG_OPT #endif // MIGRAPHX_DEBUG_OPT
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP #endif // MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraph/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include "memory_coloring_impl.hpp" #include "memory_coloring_impl.hpp"
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(program& p) const void memory_coloring::apply(program& p) const
{ {
if(!enabled(MIGRAPH_DISABLE_MEMORY_COLORING{})) if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{ {
memory_coloring_impl opt(&p, allocation_op, verify); memory_coloring_impl opt(&p, allocation_op, verify);
opt.run(); opt.run();
} }
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include "memory_coloring_impl.hpp" #include "memory_coloring_impl.hpp"
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring_impl::run() void memory_coloring_impl::run()
{ {
MIGRAPH_DEBUG(dump("---Before memory coloring---")); MIGRAPHX_DEBUG(dump("---Before memory coloring---"));
MIGRAPH_DEBUG(dump_program()); MIGRAPHX_DEBUG(dump_program());
build(); build();
if(num_of_lives != 0) if(num_of_lives != 0)
{ {
MIGRAPH_DEBUG(dump_intervals()); MIGRAPHX_DEBUG(dump_intervals());
// Coloring // Coloring
while(!alloc_queue.empty()) while(!alloc_queue.empty())
{ {
...@@ -85,7 +85,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -85,7 +85,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
conflict_queue.pop(); conflict_queue.pop();
} }
segment.offset = offset; segment.offset = offset;
MIGRAPH_DEBUG(segment.dump()); MIGRAPHX_DEBUG(segment.dump());
required_bytes = std::max(required_bytes, offset + segment.size); required_bytes = std::max(required_bytes, offset + segment.size);
return true; return true;
} }
...@@ -218,8 +218,8 @@ void memory_coloring_impl::rewrite() ...@@ -218,8 +218,8 @@ void memory_coloring_impl::rewrite()
} }
} }
} }
MIGRAPH_DEBUG(dump("---After rewrite---")); MIGRAPHX_DEBUG(dump("---After rewrite---"));
MIGRAPH_DEBUG(dump_program()); MIGRAPHX_DEBUG(dump_program());
} }
void memory_coloring_impl::verify() void memory_coloring_impl::verify()
...@@ -235,7 +235,7 @@ void memory_coloring_impl::verify() ...@@ -235,7 +235,7 @@ void memory_coloring_impl::verify()
{ {
// TODO: This check breaks on the tests // TODO: This check breaks on the tests
// if(!interval.is_live_on_entry) // if(!interval.is_live_on_entry)
// MIGRAPH_THROW("interval is not live on entry"); // MIGRAPHX_THROW("interval is not live on entry");
continue; continue;
} }
...@@ -253,14 +253,14 @@ void memory_coloring_impl::verify() ...@@ -253,14 +253,14 @@ void memory_coloring_impl::verify()
if(range->offset == invalid_offset) if(range->offset == invalid_offset)
continue; continue;
if(!is_disjoin(*range, segment)) if(!is_disjoin(*range, segment))
MIGRAPH_THROW("range and segment is not disjoined"); MIGRAPHX_THROW("range and segment is not disjoined");
} }
} }
} }
} }
} }
#ifdef MIGRAPH_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; } void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; }
...@@ -334,5 +334,5 @@ void live_interval::dump() ...@@ -334,5 +334,5 @@ void live_interval::dump()
#endif #endif
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#ifndef MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP #define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
static const int invalid_offset = -1; static const int invalid_offset = -1;
...@@ -15,7 +15,7 @@ struct live_range ...@@ -15,7 +15,7 @@ struct live_range
long long offset; // offset to base pointer of allocated memory trunk. long long offset; // offset to base pointer of allocated memory trunk.
int vn; // value number that identifies this live_range. int vn; // value number that identifies this live_range.
long long size; // size of required memory in bytes long long size; // size of required memory in bytes
#ifdef MIGRAPH_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
void dump(); void dump();
#endif #endif
}; };
...@@ -35,7 +35,7 @@ struct live_interval ...@@ -35,7 +35,7 @@ struct live_interval
int get_end() const { return segment.end; } int get_end() const { return segment.end; }
long long get_offset() const { return segment.offset; } long long get_offset() const { return segment.offset; }
#ifdef MIGRAPH_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
void dump(); void dump();
#endif #endif
...@@ -84,7 +84,7 @@ struct memory_coloring_impl ...@@ -84,7 +84,7 @@ struct memory_coloring_impl
{ {
return is_param(ins) && any_cast<builtin::param>(ins->get_operator()).parameter == "output"; return is_param(ins) && any_cast<builtin::param>(ins->get_operator()).parameter == "output";
} }
bool is_allocate(const instruction_ref ins) { return ins->name() == allocation_op; } bool is_allocate(const instruction_ref ins) const { return ins->name() == allocation_op; }
static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; } static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; }
static bool is_literal(const instruction_ref ins) { return ins->name() == "@literal"; } static bool is_literal(const instruction_ref ins) { return ins->name() == "@literal"; }
static bool is_check_context(const instruction_ref ins) static bool is_check_context(const instruction_ref ins)
...@@ -101,7 +101,7 @@ struct memory_coloring_impl ...@@ -101,7 +101,7 @@ struct memory_coloring_impl
return ((end1 < range2.offset) || (end2 < range1.offset)); return ((end1 < range2.offset) || (end2 < range1.offset));
} }
void verify(); void verify();
#ifdef MIGRAPH_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
void dump(const std::string&); void dump(const std::string&);
void dump_program(); void dump_program();
void dump_intervals(); void dump_intervals();
...@@ -154,6 +154,6 @@ struct memory_coloring_impl ...@@ -154,6 +154,6 @@ struct memory_coloring_impl
bool enable_verify; bool enable_verify;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/env.hpp> #include <migraphx/env.hpp>
#include <migraph/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraph/time.hpp> #include <migraphx/time.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_COMPILE) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_EVAL) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl struct program_impl
{ {
...@@ -271,6 +271,8 @@ instruction_ref program::end() const { return impl->instructions.end(); } ...@@ -271,6 +271,8 @@ instruction_ref program::end() const { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().get_shape(); } shape program::get_shape() const { return impl->instructions.back().get_shape(); }
context& program::get_context() const { return impl->ctx; }
instruction_ref program::validate() const instruction_ref program::validate() const
{ {
return std::find_if(impl->instructions.begin(), return std::find_if(impl->instructions.begin(),
...@@ -282,7 +284,7 @@ void program::compile(const target& t, tracer trace) ...@@ -282,7 +284,7 @@ void program::compile(const target& t, tracer trace)
{ {
assert(this->validate() == impl->instructions.end()); assert(this->validate() == impl->instructions.end());
this->impl->ctx = t.get_context(); this->impl->ctx = t.get_context();
if(enabled(MIGRAPH_TRACE_COMPILE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}))
trace = tracer{std::cout}; trace = tracer{std::cout};
trace(*this); trace(*this);
trace(); trace();
...@@ -297,8 +299,8 @@ void program::compile(const target& t, tracer trace) ...@@ -297,8 +299,8 @@ void program::compile(const target& t, tracer trace)
if(invalid != impl->instructions.end()) if(invalid != impl->instructions.end())
{ {
auto index = std::distance(impl->instructions.begin(), invalid); auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " + MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name()); std::to_string(index) + ": " + invalid->name());
} }
trace(); trace();
#endif #endif
...@@ -307,7 +309,16 @@ void program::compile(const target& t, tracer trace) ...@@ -307,7 +309,16 @@ void program::compile(const target& t, tracer trace)
if(invalid != impl->instructions.end()) if(invalid != impl->instructions.end())
{ {
auto index = std::distance(impl->instructions.begin(), invalid); auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW("Invalid program from compilation at instruction " + std::to_string(index)); MIGRAPHX_THROW("Invalid program from compilation at instruction " + std::to_string(index));
}
this->finalize();
}
void program::finalize()
{
for(auto ins : iterator_for(*this))
{
ins->finalize(this->impl->ctx);
} }
} }
...@@ -334,7 +345,7 @@ argument generic_eval(const program& p, ...@@ -334,7 +345,7 @@ argument generic_eval(const program& p,
auto param_name = auto param_name =
any_cast<builtin::param>(ins->get_operator()).parameter; any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name)) if(not contains(params, param_name))
MIGRAPH_THROW("Parameter not found: " + param_name); MIGRAPHX_THROW("Parameter not found: " + param_name);
return params.at(param_name); return params.at(param_name);
})); }));
} }
...@@ -361,7 +372,7 @@ argument generic_eval(const program& p, ...@@ -361,7 +372,7 @@ argument generic_eval(const program& p,
argument program::eval(std::unordered_map<std::string, argument> params) const argument program::eval(std::unordered_map<std::string, argument> params) const
{ {
if(enabled(MIGRAPH_TRACE_EVAL{})) if(enabled(MIGRAPHX_TRACE_EVAL{}))
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
return generic_eval(*this, this->impl->ctx, std::move(params), [&](auto& ins, auto f) { return generic_eval(*this, this->impl->ctx, std::move(params), [&](auto& ins, auto f) {
...@@ -501,5 +512,5 @@ std::ostream& operator<<(std::ostream& os, const program& p) ...@@ -501,5 +512,5 @@ std::ostream& operator<<(std::ostream& os, const program& p)
return os; return os;
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/shape.hpp> #include <migraphx/shape.hpp>
#include <migraph/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <numeric> #include <numeric>
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct shape_impl struct shape_impl
{ {
...@@ -169,12 +169,12 @@ std::string shape::type_string() const ...@@ -169,12 +169,12 @@ std::string shape::type_string() const
{ {
switch(this->type()) switch(this->type())
{ {
#define MIGRAPH_SHAPE_TYPE_STRING_CASE(x, t) \ #define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE(x, t) \
case x: return #x; case x: return #x;
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_TYPE_STRING_CASE) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE)
#undef MIGRAPH_SHAPE_TYPE_STRING_CASE #undef MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE
} }
MIGRAPH_THROW("Invalid type"); MIGRAPHX_THROW("Invalid type");
} }
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
...@@ -191,5 +191,5 @@ std::ostream& operator<<(std::ostream& os, const shape& x) ...@@ -191,5 +191,5 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
return os; return os;
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraph/literal.hpp> #include <migraphx/literal.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct find_add_lit_broadcast struct find_add_lit_broadcast
{ {
...@@ -61,5 +61,5 @@ struct find_add_lit_broadcast ...@@ -61,5 +61,5 @@ struct find_add_lit_broadcast
void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); } void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/ranges.hpp> #include <migraphx/ranges.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool is_reshaper(const std::string& name) // Reshapers that can't handle nonstandard input shapes
bool is_nonstandard_reshaper(instruction_ref ins)
{
// clang-format off
static const std::unordered_set<std::string> names = {
"reshape"
};
// clang-format on
return contains(names, ins->name()) and ins->inputs().front()->name() == "contiguous";
}
bool is_reshaper(instruction_ref ins)
{ {
// clang-format off // clang-format off
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
...@@ -19,26 +30,27 @@ bool is_reshaper(const std::string& name) ...@@ -19,26 +30,27 @@ bool is_reshaper(const std::string& name)
"contiguous" "contiguous"
}; };
// clang-format on // clang-format on
return contains(names, name); return contains(names, ins->name()) and not is_nonstandard_reshaper(ins);
} }
void simplify_reshapes::apply(program& p) const void simplify_reshapes::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(not is_reshaper(ins->name())) if(not is_reshaper(ins))
continue; continue;
if(ins->outputs().size() != 1) if(ins->outputs().size() != 1)
continue; continue;
if(is_reshaper(ins->outputs().front()->name())) if(is_reshaper(ins->outputs().front()))
continue; continue;
// Gather reshapes // Gather reshapes
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()->name())) while(is_reshaper(reshapes.back()))
{ {
assert(!reshapes.back()->inputs().empty()); assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front())); assert(p.has_instruction(reshapes.back()->inputs().front()));
reshapes.push_back(reshapes.back()->inputs().front()); auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
} }
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()}; std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
...@@ -58,7 +70,14 @@ void simplify_reshapes::apply(program& p) const ...@@ -58,7 +70,14 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(r.first, r.second); p.replace_instruction(r.first, r.second);
} }
} }
// Replace all reshapes with as_shape
for(auto ins : iterator_for(p))
{
if(ins->name() != "reshape")
continue;
p.replace_instruction(ins, op::as_shape{ins->get_shape()}, ins->inputs());
}
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
add_library(migraph_cpu add_library(migraphx_cpu
target.cpp target.cpp
lowering.cpp lowering.cpp
gemm.cpp gemm.cpp
) )
set_target_properties(migraph_cpu PROPERTIES EXPORT_NAME cpu) set_target_properties(migraphx_cpu PROPERTIES EXPORT_NAME cpu)
find_path(BLAZE_INCLUDE blaze/Blaze.h) find_path(BLAZE_INCLUDE blaze/Blaze.h)
find_package(Threads) find_package(Threads)
rocm_clang_tidy_check(migraph_cpu) rocm_clang_tidy_check(migraphx_cpu)
target_link_libraries(migraph_cpu migraph Threads::Threads) target_link_libraries(migraphx_cpu migraphx Threads::Threads)
target_include_directories(migraph_cpu PRIVATE ${BLAZE_INCLUDE}) target_include_directories(migraphx_cpu PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraph_cpu PRIVATE -DBLAZE_USE_CPP_THREADS) target_compile_definitions(migraphx_cpu PRIVATE -DBLAZE_USE_CPP_THREADS)
rocm_install_targets( rocm_install_targets(
TARGETS migraph_cpu TARGETS migraphx_cpu
INCLUDE INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/include
) )
......
#include <migraph/cpu/gemm.hpp> #include <migraphx/cpu/gemm.hpp>
#include <migraph/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraph/requires.hpp> #include <migraphx/requires.hpp>
#include <blaze/math/CustomMatrix.h> #include <blaze/math/CustomMatrix.h>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
template <class T> template <class T>
...@@ -94,5 +94,5 @@ void migemm( ...@@ -94,5 +94,5 @@ void migemm(
} }
} // namespace cpu } // namespace cpu
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#ifndef MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
#define MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace cpu {
struct context
{
void finish() const {}
};
} // namespace cpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_CPU_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_CPU_TARGET_HPP
#include <migraph/program.hpp>
#include <migraph/cpu/context.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace cpu {
struct target
{
std::string name() const;
std::vector<pass> get_passes(migraph::context& ctx) const;
migraph::context get_context() const { return context{}; }
};
} // namespace cpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
struct context
{
void finish() const {}
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_CPU_GEMM_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_CPU_GEMM_HPP
#define MIGRAPH_GUARD_RTGLIB_CPU_GEMM_HPP #define MIGRAPHX_GUARD_RTGLIB_CPU_GEMM_HPP
#include <migraph/argument.hpp> #include <migraphx/argument.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
void migemm( void migemm(
const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta); const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta);
} // namespace cpu } // namespace cpu
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_CPU_LOWERING_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#define MIGRAPH_GUARD_RTGLIB_CPU_LOWERING_HPP #define MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
struct lowering struct lowering
...@@ -15,7 +15,7 @@ struct lowering ...@@ -15,7 +15,7 @@ struct lowering
}; };
} // namespace cpu } // namespace cpu
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_CPU_TARGET_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_CPU_TARGET_HPP
#include <migraphx/program.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
struct target
{
std::string name() const;
std::vector<pass> get_passes(migraphx::context& ctx) const;
migraphx::context get_context() const { return context{}; }
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraph/cpu/lowering.hpp> #include <migraphx/cpu/lowering.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/cpu/gemm.hpp> #include <migraphx/par_dfor.hpp>
#include <migraphx/cpu/gemm.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
template <typename T> template <typename T>
...@@ -19,6 +20,14 @@ T zero(const T&) ...@@ -19,6 +20,14 @@ T zero(const T&)
return T(0); return T(0);
} }
template <class T>
typename std::conditional_t<std::is_integral<T>{}, std::make_signed<T>, std::enable_if<true, T>>::
type
make_signed(T x)
{
return x;
}
// //
// cpu implemenataion of batch norm for inference // cpu implemenataion of batch norm for inference
// //
...@@ -64,7 +73,7 @@ struct cpu_batch_norm_inference ...@@ -64,7 +73,7 @@ struct cpu_batch_norm_inference
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)( visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) { [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width)( par_dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c) + epsilon) > 0); assert((variance(c) + epsilon) > 0);
result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) / result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) /
...@@ -79,7 +88,7 @@ struct cpu_batch_norm_inference ...@@ -79,7 +88,7 @@ struct cpu_batch_norm_inference
visit_all(output, input, mini_batch_mean, mini_batch_mean, arg_gamma, arg_bias)( visit_all(output, input, mini_batch_mean, mini_batch_mean, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) { [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width)( par_dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c, h, w) + epsilon) > 0); assert((variance(c, h, w) + epsilon) > 0);
result(n, c, h, w) = gamma(c, h, w) * result(n, c, h, w) = gamma(c, h, w) *
...@@ -141,28 +150,33 @@ struct cpu_convolution ...@@ -141,28 +150,33 @@ struct cpu_convolution
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) { visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in_h = input.get_shape().lens()[2]; auto in = input.get_shape().lens();
auto in_w = input.get_shape().lens()[3]; auto in_h = in[2];
auto in_w = in[3];
auto wei_c = weights.get_shape().lens()[1];
auto wei_h = weights.get_shape().lens()[2]; auto wei = weights.get_shape().lens();
auto wei_w = weights.get_shape().lens()[3]; auto wei_n = wei[0];
auto wei_c = wei[1];
dfor(output_shape.lens()[0], auto wei_h = wei[2];
output_shape.lens()[1], auto wei_w = wei[3];
output_shape.lens()[2],
output_shape.lens()[3])( par_dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2],
output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) { [&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x = i * op.stride[0] - op.padding[0]; const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1]; const int start_y = j * op.stride[1] - op.padding[1];
const int group_id = w / (wei_n / op.group);
double acc = 0; double acc = 0;
dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) { dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) {
const int in_x = start_x + x; const int in_x = start_x + x;
const int in_y = start_y + y; const int in_y = start_y + y;
const int in_ch = group_id * wei_c + k;
if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w) if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
{ {
acc += input(o, k, in_x, in_y) * weights(w, k, x, y); acc += input(o, in_ch, in_x, in_y) * weights(w, k, x, y);
} }
}); });
output(o, w, i, j) = acc; output(o, w, i, j) = acc;
...@@ -195,7 +209,8 @@ struct cpu_im2col ...@@ -195,7 +209,8 @@ struct cpu_im2col
const std::size_t& stride_h = op.stride[0]; const std::size_t& stride_h = op.stride[0];
const std::size_t& stride_w = op.stride[1]; const std::size_t& stride_w = op.stride[1];
int kdiv2_h, kdiv2_w; int kdiv2_h;
int kdiv2_w;
kdiv2_h = kernel_h / 2; kdiv2_h = kernel_h / 2;
kdiv2_w = kernel_w / 2; kdiv2_w = kernel_w / 2;
// calculate output sizes // calculate output sizes
...@@ -268,10 +283,10 @@ struct cpu_pooling ...@@ -268,10 +283,10 @@ struct cpu_pooling
auto in_h = input.get_shape().lens()[2]; auto in_h = input.get_shape().lens()[2];
auto in_w = input.get_shape().lens()[3]; auto in_w = input.get_shape().lens()[3];
dfor(output_shape.lens()[0], par_dfor(output_shape.lens()[0],
output_shape.lens()[1], output_shape.lens()[1],
output_shape.lens()[2], output_shape.lens()[2],
output_shape.lens()[3])( output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) { [&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x0 = i * op.stride[0] - op.padding[0]; const int start_x0 = i * op.stride[0] - op.padding[0];
const int start_y0 = j * op.stride[1] - op.padding[1]; const int start_y0 = j * op.stride[1] - op.padding[1];
...@@ -320,34 +335,43 @@ struct cpu_contiguous ...@@ -320,34 +335,43 @@ struct cpu_contiguous
} }
}; };
struct cpu_concat struct cpu_pad
{ {
op::concat op; op::pad op;
std::string name() const { return "cpu::concat"; } std::string name() const { return "cpu::contiguous"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
assert(output_shape.standard());
argument result{output_shape}; argument result{output_shape};
std::vector<std::size_t> coffsets = op.compute_offsets(output_shape, args); result.visit([&](auto output) { std::fill(output.begin(), output.end(), op.value); });
for(std::size_t l = 0; l < args.size(); l++)
{ visit_all(result, args[0])([&](auto output, auto input) {
auto argl = args[l]; shape_for_each(input.get_shape(), [&](const auto& idx) {
std::size_t nelements = argl.get_shape().elements(); std::vector<std::size_t> new_idx(idx.size());
visit_all(result, argl)([&](auto output, auto input) { std::transform(
auto slice_shape = idx.begin(), idx.end(), op.pads.begin(), new_idx.begin(), [](auto i, auto j) {
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()}; return i + j;
auto slice = make_view(slice_shape, output.data() + coffsets[l]); });
// cppcheck-suppress useStlAlgorithm output(new_idx.begin(), new_idx.end()) = input(idx.begin(), idx.end());
for(std::size_t i = 0; i < nelements; i++)
{
slice[i] = input[i];
}
}); });
} });
return result; return result;
} }
}; };
struct cpu_concat
{
op::concat op;
std::string name() const { return "cpu::concat"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
return op.compute(output_shape, std::move(args));
}
};
struct cpu_gemm struct cpu_gemm
{ {
op::dot op; op::dot op;
...@@ -362,6 +386,18 @@ struct cpu_gemm ...@@ -362,6 +386,18 @@ struct cpu_gemm
} }
}; };
struct cpu_gather
{
op::gather op;
std::string name() const { return "cpu::gather"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
return op.compute(output_shape, std::move(args));
}
};
struct identity_op struct identity_op
{ {
std::string name() const { return "cpu::identity"; } std::string name() const { return "cpu::identity"; }
...@@ -376,7 +412,7 @@ struct abs_op ...@@ -376,7 +412,7 @@ struct abs_op
std::string name() const { return "cpu::abs"; } std::string name() const { return "cpu::abs"; }
auto fcn() const auto fcn() const
{ {
return [](auto x) { return std::abs(x); }; return [](auto x) { return std::abs(make_signed(x)); };
} }
}; };
...@@ -389,6 +425,15 @@ struct exp_op ...@@ -389,6 +425,15 @@ struct exp_op
} }
}; };
struct log_op
{
std::string name() const { return "cpu::log"; }
auto fcn() const
{
return [](auto x) { return std::log(x); };
}
};
struct sin_op struct sin_op
{ {
std::string name() const { return "cpu::sin"; } std::string name() const { return "cpu::sin"; }
...@@ -443,6 +488,24 @@ struct atan_op ...@@ -443,6 +488,24 @@ struct atan_op
} }
}; };
struct sinh_op
{
std::string name() const { return "cpu::sinh"; }
auto fcn() const
{
return [](auto x) { return std::sinh(x); };
}
};
struct cosh_op
{
std::string name() const { return "cpu::cosh"; }
auto fcn() const
{
return [](auto x) { return std::cosh(x); };
}
};
struct tanh_op struct tanh_op
{ {
std::string name() const { return "cpu::tanh"; } std::string name() const { return "cpu::tanh"; }
...@@ -490,6 +553,17 @@ struct leaky_relu_op ...@@ -490,6 +553,17 @@ struct leaky_relu_op
} }
}; };
struct elu_op
{
op::elu op;
std::string name() const { return "cpu::elu"; }
auto fcn() const
{
auto& a = op.alpha;
return [a](auto x) { return x > 0 ? x : a * std::expm1(x); };
}
};
template <typename Op> template <typename Op>
struct cpu_unary struct cpu_unary
{ {
...@@ -582,6 +656,24 @@ struct div_op ...@@ -582,6 +656,24 @@ struct div_op
} }
}; };
struct max_op
{
std::string name() const { return "max"; }
auto fcn() const
{
return [](auto x, auto y) { return std::max(x, y); };
}
};
struct min_op
{
std::string name() const { return "min"; }
auto fcn() const
{
return [](auto x, auto y) { return std::min(x, y); };
}
};
template <typename Op> template <typename Op>
struct cpu_binary struct cpu_binary
{ {
...@@ -635,21 +727,33 @@ struct cpu_apply ...@@ -635,21 +727,33 @@ struct cpu_apply
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>(); apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>(); apply_map["concat"] = extend_op<cpu_concat, op::concat>();
apply_map["gather"] = extend_op<cpu_gather, op::gather>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>(); apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>(); apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
apply_map["abs"] = simple_op<cpu_unary<abs_op>>();
apply_map["sinh"] = simple_op<cpu_unary<sinh_op>>();
apply_map["cosh"] = simple_op<cpu_unary<cosh_op>>();
apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>(); apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>();
apply_map["sigmoid"] = simple_op<cpu_unary<sigmoid_op>>(); apply_map["sigmoid"] = simple_op<cpu_unary<sigmoid_op>>();
apply_map["exp"] = simple_op<cpu_unary<exp_op>>(); apply_map["exp"] = simple_op<cpu_unary<exp_op>>();
apply_map["log"] = simple_op<cpu_unary<log_op>>();
apply_map["neg"] = simple_op<cpu_unary<neg_op>>(); apply_map["neg"] = simple_op<cpu_unary<neg_op>>();
apply_map["sin"] = simple_op<cpu_unary<sin_op>>(); apply_map["sin"] = simple_op<cpu_unary<sin_op>>();
apply_map["cos"] = simple_op<cpu_unary<cos_op>>(); apply_map["cos"] = simple_op<cpu_unary<cos_op>>();
apply_map["tan"] = simple_op<cpu_unary<tan_op>>(); apply_map["tan"] = simple_op<cpu_unary<tan_op>>();
apply_map["asin"] = simple_op<cpu_unary<asin_op>>();
apply_map["acos"] = simple_op<cpu_unary<acos_op>>();
apply_map["atan"] = simple_op<cpu_unary<atan_op>>();
apply_map["relu"] = simple_op<cpu_unary<relu_op>>(); apply_map["relu"] = simple_op<cpu_unary<relu_op>>();
apply_map["add"] = simple_op<cpu_binary<add_op>>(); apply_map["add"] = simple_op<cpu_binary<add_op>>();
apply_map["sub"] = simple_op<cpu_binary<sub_op>>(); apply_map["sub"] = simple_op<cpu_binary<sub_op>>();
apply_map["mul"] = simple_op<cpu_binary<mul_op>>(); apply_map["mul"] = simple_op<cpu_binary<mul_op>>();
apply_map["div"] = simple_op<cpu_binary<div_op>>(); apply_map["div"] = simple_op<cpu_binary<div_op>>();
apply_map["max"] = simple_op<cpu_binary<max_op>>();
apply_map["min"] = simple_op<cpu_binary<min_op>>();
apply_map["softmax"] = simple_op<softmax2d>(); apply_map["softmax"] = simple_op<softmax2d>();
} }
...@@ -696,5 +800,5 @@ struct cpu_apply ...@@ -696,5 +800,5 @@ struct cpu_apply
void lowering::apply(program& p) const { cpu_apply{&p}.apply(); } void lowering::apply(program& p) const { cpu_apply{&p}.apply(); }
} // namespace cpu } // namespace cpu
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraph/cpu/lowering.hpp> #include <migraphx/cpu/lowering.hpp>
#include <migraph/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
std::string target::name() const { return "cpu"; } std::string target::name() const { return "cpu"; }
std::vector<pass> target::get_passes(migraph::context&) const std::vector<pass> target::get_passes(migraphx::context&) const
{ {
return {auto_contiguous{}, lowering{}}; return {auto_contiguous{}, lowering{}};
} }
} // namespace cpu } // namespace cpu
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment