"src/vscode:/vscode.git/clone" did not exist on "d2a38cd41e454e93b05c8ffd7a787a6dcd3b02f9"
Unverified Commit 35d1bcc2 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add code generation for pointwise operators (#780)

* Add definitions for all pointwise operators

* Formatting

* Add cpp generator class

* Formatting

* Move compilation to core

* Formatting

* Add clock to tmp name

* Add dynamic loader

* Formatting

* Add tests for code gen

* Formatting

* Add test for literals

* Formatting

* Use with_char

* Add missing header

* Fix mismerge

* Ignore tidy warning

* Fxx gcc 5 errors

* Apply fixits

* Skip signed bitwise of status

* Remove unused parameters

* Explicitly add c++14 flag

* Fix tidy warning

* Remove .o files
parent 3e92ef7a
...@@ -18,6 +18,7 @@ namespace op { ...@@ -18,6 +18,7 @@ namespace op {
struct sub : binary<sub> struct sub : binary<sub>
{ {
std::string point_function() const { return "-"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return x - y; }; return [](auto x, auto y) { return x - y; };
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
namespace migraphx { namespace migraphx {
...@@ -14,7 +15,27 @@ namespace op { ...@@ -14,7 +15,27 @@ namespace op {
template <class Derived> template <class Derived>
struct unary : op_name<Derived> struct unary : op_name<Derived>
{ {
value base_attributes() const { return {{"pointwise", true}}; } std::string point_function() const { return this->name(); }
std::string point_op() const
{
const auto& self = static_cast<const Derived&>(*this);
auto pf = self.point_function();
if(pf.empty())
return {};
if(with_char(::ispunct)(pf.front()))
{
return pf + "${0}";
}
else
{
return "${function:" + pf + "}(${0})";
}
}
value base_attributes() const
{
const auto& self = static_cast<const Derived&>(*this);
return {{"pointwise", true}, {"point_op", self.point_op()}};
}
value attributes() const { return base_attributes(); } value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -12,6 +12,7 @@ namespace op { ...@@ -12,6 +12,7 @@ namespace op {
struct unary_not : unary<unary_not> struct unary_not : unary<unary_not>
{ {
std::string point_function() const { return "!"; }
auto apply() const auto apply() const
{ {
return [](auto x) { return not x; }; return [](auto x) { return not x; };
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PROCESS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PROCESS_HPP
#include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp>
#include <string>
#include <memory>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct process_impl;
struct process
{
process(const std::string& cmd);
// move constructor
process(process&&) noexcept;
// copy assignment operator
process& operator=(process rhs);
~process() noexcept;
process& cwd(const fs::path& p);
void exec();
private:
std::unique_ptr<process_impl> impl;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PROCESS_HPP
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/tensor_view.hpp> #include <migraphx/tensor_view.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <sstream>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -146,6 +147,13 @@ struct raw_data : raw_data_base ...@@ -146,6 +147,13 @@ struct raw_data : raw_data_base
migraphx::shape::get_type<T>{}); migraphx::shape::get_type<T>{});
return reinterpret_cast<T*>(buffer); return reinterpret_cast<T*>(buffer);
} }
std::string to_string() const
{
std::stringstream ss;
ss << static_cast<const Derived&>(*this);
return ss.str();
}
}; };
template <class T, template <class T,
......
...@@ -61,6 +61,9 @@ struct shape ...@@ -61,6 +61,9 @@ struct shape
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
static std::string name(type_t t);
static std::string cpp_type(type_t t);
shape(); shape();
shape(type_t t); shape(type_t t);
shape(type_t t, std::vector<std::size_t> l); shape(type_t t, std::vector<std::size_t> l);
......
...@@ -15,6 +15,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -15,6 +15,12 @@ inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPHX_STRINGIZE_1(...) #__VA_ARGS__ #define MIGRAPHX_STRINGIZE_1(...) #__VA_ARGS__
#define MIGRAPHX_STRINGIZE(...) MIGRAPHX_STRINGIZE_1(__VA_ARGS__) #define MIGRAPHX_STRINGIZE(...) MIGRAPHX_STRINGIZE_1(__VA_ARGS__)
template <class F>
auto with_char(F f)
{
return [=](unsigned char c) { return f(c); };
}
inline std::string inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace) replace_string(std::string subject, const std::string& search, const std::string& replace)
{ {
...@@ -70,7 +76,7 @@ std::string trim(const std::string& s, F f) ...@@ -70,7 +76,7 @@ std::string trim(const std::string& s, F f)
inline std::string trim(const std::string& s) inline std::string trim(const std::string& s)
{ {
return trim(s, [](int c) { return std::isspace(c); }); return trim(s, [](unsigned char c) { return std::isspace(c); });
} }
template <class F> template <class F>
...@@ -92,6 +98,14 @@ inline bool starts_with(const std::string& value, const std::string& prefix) ...@@ -92,6 +98,14 @@ inline bool starts_with(const std::string& value, const std::string& prefix)
return std::equal(prefix.begin(), prefix.end(), value.begin()); return std::equal(prefix.begin(), prefix.end(), value.begin());
} }
inline std::string remove_prefix(std::string s, const std::string& prefix)
{
if(starts_with(s, prefix))
return s.substr(prefix.length());
else
return s;
}
template <class F> template <class F>
inline std::string inline std::string
interpolate_string(const std::string& input, F f, std::string start = "${", std::string end = "}") interpolate_string(const std::string& input, F f, std::string start = "${", std::string end = "}")
...@@ -124,14 +138,6 @@ inline std::string interpolate_string(const std::string& input, ...@@ -124,14 +138,6 @@ inline std::string interpolate_string(const std::string& input,
}); });
} }
inline std::string remove_prefix(std::string s, const std::string& prefix)
{
if(starts_with(s, prefix))
return s.substr(prefix.length());
else
return s;
}
template <class Iterator> template <class Iterator>
inline std::string to_string_range(Iterator start, Iterator last) inline std::string to_string_range(Iterator start, Iterator last)
{ {
......
...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct tmp_dir struct tmp_dir
{ {
fs::path path; fs::path path;
tmp_dir(); tmp_dir(const std::string& prefix = "");
void execute(const std::string& exe, const std::string& args) const; void execute(const std::string& exe, const std::string& args) const;
......
#include <migraphx/process.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/env.hpp>
#include <functional>
#include <iostream>
#include <unistd.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_CMD_EXECUTE)
std::function<void(const char*)> redirect_to(std::ostream& os)
{
return [&](const char* x) { os << x; };
}
int exec(const std::string& cmd, const std::function<void(const char*)>& std_out)
{
int ec = 0;
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl;
std::array<char, 128> buffer;
auto closer = [&](FILE* stream) {
auto status = pclose(stream);
ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT
};
{
// TODO: Use execve instead of popen
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT
if(!pipe)
MIGRAPHX_THROW("popen() failed: " + cmd);
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
std_out(buffer.data());
}
return ec;
}
struct process_impl
{
std::string command{};
fs::path cwd{};
std::string get_command() const
{
std::string result;
if(not cwd.empty())
result += "cd " + cwd.string() + "; ";
result += command;
return result;
}
};
process::process(const std::string& cmd) : impl(std::make_unique<process_impl>())
{
impl->command = cmd;
}
process::process(process&&) noexcept = default;
process& process::operator=(process rhs)
{
std::swap(impl, rhs.impl);
return *this;
}
process::~process() noexcept = default;
process& process::cwd(const fs::path& p)
{
impl->cwd = p;
return *this;
}
void process::exec()
{
auto ec = migraphx::exec(impl->get_command(), redirect_to(std::cout));
if(ec != 0)
MIGRAPHX_THROW("Command " + impl->get_command() + " exited with status " +
std::to_string(ec));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -88,6 +88,29 @@ const std::vector<shape::type_t>& shape::types() ...@@ -88,6 +88,29 @@ const std::vector<shape::type_t>& shape::types()
return result; return result;
} }
std::string shape::name(shape::type_t t)
{
switch(t)
{
#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \
case x: return #x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE
}
MIGRAPHX_THROW("Invalid type");
}
std::string shape::cpp_type(shape::type_t t)
{
switch(t)
{
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
case x: return #t;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE
}
MIGRAPHX_THROW("Invalid type");
}
shape::shape() : impl(shape_impl::default_shape()) {} shape::shape() : impl(shape_impl::default_shape()) {}
shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {} shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
...@@ -246,17 +269,7 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const ...@@ -246,17 +269,7 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const
std::size_t shape::element_space() const { return impl->element_space(); } std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const std::string shape::type_string() const { return name(this->type()); }
{
switch(this->type())
{
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE(x, t) \
case x: return #x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE
}
MIGRAPHX_THROW("Invalid type");
}
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
{ {
......
#include <migraphx/gpu/compile_hip.hpp> #include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/compile_src.hpp>
#include <migraphx/process.hpp>
#include <cassert> #include <cassert>
namespace migraphx { namespace migraphx {
...@@ -24,13 +24,11 @@ bool is_hip_clang_compiler() ...@@ -24,13 +24,11 @@ bool is_hip_clang_compiler()
std::vector<std::vector<char>> std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{ {
std::vector<std::vector<char>> hsacos; assert(not srcs.empty());
if(not is_hcc_compiler() and not is_hip_clang_compiler()) if(not is_hcc_compiler() and not is_hip_clang_compiler())
MIGRAPHX_THROW("Unknown hip compiler: " + MIGRAPHX_THROW("Unknown hip compiler: " +
std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER))); std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER)));
assert(not srcs.empty());
tmp_dir td{};
params += " -Wno-cuda-compat";
if(params.find("-std=") == std::string::npos) if(params.find("-std=") == std::string::npos)
params += " --std=c++17"; params += " --std=c++17";
params += " -fno-gpu-rdc"; params += " -fno-gpu-rdc";
...@@ -46,55 +44,31 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -46,55 +44,31 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
params += " -O3 "; params += " -O3 ";
} }
params += " -Wno-unused-command-line-argument -I. "; params += " -Wno-unused-command-line-argument -Wno-cuda-compat ";
params += MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_FLAGS); params += MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_FLAGS);
std::string output_flags{}; src_compiler compiler;
compiler.flags = params;
for(const auto& src : srcs) compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
{
fs::path full_path = td.path / src.path;
fs::path parent_path = full_path.parent_path();
fs::create_directories(parent_path);
write_buffer(full_path.string(), src.content.first, src.len());
if(src.path.extension().string() == ".cpp")
{
params += " " + src.path.filename().string();
output_flags = " -o " + src.path.stem().string() + ".o";
}
}
params += output_flags;
td.execute(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), params); if(is_hcc_compiler())
compiler.process = [&](const fs::path& obj_path) -> fs::path {
for(const auto& entry : fs::directory_iterator{td.path}) process{MIGRAPHX_STRINGIZE(MIGRAPHX_EXTRACT_KERNEL) + std::string{" -i "} +
{ obj_path.string()}
const auto& obj_path = entry.path(); .cwd(obj_path.parent_path());
if(not fs::is_regular_file(obj_path)) for(const auto& entry : fs::directory_iterator{obj_path.parent_path()})
continue; {
if(obj_path.extension() != ".o") const auto& hsaco_path = entry.path();
continue; if(not fs::is_regular_file(hsaco_path))
if(is_hcc_compiler()) continue;
{ if(hsaco_path.extension() != ".hsaco")
// call extract kernel continue;
td.execute(MIGRAPHX_STRINGIZE(MIGRAPHX_EXTRACT_KERNEL), " -i " + obj_path.string()); return hsaco_path;
} }
} MIGRAPHX_THROW("Missing hsaco");
};
const std::string ext = is_hcc_compiler() ? ".hsaco" : ".o";
for(const auto& entry : fs::directory_iterator{td.path})
{
const auto& obj_path = entry.path();
if(not fs::is_regular_file(obj_path))
continue;
if(obj_path.extension() != ext)
continue;
hsacos.push_back(read_buffer(obj_path.string()));
}
return hsacos; return {compiler.compile(srcs)};
} }
} // namespace gpu } // namespace gpu
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/compile_src.hpp>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -11,13 +12,6 @@ namespace migraphx { ...@@ -11,13 +12,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct src_file
{
fs::path path;
std::pair<const char*, const char*> content;
std::size_t len() const { return content.second - content.first; }
};
std::vector<std::vector<char>> std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch); compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch);
......
#include <migraphx/tmp_dir.hpp> #include <migraphx/tmp_dir.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/process.hpp>
#include <algorithm> #include <algorithm>
#include <random> #include <random>
#include <thread> #include <thread>
...@@ -14,7 +15,6 @@ namespace migraphx { ...@@ -14,7 +15,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_SAVE_TEMP_DIR) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_SAVE_TEMP_DIR)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_CMD_EXECUTE)
std::string random_string(std::string::size_type length) std::string random_string(std::string::size_type length)
{ {
...@@ -35,34 +35,22 @@ std::string unique_string(const std::string& prefix) ...@@ -35,34 +35,22 @@ std::string unique_string(const std::string& prefix)
{ {
auto pid = getpid(); auto pid = getpid();
auto tid = std::this_thread::get_id(); auto tid = std::this_thread::get_id();
auto clk = std::chrono::steady_clock::now().time_since_epoch().count();
std::stringstream ss; std::stringstream ss;
ss << prefix << "-" << pid << "-" << tid << "-" << random_string(64); ss << std::hex << prefix << "-" << pid << "-" << tid << "-" << clk << "-" << random_string(16);
return ss.str(); return ss.str();
} }
tmp_dir::tmp_dir() : path(fs::temp_directory_path() / unique_string("migraphx")) tmp_dir::tmp_dir(const std::string& prefix)
: path(fs::temp_directory_path() /
unique_string(prefix.empty() ? "migraphx" : "migraphx-" + prefix))
{ {
fs::create_directories(this->path); fs::create_directories(this->path);
} }
void system_cmd(const std::string& cmd)
{
// We shouldn't call system commands
#ifdef MIGRAPHX_USE_CLANG_TIDY
(void)cmd;
#else
if(std::system(cmd.c_str()) != 0)
MIGRAPHX_THROW("Can't execute " + cmd);
#endif
}
void tmp_dir::execute(const std::string& exe, const std::string& args) const void tmp_dir::execute(const std::string& exe, const std::string& args) const
{ {
std::string cd = "cd " + this->path.string() + "; "; process{exe + " " + args}.cwd(this->path).exec();
std::string cmd = cd + exe + " " + args; // + " > /dev/null";
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl;
system_cmd(cmd);
} }
tmp_dir::~tmp_dir() tmp_dir::~tmp_dir()
......
...@@ -69,7 +69,7 @@ int main() {} ...@@ -69,7 +69,7 @@ int main() {}
)__migraphx__"; )__migraphx__";
migraphx::gpu::src_file make_src_file(const std::string& name, const std::string& content) migraphx::src_file make_src_file(const std::string& name, const std::string& content)
{ {
return {name, std::make_pair(content.data(), content.data() + content.size())}; return {name, std::make_pair(content.data(), content.data() + content.size())};
} }
......
#include <algorithm>
#include <cassert> #include <cassert>
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <sstream>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -67,6 +69,27 @@ struct nop ...@@ -67,6 +69,27 @@ struct nop
} }
}; };
struct function
{
static std::string as_string() { return ""; }
template <class T>
static decltype(auto) call(T&& x)
{
return x();
}
};
template <class Iterator>
inline std::ostream& stream_range(std::ostream& s, Iterator start, Iterator last)
{
if(start != last)
{
s << *start;
std::for_each(std::next(start), last, [&](auto&& x) { s << ", " << x; });
}
return s;
}
inline std::ostream& operator<<(std::ostream& s, std::nullptr_t) inline std::ostream& operator<<(std::ostream& s, std::nullptr_t)
{ {
s << "nullptr"; s << "nullptr";
...@@ -77,10 +100,7 @@ template <class T> ...@@ -77,10 +100,7 @@ template <class T>
inline std::ostream& operator<<(std::ostream& s, const std::vector<T>& v) inline std::ostream& operator<<(std::ostream& s, const std::vector<T>& v)
{ {
s << "{ "; s << "{ ";
for(auto&& x : v) stream_range(s, v.begin(), v.end());
{
s << x << ", ";
}
s << "}"; s << "}";
return s; return s;
} }
...@@ -88,10 +108,7 @@ inline std::ostream& operator<<(std::ostream& s, const std::vector<T>& v) ...@@ -88,10 +108,7 @@ inline std::ostream& operator<<(std::ostream& s, const std::vector<T>& v)
inline std::ostream& operator<<(std::ostream& s, const std::vector<bool>& v) inline std::ostream& operator<<(std::ostream& s, const std::vector<bool>& v)
{ {
s << "{ "; s << "{ ";
for(auto x : v) stream_range(s, v.begin(), v.end());
{
s << x << ", ";
}
s << "}"; s << "}";
return s; return s;
} }
...@@ -142,7 +159,10 @@ struct lhs_expression ...@@ -142,7 +159,10 @@ struct lhs_expression
friend std::ostream& operator<<(std::ostream& s, const lhs_expression& self) friend std::ostream& operator<<(std::ostream& s, const lhs_expression& self)
{ {
s << Operator::as_string() << " " << self.lhs; std::string op = Operator::as_string();
if(not op.empty())
s << Operator::as_string() << " ";
s << self.lhs;
return s; return s;
} }
...@@ -180,6 +200,55 @@ struct lhs_expression ...@@ -180,6 +200,55 @@ struct lhs_expression
TEST_LHS_REOPERATOR (^) TEST_LHS_REOPERATOR (^)
}; };
template <class F>
struct predicate
{
std::string msg;
F f;
friend std::ostream& operator<<(std::ostream& s, const predicate& self)
{
s << self.msg;
return s;
}
decltype(auto) operator()() const { return f(); }
operator decltype(auto)() const { return f(); }
};
template <class F>
auto make_predicate(const std::string& msg, F f)
{
return make_lhs_expression(predicate<F>{msg, f}, function{});
}
template <class T>
std::string as_string(const T& x)
{
std::stringstream ss;
ss << x;
return ss.str();
}
template <class Iterator>
std::string as_string(Iterator start, Iterator last)
{
std::stringstream ss;
stream_range(ss, start, last);
return ss.str();
}
template <class F>
auto make_function(const std::string& name, F f)
{
return [=](auto&&... xs) {
std::vector<std::string> args = {as_string(xs)...};
return make_predicate(name + "(" + as_string(args.begin(), args.end()) + ")",
[=] { return f(xs...); });
};
}
struct capture struct capture
{ {
template <class T> template <class T>
...@@ -236,6 +305,13 @@ bool throws(F f, const std::string& msg = "") ...@@ -236,6 +305,13 @@ bool throws(F f, const std::string& msg = "")
} }
} }
template <class T, class U>
auto near(T px, U py, double ptol = 1e-6f)
{
return make_function("near", [](auto x, auto y, auto tol) { return std::abs(x - y) < tol; })(
px, py, ptol);
}
using string_map = std::unordered_map<std::string, std::vector<std::string>>; using string_map = std::unordered_map<std::string, std::vector<std::string>>;
template <class Keyword> template <class Keyword>
......
#include <migraphx/compile_src.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
// NOLINTNEXTLINE
const std::string add_42_src = R"migraphx(
extern "C" int add(int x)
{
return x+42;
}
)migraphx";
// NOLINTNEXTLINE
const std::string preamble = R"migraphx(
#include <cmath>
)migraphx";
template <class F>
std::function<F>
compile_function(const std::string& src, const std::string& flags, const std::string& fname)
{
migraphx::src_compiler compiler;
compiler.flags = flags + "-std=c++14 -fPIC -shared";
compiler.output = "libsimple.so";
migraphx::src_file f;
f.path = "main.cpp";
f.content = std::make_pair(src.data(), src.data() + src.size());
auto image = compiler.compile({f});
return migraphx::dynamic_loader{image}.get_function<F>(fname);
}
template <class F>
std::function<F> compile_module(const migraphx::module& m, const std::string& flags = "")
{
migraphx::cpp_generator g;
g.fmap([](auto&& name) { return "std::" + name; });
g.create_function(g.generate_module(m).set_attributes({"extern \"C\""}));
return compile_function<F>(preamble + g.str(), flags, m.name());
}
TEST_CASE(simple_run)
{
auto f = compile_function<int(int)>(add_42_src, "", "add");
EXPECT(f(8) == 50);
EXPECT(f(10) == 52);
}
TEST_CASE(generate_module)
{
migraphx::module m("foo");
auto x = m.add_parameter("x", migraphx::shape::float_type);
auto y = m.add_parameter("y", migraphx::shape::float_type);
auto sum = m.add_instruction(migraphx::make_op("add"), x, y);
m.add_instruction(migraphx::make_op("sqrt"), sum);
auto f = compile_module<float(float, float)>(m);
EXPECT(test::near(f(2, 2), 2));
EXPECT(test::near(f(10, 6), 4));
EXPECT(test::near(f(1, 2), std::sqrt(3)));
}
TEST_CASE(generate_module_with_literals)
{
migraphx::module m("foo");
auto x = m.add_parameter("x", migraphx::shape::float_type);
auto y = m.add_parameter("y", migraphx::shape::float_type);
auto z = m.add_literal(1.f);
auto sum1 = m.add_instruction(migraphx::make_op("add"), x, z);
auto sum2 = m.add_instruction(migraphx::make_op("add"), sum1, y);
m.add_instruction(migraphx::make_op("sqrt"), sum2);
auto f = compile_module<float(float, float)>(m);
EXPECT(test::near(f(1, 2), 2));
EXPECT(test::near(f(9, 6), 4));
EXPECT(test::near(f(0, 2), std::sqrt(3)));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment