Commit e2eb6036 authored by Paul's avatar Paul
Browse files

Merge

parents 298c93d5 1e0bbd78
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_size : op_parser<parse_size>
{
std::vector<op_desc> operators() const { return {{"Size"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
return info.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type},
{args[0]->get_shape().elements()}});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
//
// Supporting functions for enum values used in operator parameters.
// These values are declared as "enum class" and should include << streaming operators
// to be able to write their values in human-readable format so users can
// save and edit model files.
//
#include <sstream>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
std::ostream& operator<<(std::ostream& os, pooling_mode v)
{
// the strings for the enum are the same as the values used for onnx parsing
// but this enum is not onnx-specific: strings must be converted when parsing tf
static const std::vector<std::string> pooling_mode_str = {"average", "max"};
os << pooling_mode_str[static_cast<std::underlying_type<pooling_mode>::type>(v)];
return os;
}
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
static const std::vector<std::string> rnn_direction_str = {
"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -353,13 +353,20 @@ std::vector<argument> program::eval(parameter_map params) const
if(trace_level > 0)
{
std::unordered_map<instruction_ref, std::string> ins_out;
// get instruction names
this->print([&](auto x, auto ins_names) {
std::stringstream ss;
instruction::print(ss, x, ins_names);
ins_out[x] = ss.str();
});
return generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish();
std::cout << "Run instruction: ";
this->debug_print(ins);
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{};
auto result = check_context(f);
double t1 = t.record<milliseconds>();
......@@ -742,6 +749,14 @@ void program::print(
}
}
void program::print(
const std::function<void(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>)>& print_func) const
{
std::unordered_map<instruction_ref, std::string> names;
this->print(names, print_func);
}
void program::print_graph(std::ostream& os, bool brief) const
{
const auto* mm = this->get_main_module();
......@@ -809,11 +824,12 @@ void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputItera
std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) {
return mod->name();
});
transform_if(m.begin(),
m.end(),
out,
[&](auto&& pp) { return not contains(used, pp.first); },
[](auto&& pp) { return &pp.second; });
transform_if(
m.begin(),
m.end(),
out,
[&](auto&& pp) { return not contains(used, pp.first); },
[](auto&& pp) { return &pp.second; });
}
std::vector<const module*> program::get_modules() const
......
......@@ -3,6 +3,8 @@
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <migraphx/program.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ref/target.hpp>
......@@ -95,7 +97,6 @@ migraphx::value to_value(py::kwargs kwargs)
auto&& val = arg.second;
visit_py(val, [&](auto py_val) { v[key] = py_val; });
}
return v;
}
} // namespace migraphx
......@@ -211,12 +212,21 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
.def(py::init<>())
.def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float"));
auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
else
return migraphx::shape(t, lens);
}))
.def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides)
.def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes)
.def("type_string", &migraphx::shape::type_string)
.def("type_size", &migraphx::shape::type_size)
.def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed)
......@@ -247,13 +257,38 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::module>(m, "module")
py::class_<migraphx::instruction_ref>(m, "instruction_ref");
py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
.def("__eq__", std::equal_to<migraphx::module>{})
.def("__ne__", std::not_equal_to<migraphx::module>{})
.def(
"add_instruction",
[](migraphx::module& mm,
const migraphx::operation& op,
std::vector<migraphx::instruction_ref>& args,
std::vector<migraphx::module*>& mod_args) {
return mm.add_instruction(op, args, mod_args);
},
py::arg("op"),
py::arg("args"),
py::arg("mod_args") = std::vector<migraphx::module*>{})
.def(
"add_parameter",
[](migraphx::module& mm, const std::string& name, const migraphx::shape shape) {
return mm.add_parameter(name, shape);
},
py::arg("name"),
py::arg("shape"))
.def(
"add_return",
[](migraphx::module& mm, std::vector<migraphx::instruction_ref>& args) {
return mm.add_return(args);
},
py::arg("args"))
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); });
py::class_<migraphx::program>(m, "program")
.def(py::init([]() { return migraphx::program(); }))
.def("get_parameter_names", &migraphx::program::get_parameter_names)
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_output_shapes", &migraphx::program::get_output_shapes)
......@@ -268,11 +303,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("t"),
py::arg("offload_copy") = true,
py::arg("fast_math") = true)
.def("get_main_module",
[](migraphx::program& p) {
auto* mm = p.get_main_module();
return *mm;
})
.def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); })
.def(
"create_module",
[](migraphx::program& p, const std::string& name) { return p.create_module(name); },
py::arg("name"))
.def("run",
[](migraphx::program& p, py::dict params) {
migraphx::parameter_map pm;
......@@ -303,89 +338,94 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("name", &migraphx::operation::name);
m.def("parse_tf",
[](const std::string& filename,
bool is_nhwc,
unsigned int batch_size,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::vector<std::string> output_names) {
return migraphx::parse_tf(
filename,
migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
},
"Parse tf protobuf (default format is nhwc)",
py::arg("filename"),
py::arg("is_nhwc") = true,
py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>());
m.def("parse_onnx",
[](const std::string& filename,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
m.def("parse_onnx_buffer",
[](const std::string& onnx_buffer,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
m.def("load",
[](const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::load(name, options);
},
"Load MIGraphX program",
py::arg("filename"),
py::arg("format") = "msgpack");
m.def("save",
[](const migraphx::program& p, const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::save(p, name, options);
},
"Save MIGraphX program",
py::arg("p"),
py::arg("filename"),
py::arg("format") = "msgpack");
m.def(
"parse_tf",
[](const std::string& filename,
bool is_nhwc,
unsigned int batch_size,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::vector<std::string> output_names) {
return migraphx::parse_tf(
filename, migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
},
"Parse tf protobuf (default format is nhwc)",
py::arg("filename"),
py::arg("is_nhwc") = true,
py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>());
m.def(
"parse_onnx",
[](const std::string& filename,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
m.def(
"parse_onnx_buffer",
[](const std::string& onnx_buffer,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
m.def(
"load",
[](const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::load(name, options);
},
"Load MIGraphX program",
py::arg("filename"),
py::arg("format") = "msgpack");
m.def(
"save",
[](const migraphx::program& p, const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::save(p, name, options);
},
"Save MIGraphX program",
py::arg("p"),
py::arg("filename"),
py::arg("format") = "msgpack");
m.def("get_target", &migraphx::make_target);
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value"));
m.def("quantize_fp16",
&migraphx::quantize_fp16,
py::arg("prog"),
......
......@@ -39,9 +39,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{
while(reduce_dim(shapes, n) and n < shapes.size())
{
}
while(reduce_dim(shapes, n) and n < shapes.size()) {}
return n + 1;
}
......
#include <migraphx/register_target.hpp>
#include <unordered_map>
#include <migraphx/register_target.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -11,7 +11,17 @@ std::unordered_map<std::string, target>& target_map()
}
void register_target(const target& t) { target_map()[t.name()] = t; }
target make_target(const std::string& name) { return target_map().at(name); }
target make_target(const std::string& name)
{
const auto it = target_map().find(name);
if(it == target_map().end())
{
MIGRAPHX_THROW("Requested target '" + name + "' is not enabled or not supported");
}
return it->second;
}
std::vector<std::string> get_targets()
{
std::vector<std::string> result;
......
......@@ -38,7 +38,7 @@ void rewrite_pooling::apply(module& prog) const
instruction_ref pooling{};
// average pooling
if(op.mode == "average")
if(op.mode == op::pooling_mode::average)
{
pooling =
prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
......
......@@ -1426,14 +1426,5 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
return hs_padded;
}
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -86,6 +86,8 @@ struct shape_impl
return std::accumulate(
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
};
const std::vector<shape::type_t>& shape::types()
......@@ -135,6 +137,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
shape shape::from_permutation(type_t t,
const std::vector<std::size_t>& l,
const std::vector<int64_t>& perm)
......@@ -294,6 +298,13 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const
return this->with_lens(this->type(), l);
}
shape shape::with_type(type_t t) const
{
auto c = impl->copy();
c->m_type = t;
return {c};
}
std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); }
......
......@@ -335,7 +335,6 @@ struct find_concat_op
}
auto y = p.insert_instruction(ins, op, concats);
return {y};
};
std::vector<instruction_ref> args;
......
......@@ -316,7 +316,6 @@ struct find_nested_concat
else
args.push_back(i);
}
})(ins->inputs());
p.replace_instruction(ins, ins->get_operator(), args);
}
......
......@@ -213,7 +213,6 @@ template <std::size_t N, class... Xs>
bool is_vectorizable(const Xs&... xs)
{
return all_of({xs...}, [](const auto& s) {
if(s.standard() and (s.lens().back() % N) == 0)
return true;
if(s.broadcasted())
......
......@@ -460,10 +460,10 @@ struct cpu_apply
if(has_op("dnnl::pooling") and ins->get_shape().type() == shape::type_t::float_type and
not v["ceil_mode"].to<bool>())
return replace(ins, make_op("dnnl::pooling", op.to_value()));
std::string mode = v["mode"].to<std::string>();
if(mode == "max")
op::pooling_mode mode = v["mode"].to<op::pooling_mode>();
if(mode == op::pooling_mode::max)
return replace(ins, make_op("cpu::pooling_max", v));
else if(mode == "average")
else if(mode == op::pooling_mode::average)
return replace(ins, make_op("cpu::pooling_average", v));
return ins;
}
......
......@@ -129,7 +129,8 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
auto algo = op.mode == "max" ? dnnl::algorithm::pooling_max : dnnl::algorithm::pooling_avg;
auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg;
auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
......@@ -145,5 +146,6 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -11,7 +11,7 @@ if(NOT TARGET MIOpen)
endif()
include(Embed)
file(GLOB KERNEL_FILES
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES})
......@@ -119,6 +119,7 @@ target_link_libraries(kernel_file_check compile_for_gpu)
rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
add_library(migraphx_gpu
abs.cpp
analyze_streams.cpp
......@@ -131,8 +132,7 @@ add_library(migraphx_gpu
compile_ops.cpp
compile_hip.cpp
compile_hip_code_object.cpp
compile_pointwise.cpp
compile_roialign.cpp
compiler.cpp
concat.cpp
convert.cpp
convolution.cpp
......@@ -172,6 +172,7 @@ add_library(migraphx_gpu
target.cpp
topk.cpp
write_literals.cpp
${JIT_GPU_SRCS}
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
......@@ -342,6 +343,12 @@ target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}"
"-DMIGRAPHX_USE_HIPRTC=0"
)
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
endif()
# Check miopen find mode api
......
......@@ -178,6 +178,12 @@ bool is_hip_clang_compiler()
return result;
}
bool has_compiler_launcher()
{
static const auto result = fs::exists(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER));
return result;
}
std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{
......@@ -210,6 +216,10 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
src_compiler compiler;
compiler.flags = params;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER);
#endif
if(is_hcc_compiler())
compiler.process = [&](const fs::path& obj_path) -> fs::path {
......@@ -238,13 +248,6 @@ std::string enum_params(std::size_t count, std::string param)
return join_strings(items, ",");
}
std::size_t compute_global(std::size_t n, std::size_t local)
{
std::size_t groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return nglobal;
}
#endif // MIGRAPHX_USE_HIPRTC
} // namespace gpu
......
......@@ -93,6 +93,32 @@ const std::vector<std::string>& compiler_warnings()
return warnings;
}
void hip_compile_options::set_launch_params(
const value& v,
const std::function<std::size_t(std::size_t local)>& compute_global,
std::size_t default_local)
{
local = v.get("local", default_local);
if(v.contains("global"))
global = v.at("global").to<std::size_t>();
else
global = compute_global(local);
}
std::function<std::size_t(std::size_t local)>
compute_global_for(context& ctx, std::size_t n, std::size_t over)
{
assert(over > 0);
std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) {
std::size_t groups = (n + local - 1) / local;
std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local;
return nglobal;
};
}
operation compile_hip_code_object(const std::string& content, hip_compile_options options)
{
std::vector<src_file> srcs;
......
......@@ -6,12 +6,14 @@
#include <migraphx/par_for.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compile_pointwise.hpp>
#include <migraphx/gpu/compiler.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op
{
operation op = op::identity{};
......@@ -38,41 +40,22 @@ struct precompile_op
MIGRAPHX_REGISTER_OP(precompile_op);
struct pointwise_compiler
struct compiled_result
{
std::string name() const { return "pointwise"; }
operation apply(context& ctx, instruction_ref ins, const operation&) const
{
assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front();
return compile_pointwise(ctx, to_shapes(ins->inputs()), *pm);
}
compiler_replace replace;
instruction_ref ins;
};
using compiler_function = std::function<operation(context&, instruction_ref, operation)>;
template <class T>
compiler_function make_compiler_function(T x)
template <class F>
void par_compile(std::size_t n, F f)
{
return {[=](auto&&... xs) { return x.apply(xs...); }};
if(n == 0)
return;
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
}
template <class... Ts>
std::unordered_map<std::string, compiler_function> make_compilers(Ts... xs)
{
return {{xs.name(), make_compiler_function(xs)}...};
}
struct compiled_result
{
operation op;
instruction_ref ins;
};
void compile_ops::apply(module& m) const
{
auto compilers = make_compilers(pointwise_compiler{});
std::vector<std::function<compiled_result()>> compiles;
for(auto ins : iterator_for(m))
......@@ -80,15 +63,15 @@ void compile_ops::apply(module& m) const
if(ins->name() != "gpu::precompile_op")
continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op;
assert(contains(compilers, preop.name()));
auto c = compilers[preop.name()];
compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; });
compiles.emplace_back([=]() -> compiled_result {
return {compile(*ctx, ins, preop), ins};
});
}
std::vector<compiled_result> results(compiles.size());
par_for(compiles.size(), 1, [&](auto i) { results[i] = compiles[i](); });
par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results)
{
m.replace_instruction(cr.ins, cr.op, cr.ins->inputs());
cr.replace(m, cr.ins);
}
}
......
#include <migraphx/gpu/compile_pointwise.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void kernel(${params})
{
pointwise(${lambda}, ${args});
}
}
} // namespace migraphx
int main() {}
)__migraphx__";
operation compile_pointwise(context&,
const std::vector<shape>& inputs,
const std::string& lambda,
const std::string& preamble)
{
hip_compile_options options;
options.global = compute_global(inputs.front().elements());
options.local = 1024;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"lambda", lambda},
{"preamble", preamble}});
return compile_hip_code_object(src, options);
}
operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, module m)
{
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign", "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
auto name =
g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m));
return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str());
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/compiler.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
auto& compiler_map()
{
static std::unordered_map<std::string, compiler_compile> m; // NOLINT
return m;
}
auto& compiler_op_map()
{
static std::unordered_map<std::string, compiler_compile_op> m; // NOLINT
return m;
}
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop)
{
compiler_map()[name] = std::move(c);
compiler_op_map()[name] = std::move(cop);
}
bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op)
{
return compiler_map().at(op.name())(ctx, ins, op);
}
operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v)
{
return compiler_op_map().at(name)(ctx, inputs, v);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment