Unverified Commit 35d1bcc2 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add code generation for pointwise operators (#780)

* Add definitions for all pointwise operators

* Formatting

* Add cpp generator class

* Formatting

* Move compilation to core

* Formatting

* Add clock to tmp name

* Add dynamic loader

* Formatting

* Add tests for code gen

* Formatting

* Add test for literals

* Formatting

* Use with_char

* Add missing header

* Fix mismerge

* Ignore tidy warning

* Fxx gcc 5 errors

* Apply fixits

* Skip signed bitwise of status

* Remove unused parameters

* Explicitly add c++14 flag

* Fix tidy warning

* Remove .o files
parent 3e92ef7a
......@@ -11,7 +11,10 @@ add_library(migraphx
eliminate_common_subexpression.cpp
decompose.cpp
propagate_constant.cpp
compile_src.cpp
cpp_generator.cpp
dead_code_elimination.cpp
dynamic_loader.cpp
eliminate_allocation.cpp
eliminate_contiguous.cpp
eliminate_concat.cpp
......@@ -30,6 +33,7 @@ add_library(migraphx
msgpack.cpp
operation.cpp
permutation.cpp
process.cpp
program.cpp
module.cpp
quantization.cpp
......@@ -169,6 +173,8 @@ if(HAS_LIB_STD_FILESYSTEM)
target_link_libraries(migraphx PRIVATE -lstdc++fs)
endif()
target_link_libraries(migraphx PRIVATE -ldl)
target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
find_package(msgpack REQUIRED)
......
#include <migraphx/compile_src.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp>
#include <cassert>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
{
assert(not srcs.empty());
tmp_dir td{"compile"};
auto params = flags;
params += " -I.";
auto out = output;
for(const auto& src : srcs)
{
fs::path full_path = td.path / src.path;
fs::path parent_path = full_path.parent_path();
fs::create_directories(parent_path);
write_buffer(full_path.string(), src.content.first, src.len());
if(src.path.extension().string() == ".cpp")
{
params += " " + src.path.filename().string();
if(out.empty())
out = src.path.stem().string() + ".o";
}
}
params += " -o" + out;
td.execute(compiler, params);
auto out_path = td.path / out;
if(not fs::exists(out_path))
MIGRAPHX_THROW("Output file missing: " + out);
if(process)
out_path = process(out_path);
return read_buffer(out_path.string());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/cpp_generator.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/iterator_for.hpp>
#include <map>
#include <sstream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
cpp_generator::function&
cpp_generator::function::set_body(const module& m, const cpp_generator::generate_module_callback& g)
{
std::unordered_map<migraphx::instruction_ref, std::string> names;
std::stringstream ss;
auto return_ins = std::prev(m.end());
for(auto ins : iterator_for(m))
{
ss << "// " << ins->get_operator() << " -> " << ins->get_shape() << "\n";
if(ins->name() == "@param")
{
names[ins] =
migraphx::any_cast<migraphx::builtin::param>(ins->get_operator()).parameter;
continue;
}
if(ins->name() == "@return")
{
assert(ins->inputs().size() == 1);
return_ins = ins->inputs().front();
}
std::string n = "z" + std::to_string(names.size());
names[ins] = n;
ss << "auto " << n << " = " << g(ins, names) << ";\n";
}
ss << "return " << names.at(return_ins) << ";\n";
body = ss.str();
return *this;
}
cpp_generator::function& cpp_generator::function::set_types(const module& m)
{
return cpp_generator::function::set_types(m, [](auto s) { return shape::cpp_type(s.type()); });
}
cpp_generator::function&
cpp_generator::function::set_types(const module& m, const std::function<std::string(shape)>& parse)
{
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, parse(p.second)};
});
auto output_shapes = m.get_output_shapes();
assert(not output_shapes.empty());
this->return_type = parse(output_shapes.front());
return *this;
}
struct cpp_generator_impl
{
std::stringstream fs{};
std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr;
};
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
cpp_generator::cpp_generator(cpp_generator&&) noexcept = default;
cpp_generator& cpp_generator::operator=(cpp_generator rhs)
{
std::swap(impl, rhs.impl);
return *this;
}
cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
std::string cpp_generator::generate_point_op(const operation& op,
const std::vector<std::string>& args)
{
auto v = op.to_value();
return interpolate_string(op.attributes()["point_op"].to<std::string>(),
[&](auto start, auto last) -> std::string {
auto key = trim({start, last});
if(key.empty())
MIGRAPHX_THROW("Empty parameter");
std::string fselector = "function:";
if(starts_with(key, fselector))
{
auto fname = key.substr(fselector.size());
if(impl->fmap == nullptr)
return fname;
else
return impl->fmap(fname);
}
else if(with_char(::isdigit)(key[0]))
{
auto i = std::stoul(key);
return args.at(i);
}
else if(v.contains(key))
{
return v[key].template to<std::string>();
}
else
{
return key;
}
});
}
std::string cpp_generator::str() const { return impl->fs.str(); }
cpp_generator::function cpp_generator::generate_module(const module& m)
{
function f;
f.set_name(m.name()).set_types(m).set_body(
m, [&](instruction_ref ins, const auto& names) -> std::string {
if(ins->name() == "@literal")
return shape::cpp_type(ins->get_shape().type()) + "(" +
ins->get_literal().to_string() + ")";
std::vector<std::string> args;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(args),
[&](auto i) { return names.at(i); });
auto s = this->generate_point_op(ins->get_operator(), args);
return this->generate_point_op(ins->get_operator(), args);
});
return f;
}
std::string cpp_generator::create_function(const cpp_generator::function& f)
{
impl->function_count++;
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
char delim = '(';
for(auto&& p : f.params)
{
impl->fs << delim << p.type << " " << p.name;
delim = ',';
}
impl->fs << ") {\n" << f.body << "\n}\n";
return name;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp>
#include <utility>
#include <dlfcn.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct dynamic_loader_impl
{
dynamic_loader_impl() = default;
dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr)
: handle(dlopen(p.string().c_str(), RTLD_LAZY), &dlclose), temp(std::move(t))
{
}
static std::shared_ptr<dynamic_loader_impl> from_buffer(const char* image, std::size_t size)
{
auto t = std::make_shared<tmp_dir>("dloader");
auto f = t->path / "libtmp.so";
write_buffer(f.string(), image, size);
return std::make_shared<dynamic_loader_impl>(f, t);
}
std::shared_ptr<void> handle = nullptr;
std::shared_ptr<tmp_dir> temp = nullptr;
};
dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared<dynamic_loader_impl>(p))
{
}
dynamic_loader::dynamic_loader(const char* image, std::size_t size)
: impl(dynamic_loader_impl::from_buffer(image, size))
{
}
dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
: impl(dynamic_loader_impl::from_buffer(buffer.data(), buffer.size()))
{
}
std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const
{
void* symbol = dlsym(impl->handle.get(), name.c_str());
if(symbol == nullptr)
MIGRAPHX_THROW("Symbol not found: " + name);
return {impl, symbol};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMPILE_SRC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_COMPILE_SRC_HPP
#include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp>
#include <functional>
#include <string>
#include <utility>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct src_file
{
fs::path path;
std::pair<const char*, const char*> content;
std::size_t len() const { return content.second - content.first; }
};
struct src_compiler
{
std::string compiler = "c++";
std::string flags = "";
std::string output = "";
std::function<fs::path(fs::path)> process = nullptr;
std::vector<char> compile(const std::vector<src_file>& srcs) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMPILE_SRC_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHX_CPP_GENERATOR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_CPP_GENERATOR_HPP
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <string>
#include <unordered_map>
#include <vector>
#include <memory>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct operation;
struct module;
struct shape;
struct cpp_generator_impl;
struct cpp_generator
{
using generate_module_callback = std::function<std::string(
instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>;
struct param
{
std::string name;
std::string type;
};
struct function
{
std::vector<param> params = {};
std::string body = "";
std::string return_type = "void";
std::string name = "";
std::vector<std::string> attributes = {};
function& set_body(const module& m, const generate_module_callback& g);
function& set_body(const std::string& s)
{
body = s;
return *this;
}
function& set_name(const std::string& s)
{
name = s;
return *this;
}
function& set_attributes(std::vector<std::string> attrs)
{
attributes = std::move(attrs);
return *this;
}
function& set_types(const module& m);
function& set_types(const module& m, const std::function<std::string(shape)>& parse);
};
cpp_generator();
// move constructor
cpp_generator(cpp_generator&&) noexcept;
// copy assignment operator
cpp_generator& operator=(cpp_generator rhs);
~cpp_generator() noexcept;
void fmap(const std::function<std::string(std::string)>& f);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
std::string str() const;
function generate_module(const module& m, const generate_module_callback& g);
function generate_module(const module& m);
std::string create_function(const function& f);
private:
std::unique_ptr<cpp_generator_impl> impl;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_CPP_GENERATOR_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHX_DYNAMIC_LOADER_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_DYNAMIC_LOADER_HPP
#include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp>
#include <functional>
#include <memory>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct dynamic_loader_impl;
struct dynamic_loader
{
dynamic_loader() = default;
dynamic_loader(const fs::path& p);
dynamic_loader(const char* image, std::size_t size);
dynamic_loader(const std::vector<char>& buffer);
std::shared_ptr<void> get_symbol(const std::string& name) const;
template <class F>
std::function<F> get_function(const std::string& name) const
{
auto s = get_symbol(name);
return [=](auto&&... xs) -> decltype(auto) {
auto f = reinterpret_cast<std::add_pointer_t<F>>(s.get());
return f(std::forward<decltype(xs)>(xs)...);
};
}
private:
std::shared_ptr<dynamic_loader_impl> impl;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_DYNAMIC_LOADER_HPP
......@@ -24,6 +24,7 @@ struct add : binary<add>
a["commutative"] = true;
return a;
}
std::string point_function() const { return "+"; }
auto apply() const
{
return [](auto x, auto y) { return x + y; };
......
......@@ -5,6 +5,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
......@@ -14,7 +15,27 @@ namespace op {
template <class Derived>
struct binary : op_name<Derived>
{
value base_attributes() const { return {{"pointwise", true}}; }
std::string point_function() const { return this->name(); }
std::string point_op() const
{
const auto& self = static_cast<const Derived&>(*this);
auto pf = self.point_function();
if(pf.empty())
return {};
if(with_char(::ispunct)(pf.front()))
{
return "${0} " + pf + " ${1}";
}
else
{
return "${function:" + pf + "}(${0}, ${1})";
}
}
value base_attributes() const
{
const auto& self = static_cast<const Derived&>(*this);
return {{"pointwise", true}, {"point_op", self.point_op()}};
}
value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -18,6 +18,7 @@ namespace op {
struct div : binary<div>
{
std::string point_function() const { return "/"; }
auto apply() const
{
return [](auto x, auto y) { return x / y; };
......
......@@ -19,6 +19,7 @@ struct equal : binary<equal>
a["commutative"] = true;
return a;
}
std::string point_function() const { return "=="; }
auto apply() const
{
return [](auto x, auto y) { return float_equal(x, y); };
......
......@@ -12,6 +12,7 @@ namespace op {
struct greater : binary<greater>
{
std::string point_function() const { return ">"; }
auto apply() const
{
return [](auto x, auto y) { return x > y; };
......
......@@ -12,6 +12,7 @@ namespace op {
struct less : binary<less>
{
std::string point_function() const { return "<"; }
auto apply() const
{
return [](auto x, auto y) { return x < y; };
......
......@@ -12,6 +12,7 @@ namespace op {
struct logical_and : binary<logical_and>
{
std::string point_function() const { return "&&"; }
auto apply() const
{
return [](auto x, auto y) { return static_cast<bool>(x) and static_cast<bool>(y); };
......
......@@ -12,6 +12,7 @@ namespace op {
struct logical_or : binary<logical_or>
{
std::string point_function() const { return "||"; }
auto apply() const
{
return [](auto x, auto y) { return static_cast<bool>(x) or static_cast<bool>(y); };
......
......@@ -12,6 +12,7 @@ namespace op {
struct logical_xor : binary<logical_xor>
{
std::string point_function() const { return "^"; }
auto apply() const
{
return [](auto x, auto y) { return static_cast<bool>(x) xor static_cast<bool>(y); };
......
......@@ -24,6 +24,7 @@ struct mul : binary<mul>
a["commutative"] = true;
return a;
}
std::string point_function() const { return "*"; }
auto apply() const
{
return [](auto x, auto y) { return x * y; };
......
......@@ -18,6 +18,7 @@ namespace op {
struct neg : unary<neg>
{
std::string point_function() const { return "-"; }
auto apply() const
{
return [](auto x) { return -x; };
......
......@@ -18,6 +18,7 @@ namespace op {
struct relu : unary<relu>
{
std::string point_op() const { return "${function:max}(decltype(${0}){0}, ${0})"; }
auto apply() const
{
return [](auto x) { return std::max(decltype(x){0}, x); };
......
......@@ -9,6 +9,7 @@ namespace op {
struct sqdiff : binary<sqdiff>
{
std::string point_op() const { return "(${0} - ${1}) * (${0} - ${1})"; }
auto apply() const
{
return [](auto x, auto y) { return (x - y) * (x - y); };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment