Commit 2a1ae4df authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into padding_fix

parents 5d556dbd 66bae091
...@@ -22,7 +22,7 @@ find_package(ROCM REQUIRED) ...@@ -22,7 +22,7 @@ find_package(ROCM REQUIRED)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 0.2) rocm_setup_version(VERSION 0.3)
option( BUILD_SHARED_LIBS "Build as a shared library" ON ) option( BUILD_SHARED_LIBS "Build as a shared library" ON )
...@@ -39,8 +39,6 @@ else() ...@@ -39,8 +39,6 @@ else()
set(MIGRAPHX_ENABLE_GPU Off CACHE BOOL "") set(MIGRAPHX_ENABLE_GPU Off CACHE BOOL "")
endif() endif()
set(MIGRAPHX_ENABLE_TF Off CACHE BOOL "")
add_compile_options(-std=c++14) add_compile_options(-std=c++14)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
......
...@@ -40,6 +40,7 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU ...@@ -40,6 +40,7 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU
set(PACKAGE_DEPENDS) set(PACKAGE_DEPENDS)
add_subdirectory(driver)
add_subdirectory(onnx) add_subdirectory(onnx)
add_subdirectory(tf) add_subdirectory(tf)
......
add_executable(driver main.cpp verify.cpp perf.cpp)
rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_cpu migraphx_onnx migraphx_tf)
if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(driver migraphx_gpu)
target_compile_definitions(driver PRIVATE -DHAVE_GPU)
endif()
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP
#include <algorithm>
#include <functional>
#include <iostream>
#include <set>
#include <string>
#include <sstream>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef MIGRAPHX_USE_CLANG_TIDY
#define MIGRAPHX_DRIVER_STATIC
#else
#define MIGRAPHX_DRIVER_STATIC static
#endif
template <class T>
struct value_parser
{
template <MIGRAPHX_REQUIRES(not std::is_enum<T>{})>
static T apply(const std::string& x)
{
T result;
std::stringstream ss;
ss.str(x);
ss >> result;
if(ss.fail())
throw std::runtime_error("Failed to parse: " + x);
return result;
}
template <MIGRAPHX_REQUIRES(std::is_enum<T>{})>
static T apply(const std::string& x)
{
std::ptrdiff_t i;
std::stringstream ss;
ss.str(x);
ss >> i;
if(ss.fail())
throw std::runtime_error("Failed to parse: " + x);
return static_cast<T>(i);
}
};
struct argument_parser
{
struct argument
{
std::vector<std::string> flags;
std::function<bool(argument_parser&, const std::vector<std::string>&)> action{};
std::string type = "";
std::string help = "";
std::string metavar = "";
std::string default_value = "";
unsigned nargs = 1;
};
template <class T, class... Fs>
void operator()(T& x, const std::vector<std::string>& flags, Fs... fs)
{
arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) {
if(params.empty())
throw std::runtime_error("Flag with no value.");
x = value_parser<T>::apply(params.back());
return false;
}});
argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>();
arg.default_value = to_string(x);
migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
}
template <class... Fs>
void operator()(std::nullptr_t x, std::vector<std::string> flags, Fs... fs)
{
arguments.push_back({std::move(flags)});
argument& arg = arguments.back();
arg.type = "";
arg.nargs = 0;
migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
}
MIGRAPHX_DRIVER_STATIC auto nargs(unsigned n = 1)
{
return [=](auto&&, auto& arg) { arg.nargs = n; };
}
template <class F>
MIGRAPHX_DRIVER_STATIC auto write_action(F f)
{
return [=](auto& x, auto& arg) {
arg.action = [&, f](auto& self, const std::vector<std::string>& params) {
f(self, x, params);
return false;
};
};
}
template <class F>
MIGRAPHX_DRIVER_STATIC auto do_action(F f)
{
return [=](auto&, auto& arg) {
arg.nargs = 0;
arg.action = [&, f](auto& self, const std::vector<std::string>&) {
f(self);
return true;
};
};
}
MIGRAPHX_DRIVER_STATIC auto append()
{
return write_action([](auto&, auto& x, auto& params) {
using type = typename decltype(params)::value_type;
std::transform(params.begin(),
params.end(),
std::inserter(x, x.end()),
[](std::string y) { return value_parser<type>::apply(y); });
});
}
MIGRAPHX_DRIVER_STATIC auto show_help(const std::string& msg = "")
{
return do_action([=](auto& self) {
for(auto&& arg : self.arguments)
{
std::cout << std::endl;
std::string prefix = " ";
if(arg.flags.empty())
{
std::cout << prefix;
std::cout << arg.metavar;
}
for(const std::string& a : arg.flags)
{
std::cout << prefix;
std::cout << a;
prefix = ", ";
}
if(not arg.type.empty())
{
std::cout << " [" << arg.type << "]";
if(not arg.default_value.empty())
std::cout << " (Default: " << arg.default_value << ")";
}
std::cout << std::endl;
std::cout << " " << arg.help << std::endl;
}
std::cout << std::endl;
if(not msg.empty())
std::cout << msg << std::endl;
});
}
MIGRAPHX_DRIVER_STATIC auto help(const std::string& help)
{
return [=](auto&, auto& arg) { arg.help = help; };
}
MIGRAPHX_DRIVER_STATIC auto metavar(const std::string& metavar)
{
return [=](auto&, auto& arg) { arg.metavar = metavar; };
}
template <class T>
MIGRAPHX_DRIVER_STATIC auto set_value(T value)
{
return [=](auto& x, auto& arg) {
arg.nargs = 0;
arg.type = "";
arg.action = [&, value](auto&, const std::vector<std::string>&) {
x = value;
return false;
};
};
}
bool parse(std::vector<std::string> args)
{
std::unordered_map<std::string, unsigned> keywords;
for(auto&& arg : arguments)
{
for(auto&& flag : arg.flags)
keywords[flag] = arg.nargs + 1;
}
auto arg_map =
generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; });
for(auto&& arg : arguments)
{
auto flags = arg.flags;
if(flags.empty())
flags = {""};
for(auto&& flag : flags)
{
if(arg_map.count(flag) > 0)
{
if(arg.action(*this, arg_map[flag]))
return true;
}
}
}
return false;
}
using string_map = std::unordered_map<std::string, std::vector<std::string>>;
template <class IsKeyword>
static string_map generic_parse(std::vector<std::string> as, IsKeyword is_keyword)
{
string_map result;
std::string flag;
bool clear = false;
for(auto&& x : as)
{
auto k = is_keyword(x);
if(k > 0)
{
flag = x;
result[flag]; // Ensure the flag exists
if(k == 1)
flag = "";
else if(k == 2)
clear = true;
else
clear = false;
}
else
{
result[flag].push_back(x);
if(clear)
flag = "";
clear = false;
}
}
return result;
}
private:
std::vector<argument> arguments;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP
#define MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP
#include "argument_parser.hpp"
#include <migraphx/config.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_map>
#include <utility>
#include <vector>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
inline auto& get_commands()
{
static std::unordered_map<std::string, std::function<void(std::vector<std::string> args)>> m;
return m;
}
template <class T>
std::string compute_command_name()
{
static const std::string& tname = get_type_name<T>();
auto name = tname.substr(tname.rfind("::") + 2);
if(ends_with(name, "_command"))
name = name.substr(0, name.size() - 8);
if(ends_with(name, "_cmd"))
name = name.substr(0, name.size() - 4);
return name;
}
template <class T>
const std::string& command_name()
{
static const std::string& name = compute_command_name<T>();
return name;
}
template <class T>
void run_command(std::vector<std::string> args, bool add_help = false)
{
T x;
argument_parser ap;
if(add_help)
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help());
x.parse(ap);
if(ap.parse(std::move(args)))
return;
x.run();
}
template <class T>
int auto_register_command()
{
auto& m = get_commands();
m[command_name<T>()] = [](std::vector<std::string> args) { run_command<T>(args, true); };
return 0;
}
template <class T>
struct command
{
static int static_register;
// This typedef ensures that the static member will be instantiated if
// the class itself is instantiated
using static_register_type =
std::integral_constant<decltype(&static_register), &static_register>;
};
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
template <class T>
int command<T>::static_register = auto_register_command<T>(); // NOLINT
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
#include "argument_parser.hpp"
#include "command.hpp"
#include "verify.hpp"
#include "perf.hpp"
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
struct loader
{
std::string file;
std::string file_type;
bool is_nhwc = true;
unsigned trim = 0;
void parse(argument_parser& ap)
{
ap(file, {}, ap.metavar("<input file>"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
}
program load()
{
program p;
if(file_type.empty())
{
if(ends_with(file, ".onnx"))
file_type = "onnx";
else if(ends_with(file, ".pb"))
file_type = "tf";
}
std::cout << "Reading: " << file << std::endl;
if(file_type == "onnx")
p = parse_onnx(file);
else if(file_type == "tf")
p = parse_tf(file, is_nhwc);
if(trim > 0)
{
auto last = std::prev(p.end(), trim);
p.remove_instructions(last, p.end());
}
return p;
}
};
struct compiler
{
loader l;
bool gpu = true;
void parse(argument_parser& ap)
{
l.parse(ap);
ap(gpu, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value(true));
ap(gpu, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value(false));
}
program compile()
{
auto p = l.load();
compile_program(p, gpu);
return p;
}
auto params(const program& p) { return create_param_map(p, gpu); }
};
struct read : command<read>
{
loader l;
void parse(argument_parser& ap) { l.parse(ap); }
void run()
{
auto p = l.load();
std::cout << p << std::endl;
}
};
struct verify : command<verify>
{
loader l;
double tolerance = 80;
bool per_instruction = false;
bool reduce = false;
void parse(argument_parser& ap)
{
l.parse(ap);
ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors"));
ap(per_instruction,
{"-i", "--per-instruction"},
ap.help("Verify each instruction"),
ap.set_value(true));
ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true));
}
void run()
{
auto p = l.load();
std::cout << p << std::endl;
if(per_instruction)
{
verify_instructions(p, tolerance);
}
else if(reduce)
{
verify_reduced_program(p, tolerance);
}
else
{
verify_program(l.file, p, tolerance);
}
}
};
struct compile : command<compile>
{
compiler c;
void parse(argument_parser& ap) { c.parse(ap); }
void run()
{
std::cout << "Compiling ... " << std::endl;
auto p = c.compile();
std::cout << p << std::endl;
}
};
struct run_cmd : command<run_cmd>
{
compiler c;
void parse(argument_parser& ap) { c.parse(ap); }
void run()
{
std::cout << "Compiling ... " << std::endl;
auto p = c.compile();
std::cout << "Allocating params ... " << std::endl;
auto m = c.params(p);
p.eval(m);
std::cout << p << std::endl;
}
};
struct perf : command<perf>
{
compiler c;
unsigned n = 100;
void parse(argument_parser& ap)
{
c.parse(ap);
ap(n, {"--iterations", "-n"}, ap.help("Number of iterations to run for perf report"));
}
void run()
{
std::cout << "Compiling ... " << std::endl;
auto p = c.compile();
std::cout << "Allocating params ... " << std::endl;
auto m = c.params(p);
std::cout << "Running performance report ... " << std::endl;
p.perf_report(std::cout, n, m);
}
};
struct main_command
{
static std::string get_command_help()
{
std::string result = "Commands:\n";
return std::accumulate(get_commands().begin(),
get_commands().end(),
result,
[](auto r, auto&& p) { return r + " " + p.first + "\n"; });
}
void parse(argument_parser& ap)
{
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
}
void run() {}
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
using namespace migraphx::driver; // NOLINT
int main(int argc, const char* argv[])
{
std::vector<std::string> args(argv + 1, argv + argc);
if(args.empty())
return 0;
auto&& m = get_commands();
auto cmd = args.front();
if(m.count(cmd) > 0)
{
m.at(cmd)({args.begin() + 1, args.end()});
}
else
{
run_command<main_command>(args);
}
return 0;
}
#include "perf.hpp"
#include <migraphx/cpu/target.hpp>
#include <migraphx/generate.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
program::parameter_map create_param_map(const program& p, bool gpu)
{
program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
#ifdef HAVE_GPU
if(gpu)
m[x.first] = gpu::to_gpu(generate_argument(x.second));
else
#else
(void)gpu;
#endif
m[x.first] = generate_argument(x.second);
}
return m;
}
void compile_program(program& p, bool gpu)
{
if(gpu)
{
#ifdef HAVE_GPU
p.compile(gpu::target{});
#else
MIGRAPHX_THROW("Gpu not supported.");
#endif
}
else
{
p.compile(cpu::target{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_PERF_HPP
#define MIGRAPHX_GUARD_RTGLIB_PERF_HPP
#include <migraphx/program.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
program::parameter_map create_param_map(const program& p, bool gpu = true);
void compile_program(program& p, bool gpu = true);
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
#include "verify.hpp"
#include <migraphx/cpu/target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify_args.hpp>
#include <migraphx/instruction.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
auto get_hash(const T& x)
{
return std::hash<T>{}(x);
}
argument run_cpu(program p)
{
p.compile(cpu::target{});
program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = generate_argument(x.second, get_hash(x.first));
}
auto out = p.eval(m);
std::cout << p << std::endl;
return out;
}
argument run_gpu(program p)
{
#ifdef HAVE_GPU
p.compile(gpu::target{});
program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = gpu::to_gpu(generate_argument(x.second, get_hash(x.first)));
}
auto out = gpu::from_gpu(p.eval(m));
std::cout << p << std::endl;
return gpu::from_gpu(out);
#else
(void)p;
MIGRAPHX_THROW("Gpu unsupported!");
#endif
}
void verify_program(const std::string& name, const program& p, double tolerance)
{
auto x = run_cpu(p);
auto y = run_gpu(p);
verify_args(name, x, y, tolerance);
// std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl;
}
void verify_instructions(const program& prog, double tolerance)
{
for(auto&& ins : prog)
{
if(ins.name().front() == '@')
continue;
if(ins.name() == "broadcast")
continue;
if(ins.name() == "transpose")
continue;
if(ins.name() == "reshape")
continue;
program p;
std::vector<instruction_ref> inputs;
for(auto&& arg : ins.inputs())
{
if(arg->name() == "@literal")
inputs.push_back(p.add_literal(arg->get_literal()));
else
inputs.push_back(p.add_parameter(std::to_string(inputs.size()), arg->get_shape()));
}
p.add_instruction(ins.get_operator(), inputs);
try
{
std::cout << "Verify: " << ins.name() << std::endl;
std::cout << p << std::endl;
verify_program(ins.name(), p, tolerance);
}
catch(...)
{
std::cout << "Instruction " << ins.name() << " threw an exception." << std::endl;
throw;
}
}
}
void verify_reduced(program p, int n, double tolerance)
{
auto last = std::prev(p.end(), n + 1);
p.remove_instructions(last, p.end());
std::cout << "Verify: " << std::endl;
std::cout << p << std::endl;
verify_program(std::to_string(n), p, tolerance);
}
void verify_reduced_program(const program& p, double tolerance)
{
auto n = std::distance(p.begin(), p.end());
for(std::size_t i = 0; i < n; i++)
{
verify_reduced(p, i, tolerance);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP
#include <migraphx/program.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
argument run_cpu(program p);
argument run_gpu(program p);
void verify_program(const std::string& name, const program& p, double tolerance = 100);
void verify_instructions(const program& prog, double tolerance = 80);
void verify_reduced_program(const program& p, double tolerance = 80);
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
...@@ -31,7 +31,7 @@ struct pooling ...@@ -31,7 +31,7 @@ struct pooling
{ {
return pack(f(self.mode, "mode"), return pack(f(self.mode, "mode"),
f(self.padding, "padding"), f(self.padding, "padding"),
f(self.padding, "padding_mode"), f(self.padding_mode, "padding_mode"),
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.lengths, "lengths")); f(self.lengths, "lengths"));
} }
......
...@@ -15,35 +15,18 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT ...@@ -15,35 +15,18 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
template <bool B> template <bool B>
using bool_c = std::integral_constant<bool, B>; using bool_c = std::integral_constant<bool, B>;
template <int N> #define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y
struct requires_enum #define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y)
{
enum e
{
a = 0
};
};
#define MIGRAPHX_REQUIRES_CAT(x, y) x##y #define MIGRAPHX_REQUIRES_VAR() MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__)
#ifdef CPPCHECK #ifdef CPPCHECK
#define MIGRAPHX_REQUIRES(...) class = void #define MIGRAPHX_REQUIRES(...) class = void
#else #else
#if 0 #define MIGRAPHX_REQUIRES(...) \
// TODO: This currently crashed on clang bool MIGRAPHX_REQUIRES_VAR() = true, \
#define MIGRAPHX_REQUIRES(...) \ typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() && (migraphx::and_<__VA_ARGS__>{})), \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \ int>::type = 0
PrivateRequires, \
__LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__, \
MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__) == \
migraphx::requires_enum<__LINE__>::a>{}>::type
#else
#define MIGRAPHX_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \
PrivateRequires, __LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__>{}>::type
#endif
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -42,7 +42,9 @@ template <class Range> ...@@ -42,7 +42,9 @@ template <class Range>
auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r) auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r)
-> decltype(r.begin(), r.end(), void()) -> decltype(r.begin(), r.end(), void())
{ {
os << "{";
os << stream_range(r); os << stream_range(r);
os << "}";
} }
template <class T> template <class T>
......
...@@ -12,12 +12,7 @@ if(MIGRAPHX_ENABLE_PYTHON) ...@@ -12,12 +12,7 @@ if(MIGRAPHX_ENABLE_PYTHON)
C_VISIBILITY_PRESET hidden C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden CXX_VISIBILITY_PRESET hidden
) )
if(MIGRAPHX_ENABLE_TF) target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_cpu)
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_cpu)
target_compile_definitions(migraphx_py PRIVATE -DENABLE_TF)
else()
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu)
endif()
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(migraphx_py PRIVATE migraphx_gpu) target_link_libraries(migraphx_py PRIVATE migraphx_gpu)
target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU) target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU)
......
...@@ -6,11 +6,8 @@ ...@@ -6,11 +6,8 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#ifdef ENABLE_TF
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#else
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#endif
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
...@@ -161,16 +158,13 @@ PYBIND11_MODULE(migraphx, m) ...@@ -161,16 +158,13 @@ PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::program>{}) .def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); }); .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
#ifdef ENABLE_TF
m.def("parse_tf", m.def("parse_tf",
&migraphx::parse_tf, &migraphx::parse_tf,
"Parse tf protobuf (default format is nhwc)", "Parse tf protobuf (default format is nhwc)",
py::arg("filename"), py::arg("filename"),
py::arg("is_nhwc") = true); py::arg("is_nhwc") = true);
#else
m.def("parse_onnx", &migraphx::parse_onnx); m.def("parse_onnx", &migraphx::parse_onnx);
#endif
m.def("get_target", [](const std::string& name) -> migraphx::target { m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu") if(name == "cpu")
return migraphx::cpu::target{}; return migraphx::cpu::target{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment