Commit 80bf741a authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into ck-int8-fusion

parents 99626b4c 0e6ee3f7
...@@ -100,10 +100,10 @@ struct find_static_2in_broadcasts ...@@ -100,10 +100,10 @@ struct find_static_2in_broadcasts
} // namespace } // namespace
/** /**
* Makes all the shapes in the dynamic_dimension range. * Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those * and `loop` instructions, depending on how the submodules for those
* work. Inserts select_module instruction to the top. Replaces return, bypassing other * work. Inserts select_module instruction to the top. Replaces return, bypassing other
* instructions. * instructions. Skips if the dynamic parameter outputs to a select_module operator.
*/ */
void split_single_dyn_dim::apply(module_pass_manager& mpm) const void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{ {
...@@ -111,7 +111,13 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const ...@@ -111,7 +111,13 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
auto param_names = mm->get_parameter_names(); auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes(); auto param_shapes = mm->get_parameter_shapes();
optional<dynamic_dimensions_check> dd_check = has_one_dyn_dim(param_shapes); optional<dynamic_dimensions_check> dd_check = has_one_dyn_dim(param_shapes);
if(dd_check.has_value()) auto any_sm_next = [&](auto ddc) {
auto p_outputs = mm->get_parameter(ddc->dyn_param_str)->outputs();
return std::any_of(p_outputs.cbegin(), p_outputs.cend(), [](auto ins) {
return ins->name() == "select_module";
});
};
if(dd_check.has_value() and not any_sm_next(dd_check))
{ {
const auto& dyn_param = mm->get_parameter(dd_check->dyn_param_str); const auto& dyn_param = mm->get_parameter(dd_check->dyn_param_str);
auto dyn_param_shape = mm->get_parameter_shape(dd_check->dyn_param_str); auto dyn_param_shape = mm->get_parameter_shape(dd_check->dyn_param_str);
......
...@@ -41,7 +41,7 @@ class x_model ...@@ -41,7 +41,7 @@ class x_model
void set_shape(migraphx::shape); void set_shape(migraphx::shape);
}; };
x_model create_xmodel(migraphx::module_ref mod); x_model create_xmodel(migraphx::const_module_ref mod);
migraphx::argument execute(const x_model& xmodel, migraphx::argument execute(const x_model& xmodel,
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
......
...@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; }; ...@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; };
void x_model::set_shape(migraphx::shape s) { shape = s; } void x_model::set_shape(migraphx::shape s) { shape = s; }
x_model create_xmodel(const migraphx::module_ref mod) x_model create_xmodel(migraphx::const_module_ref mod)
{ {
std::cout << "Calling an external function: create_xmodel!\n"; std::cout << "Calling an external function: create_xmodel!\n";
x_model xmodel; x_model xmodel;
......
...@@ -33,11 +33,7 @@ if(NOT TARGET MIOpen) ...@@ -33,11 +33,7 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
if(BUILD_DEV) set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipClang APIs")
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
else()
set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs")
endif()
include(Embed) include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS} file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -171,7 +172,8 @@ std::string make_transformer_args(std::vector<std::string> transformers) ...@@ -171,7 +172,8 @@ std::string make_transformer_args(std::vector<std::string> transformers)
void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name) void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name)
{ {
module m = pm; module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}}); run_passes(m,
{rewrite_quantization{}, eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g; cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
...@@ -280,6 +282,14 @@ std::string generate_reduce(const module& m, const std::string& name) ...@@ -280,6 +282,14 @@ std::string generate_reduce(const module& m, const std::string& name)
not input->get_shape().broadcasted(); not input->get_shape().broadcasted();
}); });
auto inner_names = names; auto inner_names = names;
for(auto input : ins->inputs())
{
if(input->name() != "@param")
continue;
if(contains(tensors, input))
continue;
inner_names[input] += "[out_idx]";
}
for(auto input : tensors) for(auto input : tensors)
inner_names[input] += "_lambda_param"; inner_names[input] += "_lambda_param";
auto call_function = auto call_function =
...@@ -308,6 +318,8 @@ std::string generate_reduce(const module& m, const std::string& name) ...@@ -308,6 +318,8 @@ std::string generate_reduce(const module& m, const std::string& name)
}); });
f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name); f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name);
f.add_generic_param("r"); f.add_generic_param("r");
f.add_generic_param("out_idx");
f.unused_param("out_idx");
g.create_function(f); g.create_function(f);
return g.str(); return g.str();
} }
......
...@@ -56,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC); ...@@ -56,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#ifdef MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
std::string hiprtc_error(hiprtcResult err, const std::string& msg) std::string hiprtc_error(hiprtcResult err, const std::string& msg)
{ {
return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg)); return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg));
...@@ -194,6 +191,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr ...@@ -194,6 +191,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
options.push_back("-DMIGRAPHX_HAS_DPP=0"); options.push_back("-DMIGRAPHX_HAS_DPP=0");
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"); options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
options.push_back("-Wno-reserved-identifier"); options.push_back("-Wno-reserved-identifier");
options.push_back("-Wno-unused-parameter");
options.push_back("-Wno-gnu-line-marker"); options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast"); options.push_back("-Wno-old-style-cast");
} }
...@@ -216,6 +214,15 @@ std::vector<std::vector<char>> ...@@ -216,6 +214,15 @@ std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{ {
std::vector<hiprtc_src_file> hsrcs{srcs.begin(), srcs.end()}; std::vector<hiprtc_src_file> hsrcs{srcs.begin(), srcs.end()};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
for(const auto& src : srcs)
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
}
}
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc); auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
auto driver = p.parent_path().parent_path() / "bin" / "migraphx-hiprtc-driver"; auto driver = p.parent_path().parent_path() / "bin" / "migraphx-hiprtc-driver";
......
...@@ -135,10 +135,15 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) ...@@ -135,10 +135,15 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t max_global = ctx.get_current_device().get_cu_count() * std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu(); ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) { return [n, over, max_global](std::size_t local) {
std::size_t groups = (n + local - 1) / local; std::size_t num_elements = n;
std::size_t max_blocks = max_global / local; std::size_t groups = (num_elements + local - 1) / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local; std::size_t max_blocks = max_global / local;
return std::min(nglobal, n); std::size_t nglobal = std::min(max_blocks * over, groups) * local;
#ifdef MIGRAPHX_USE_HIPRTC
if(enabled(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{}))
num_elements = ((num_elements + local - 1) / local) * local;
#endif
return std::min(nglobal, num_elements);
}; };
} }
......
...@@ -94,6 +94,10 @@ template <> ...@@ -94,6 +94,10 @@ template <>
struct is_hip_type<std::uint8_t> : std::true_type struct is_hip_type<std::uint8_t> : std::true_type
{ {
}; };
template <>
struct is_hip_type<std::int32_t> : std::true_type
{
};
template <class T, class V, MIGRAPHX_REQUIRES(is_hip_type<typename T::type>{})> template <class T, class V, MIGRAPHX_REQUIRES(is_hip_type<typename T::type>{})>
void hip_visitor_invoke(T as, V&& v) void hip_visitor_invoke(T as, V&& v)
...@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
if(not std::all_of( if(not std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); })) types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { visit_tensor_size(s.ndim(), [&](auto ndim) {
s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); })); s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
}); });
} }
...@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); }); visit_tensor_size(s.ndim(), [&](auto ndim) { v(f(xs, ndim)...); });
} }
template <class F> template <class F>
......
...@@ -37,22 +37,26 @@ argument scatter( ...@@ -37,22 +37,26 @@ argument scatter(
hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis) hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis)
{ {
auto ds = arg0.get_shape(); auto ds = arg0.get_shape();
auto inds = arg1.get_shape(); auto s1 = arg1.get_shape();
auto axis_dim_size = ds.lens()[axis]; auto axis_dim_size = ds.lens()[axis];
hip_visit_all(result, arg0, inds)([&](auto output, auto data, auto s1) { hip_visit_all(result, arg0, arg2)([&](auto output, auto data, auto update) {
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
const auto* data_ptr = device_cast(data.data()); const auto* data_ptr = device_cast(data.data());
gs_launch(stream, ds.elements())([=](auto i) __device__ { output_ptr[i] = data_ptr[i]; }); gs_launch(stream, ds.elements())([=](auto i) __device__ { output_ptr[i] = data_ptr[i]; });
hip_visit_all(arg1, arg2)([&](auto indices, auto update) {
const auto* upd_ptr = device_cast(update.data()); hip_visit_all(arg1)([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data()); if constexpr(indices.get_shape().lens.size() == output.get_shape().lens.size())
gs_launch(stream, inds.elements())([=](auto i) __device__ { {
auto out_idx = s1.multi(i); const auto* upd_ptr = device_cast(update.data());
auto index = indices_ptr[i]; const auto* indices_ptr = device_cast(indices.data());
index = index < 0 ? index + axis_dim_size : index; gs_launch(stream, s1.elements())([=](auto i) __device__ {
out_idx[axis] = index; auto out_idx = indices.get_shape().multi(i);
output[out_idx] = upd_ptr[i]; auto index = indices_ptr[i];
}); index = index < 0 ? index + axis_dim_size : index;
out_idx[axis] = index;
output[out_idx] = upd_ptr[i];
});
}
}); });
}); });
......
...@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp ...@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
add_executable(gpu-driver add_executable(gpu-driver
${GPU_DRIVER_SRCS} ${GPU_DRIVER_SRCS}
) )
rocm_clang_tidy_check(gpu-driver)
target_include_directories(gpu-driver PRIVATE include) target_include_directories(gpu-driver PRIVATE include)
target_link_libraries(gpu-driver PRIVATE migraphx_gpu) target_link_libraries(gpu-driver PRIVATE migraphx_gpu)
...@@ -44,7 +44,7 @@ struct auto_register_action ...@@ -44,7 +44,7 @@ struct auto_register_action
template <class T> template <class T>
static void apply() static void apply()
{ {
auto name = get_type_name<T>(); const auto& name = get_type_name<T>();
register_action(name.substr(name.rfind("::") + 2), register_action(name.substr(name.rfind("::") + 2),
[](auto&&... xs) { T::apply(std::forward<decltype(xs)>(xs)...); }); [](auto&&... xs) { T::apply(std::forward<decltype(xs)>(xs)...); });
} }
......
...@@ -38,6 +38,27 @@ namespace gpu { ...@@ -38,6 +38,27 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR);
bool mlir_enabled()
{
#ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{});
if(mlir_enabled)
{
return true;
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
return false;
}
#else
return false;
#endif
}
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
struct mlir_op struct mlir_op
...@@ -58,8 +79,41 @@ struct mlir_op ...@@ -58,8 +79,41 @@ struct mlir_op
MIGRAPHX_THROW("should have one submodule."); MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2) if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size();
return op.compute_shape({inputs[n - 2], inputs[n - 1]}); module_ref mod = mods[0];
auto type = mod->get_output_shapes().front().type();
std::unordered_map<instruction_ref, shape> ins_shapes;
size_t param_cnt = 0;
std::vector<std::string> names = mod->get_parameter_names();
std::sort(names.begin(), names.end());
for(std::string param_name : names)
{
ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
}
for(auto ins : iterator_for(*mod))
{
if(ins->name() == "@param")
{
continue;
}
if(ins->name() == "@literal")
{
ins_shapes[ins] = ins->get_shape();
continue;
}
if(ins->name() == "@return")
{
return ins_shapes[ins->inputs().at(0)].with_type(type);
}
std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size());
std::transform(ins->inputs().begin(),
ins->inputs().end(),
input_shapes.begin(),
[&](auto in) { return ins_shapes[in]; });
ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes);
}
MIGRAPHX_THROW("No return found in the submodule");
} }
}; };
MIGRAPHX_REGISTER_OP(mlir_op); MIGRAPHX_REGISTER_OP(mlir_op);
...@@ -68,7 +122,7 @@ namespace { ...@@ -68,7 +122,7 @@ namespace {
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{ {
if(ins->name() != "convolution") if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false; return false;
value v = ins->get_operator().to_value(); value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>(); auto group = v.at("group").to<int>();
...@@ -89,6 +143,53 @@ struct find_mlir_op ...@@ -89,6 +143,53 @@ struct find_mlir_op
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
} }
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) const
{
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
{
if(ins->name() != "@literal")
{
continue;
}
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
}
return ins_map;
}
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) const
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -98,33 +199,42 @@ struct find_mlir_op ...@@ -98,33 +199,42 @@ struct find_mlir_op
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
// Whitelist pointwise operators // Whitelist pointwise operators
if(std::any_of(pm->begin(), pm->end(), [](const auto& i) { if(std::any_of(pm->begin(), pm->end(), [](const auto& i) {
return not contains( return not contains({"@literal",
{"@literal", "@param", "@return", "convolution", "dot", "add", "relu"}, "@param",
i.name()); "@return",
"convolution",
"quant_convolution",
"dot",
"add",
"relu",
"dequantizelinear",
"quantizelinear",
"mul"},
i.name());
})) }))
return; return;
// Only fuse with fp32/fp16 // Only fuse with fp32/fp16/int8/int32
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) { if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type, shape::type_t::half_type}, return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::int8_type,
shape::type_t::int32_type},
i->get_shape().type()); i->get_shape().type());
})) }))
return; return;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name()); module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass(); mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map; std::unordered_map<instruction_ref, instruction_ref> param_map =
auto x = mm->add_parameter("x" + std::to_string(names.size()), create_param_map_with_literals(mm, pm, gemm_based_op->get_shape());
gemm_based_op->inputs().at(0)->get_shape()); auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op);
auto w = mm->add_parameter("x" + std::to_string(names.size() + 1),
gemm_based_op->inputs().at(1)->get_shape());
auto conv = mm->add_instruction(gemm_based_op->get_operator(), {x, w});
std::transform(names.begin(), std::transform(names.begin(),
names.end(), names.end(),
ins->inputs().begin(), ins->inputs().begin(),
std::inserter(param_map, param_map.end()), std::inserter(param_map, param_map.end()),
[&](auto name, auto input) { [&, &anchor_op = anchor_op](auto name, auto input) {
if(input == x_ins) if(input == x_ins)
return std::make_pair(pm->get_parameter(name), conv); return std::make_pair(pm->get_parameter(name), anchor_op);
return std::make_pair(pm->get_parameter(name), return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape())); mm->add_parameter(name, input->get_shape()));
}); });
...@@ -135,7 +245,7 @@ struct find_mlir_op ...@@ -135,7 +245,7 @@ struct find_mlir_op
ins->inputs().end(), ins->inputs().end(),
std::back_inserter(inputs), std::back_inserter(inputs),
[&](auto input) { return input != gemm_based_op; }); [&](auto input) { return input != gemm_based_op; });
inputs.insert(inputs.end(), gemm_based_op->inputs().begin(), gemm_based_op->inputs().end()); inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
mpm.get_module().replace_instruction( mpm.get_module().replace_instruction(
ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm}); ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
} }
...@@ -148,17 +258,7 @@ struct find_mlir_op ...@@ -148,17 +258,7 @@ struct find_mlir_op
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{}); match::find_matches(mpm, find_mlir_op{});
if(mlir_enabled)
{
match::find_matches(mpm, find_mlir_op{});
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
}
#else #else
(void)mpm; (void)mpm;
#endif #endif
......
...@@ -157,13 +157,8 @@ void gemm_impl(context& ctx, ...@@ -157,13 +157,8 @@ void gemm_impl(context& ctx,
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag = rocblas_gemm_flags flag =
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none; int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#else
(void)int8_x4_format;
int flag = 0;
#endif
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/compile_src.hpp> #include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -36,6 +37,11 @@ namespace migraphx { ...@@ -36,6 +37,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
#ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
#endif
struct hiprtc_src_file struct hiprtc_src_file
{ {
hiprtc_src_file() = default; hiprtc_src_file() = default;
......
...@@ -34,6 +34,8 @@ struct module_pass_manager; ...@@ -34,6 +34,8 @@ struct module_pass_manager;
namespace gpu { namespace gpu {
bool mlir_enabled();
struct fuse_mlir struct fuse_mlir
{ {
context* ctx = nullptr; context* ctx = nullptr;
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/dyn_output.hpp>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -112,7 +113,7 @@ struct hip_copy_to_gpu ...@@ -112,7 +113,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; } std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2).same_type(); check_shapes{inputs, *this, true}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -121,6 +122,10 @@ struct hip_copy_to_gpu ...@@ -121,6 +122,10 @@ struct hip_copy_to_gpu
if(args.size() == 1) if(args.size() == 1)
return input; return input;
argument result = args[1].share(); argument result = args[1].share();
if(result.get_shape().dynamic())
{
result = result.reshape(args[0].get_shape());
}
gpu_copy(ctx, input, result); gpu_copy(ctx, input, result);
// Associate the input since it was registered with hip // Associate the input since it was registered with hip
return {result.get_shape(), [input, result]() mutable { return result.data(); }}; return {result.get_shape(), [input, result]() mutable { return result.data(); }};
...@@ -138,19 +143,24 @@ struct hip_copy_from_gpu ...@@ -138,19 +143,24 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; } std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2).same_type(); check_shapes{inputs, *this, true}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const dyn_output& dyn_out, const std::vector<argument>& args) const
{ {
if(args.size() == 1) if(args.size() == 1)
{ {
argument result = allocate_gpu(output_shape, true); argument result = allocate_gpu(dyn_out.computed_shape, true);
gpu_copy(ctx, args[0], result); gpu_copy(ctx, args[0], result);
return result; return result;
} }
copy_from_gpu(ctx, args[0], args[1]); argument input = args[0].share();
if(input.get_shape().dynamic())
{
input = input.reshape(args[1].get_shape());
}
copy_from_gpu(ctx, input, args[1]);
return args[1]; return args[1];
} }
std::ptrdiff_t output_alias(const std::vector<shape>& args) const std::ptrdiff_t output_alias(const std::vector<shape>& args) const
......
...@@ -28,10 +28,6 @@ ...@@ -28,10 +28,6 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/math_functions.h> #include <hip/math_functions.h>
#include <hip/hip_math_constants.h>
#elif defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS)
#include <hip/hip_common.h>
#include <hip/hip_math_constants.h>
#endif #endif
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP #endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
...@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor) ...@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt)
// Use float to compute half overload // Use float to compute half overload
...@@ -176,7 +176,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log) ...@@ -176,7 +176,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10) MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10)
MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2) MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2)
MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt) MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt)
// MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin) MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt) MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt)
template <class T, class U> template <class T, class U>
...@@ -217,14 +217,6 @@ constexpr auto min(const T& a, const U& b) ...@@ -217,14 +217,6 @@ constexpr auto min(const T& a, const U& b)
return min<common_type_t<T, U>>(a, b); return min<common_type_t<T, U>>(a, b);
} }
// Sin for half is broken on hip, so use cos instead
template <class T, MIGRAPHX_REQUIRES(is_same<vec_type<T>, half>{})>
constexpr T sin(T x)
{
constexpr const T shift = HIP_PIO2_F;
return migraphx::cos(shift - x);
}
MIGRAPHX_DEVICE_MATH_VEC(abs) MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos) MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh) MIGRAPHX_DEVICE_MATH_VEC(acosh)
......
...@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs) ...@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs)
template <class... Ts> template <class... Ts>
__device__ void println(Ts... xs) __device__ void println(Ts... xs)
{ {
print_each(&coutln, xs...); print_each(&cout, xs..., '\n');
} }
template <class... Ts> template <class... Ts>
__device__ void println_once(Ts... xs) __device__ void println_once(Ts... xs)
{ {
print_each_once(&coutln, xs...); print_each_once(&cout, xs..., '\n');
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -79,20 +79,21 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -79,20 +79,21 @@ __device__ void dpp_reduce(T& in, Op op)
#endif #endif
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix) \ #define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \ __device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \ __device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \ __device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(int32_t& x, op) \ __device__ inline void dpp_reduce(int32_t& x, op) \
{ \ { \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); \ MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32); \
} \ } \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); } __device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
MIGRAPHX_DPP_REDUCE(op::sum, v_add) // Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE(op::max, v_max) MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u)
MIGRAPHX_DPP_REDUCE(op::min, v_min) MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
MIGRAPHX_DPP_REDUCE(op::product, v_mul) MIGRAPHX_DPP_REDUCE(op::max, v_max, _i)
MIGRAPHX_DPP_REDUCE(op::min, v_min, _i)
template <class Op, class T, class Index, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
...@@ -570,7 +571,7 @@ template <class Algo, class Reduced, class Output, class F> ...@@ -570,7 +571,7 @@ template <class Algo, class Reduced, class Output, class F>
__device__ void fused_reduce(Output output, F f) __device__ void fused_reduce(Output output, F f)
{ {
Algo::template run<Reduced>([&](auto out_idx, auto r) { Algo::template run<Reduced>([&](auto out_idx, auto r) {
auto result = f(r); auto result = f(r, out_idx);
if constexpr(reduce::is_inner_storage<decltype(result)>{}) if constexpr(reduce::is_inner_storage<decltype(result)>{})
{ {
r.inner([&](auto& y, auto x) { y = x; })(output, result); r.inner([&](auto& y, auto x) { y = x; })(output, result);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment