Commit baac1dab authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into ck-host-lib

parents 830dff7a 77042e30
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
*/ */
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/check_context.hpp>
#include <migraphx/adjust_allocation.hpp> #include <migraphx/adjust_allocation.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_allocation.hpp> #include <migraphx/eliminate_allocation.hpp>
...@@ -83,7 +82,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -83,7 +82,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
simplify_reshapes{}, simplify_reshapes{},
layout_nhwc{},
dead_code_elimination{}, dead_code_elimination{},
simplify_reshapes{}, simplify_reshapes{},
simplify_algebra{}, simplify_algebra{},
......
...@@ -43,14 +43,11 @@ struct target ...@@ -43,14 +43,11 @@ struct target
std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const; std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const;
migraphx::context get_context() const { return context{}; } migraphx::context get_context() const { return context{}; }
supported_segments find_supported(const_module_ref mod, support_metric m) const; supported_segments find_supported(const_module_ref mod, support_metric m) const;
argument copy_to(const argument& arg) const { return arg; } argument copy_to(const argument& arg) const { return arg; }
argument copy_from(const argument& arg) const { return arg; } argument copy_from(const argument& arg) const { return arg; }
argument allocate(const shape& s) const; argument allocate(const shape& s) const;
}; };
MIGRAPHX_REGISTER_TARGET(target);
} // namespace fpga } // namespace fpga
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -41,7 +41,7 @@ class x_model ...@@ -41,7 +41,7 @@ class x_model
void set_shape(migraphx::shape); void set_shape(migraphx::shape);
}; };
x_model create_xmodel(migraphx::module_ref mod); x_model create_xmodel(migraphx::const_module_ref mod);
migraphx::argument execute(const x_model& xmodel, migraphx::argument execute(const x_model& xmodel,
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
......
...@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; }; ...@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; };
void x_model::set_shape(migraphx::shape s) { shape = s; } void x_model::set_shape(migraphx::shape s) { shape = s; }
x_model create_xmodel(const migraphx::module_ref mod) x_model create_xmodel(migraphx::const_module_ref mod)
{ {
std::cout << "Calling an external function: create_xmodel!\n"; std::cout << "Calling an external function: create_xmodel!\n";
x_model xmodel; x_model xmodel;
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# THE SOFTWARE. # THE SOFTWARE.
# #################################################################################### # ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc) list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(miopen) find_package(miopen)
# rocblas # rocblas
...@@ -33,7 +33,7 @@ if(NOT TARGET MIOpen) ...@@ -33,7 +33,7 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs") set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipClang APIs")
include(Embed) include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS} file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
...@@ -172,21 +172,6 @@ register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp ...@@ -172,21 +172,6 @@ register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
get_filename_component(CMAKE_CXX_COMPILER_PATH "${CMAKE_CXX_COMPILER}" PATH)
if(NOT CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+$")
find_program(MIGRAPHX_EXTRACT_KERNEL extractkernel
PATH_SUFFIXES bin
HINTS ${CMAKE_CXX_COMPILER_PATH}
PATHS
/opt/rocm/hip
/opt/rocm/hcc
/opt/rocm
)
endif()
message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "") set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR) if(MIGRAPHX_ENABLE_MLIR)
...@@ -228,7 +213,6 @@ else() ...@@ -228,7 +213,6 @@ else()
target_compile_definitions(migraphx_gpu PRIVATE target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}" "-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}" "-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
"-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}"
) )
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER) if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
...@@ -244,8 +228,7 @@ get_target_property(MIOPEN_LOCATION MIOpen LOCATION) ...@@ -244,8 +228,7 @@ get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API) check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API) check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# TODO: Set default to HAS_FIND_2_API set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
set(MIGRAPHX_USE_FIND_2_API OFF CACHE BOOL "")
if(MIGRAPHX_USE_FIND_2_API) if(MIGRAPHX_USE_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API) target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
...@@ -263,12 +246,11 @@ endif() ...@@ -263,12 +246,11 @@ endif()
find_package(composable_kernel 1.0.0 COMPONENTS jit_library REQUIRED) find_package(composable_kernel 1.0.0 COMPONENTS jit_library REQUIRED)
# Workaround broken rocblas headers
target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1)
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels composable_kernel::jit_library) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels composable_kernel::jit_library)
add_subdirectory(driver) add_subdirectory(driver)
add_subdirectory(hiprtc)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_gpu migraphx_device compile_for_gpu TARGETS migraphx_gpu migraphx_device compile_for_gpu
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -168,10 +169,11 @@ std::string make_transformer_args(std::vector<std::string> transformers) ...@@ -168,10 +169,11 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return join_strings(std::move(transformers), ", "); return join_strings(std::move(transformers), ", ");
} }
std::string generate_pointwise(const module& pm, const std::string& name) void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name)
{ {
module m = pm; module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}}); run_passes(m,
{rewrite_quantization{}, eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g; cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
...@@ -184,8 +186,141 @@ std::string generate_pointwise(const module& pm, const std::string& name) ...@@ -184,8 +186,141 @@ std::string generate_pointwise(const module& pm, const std::string& name)
// Add explict conversions // Add explict conversions
g.fresult( g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; }); [](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
g.create_function( gg.create_function(g.generate_module(m)
g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m).set_name(name)); .set_attributes({"__device__", "__attribute__((const))"})
.set_generic_types(m)
.set_name(name));
}
std::string generate_pointwise(const module& pm, const std::string& name)
{
cpp_generator g;
generate_pointwise(g, pm, name);
return g.str();
}
std::string reduce_op::str() const
{
return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" + input + "))";
}
void reduce_op::set(instruction_ref ins, const operation& op)
{
if(op.name() == "reduce_sum")
{
reduction = "op::sum{}";
}
else if(op.name() == "reduce_mean")
{
auto s = ins->inputs().front()->get_shape();
auto reduce_elements = s.elements() / ins->get_shape().elements();
auto reduce_type = s.type();
reduction = "op::sum{}";
std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384)
read = "compose(" + mean + ", op::convert_to<float>{})";
else if(contains({shape::float_type, shape::half_type, shape::double_type}, reduce_type))
read = mean;
else
write = mean;
}
else if(op.name() == "reduce_max")
{
reduction = "op::max{}";
init = "lowest{}";
}
else if(op.name() == "reduce_min")
{
reduction = "op::min{}";
init = "highest{}";
}
else if(op.name() == "reduce_prod")
{
reduction = "op::product{}";
init = "1";
}
else
{
MIGRAPHX_THROW("Unsupported reduce");
}
}
std::string reduce_op::generate(instruction_ref ins, const std::string& x)
{
reduce_op r{x};
r.set(ins, ins->get_operator());
return r.str();
}
static bool use_lazy_inner(instruction_ref ins)
{
if(ins->outputs().size() != 1)
return false;
auto output = ins->outputs().front();
return contains(output->name(), "reduce") or output->name() == "@return";
}
std::string generate_reduce(const module& m, const std::string& name)
{
cpp_generator g;
auto ilens = m.get_parameter_shapes().begin()->second.lens();
std::size_t i = 0;
auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) {
if(contains(ins->name(), "reduce"))
{
return reduce_op::generate(ins, names.at(ins->inputs().front()));
}
else if(ins->name() == "pointwise")
{
auto pointwise_name = "pointwise" + std::to_string(i);
i++;
generate_pointwise(g, *ins->module_inputs().front(), pointwise_name);
std::vector<instruction_ref> tensors;
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(tensors),
[&](auto input) {
return input->get_shape().lens() == ilens and
not input->get_shape().broadcasted();
});
auto inner_names = names;
for(auto input : ins->inputs())
{
if(input->name() != "@param")
continue;
if(contains(tensors, input))
continue;
inner_names[input] += "[out_idx]";
}
for(auto input : tensors)
inner_names[input] += "_lambda_param";
auto call_function =
pointwise_name + "(" +
join_strings(cpp_generator::to_args(ins->inputs(), inner_names), ", ") + ")";
if(tensors.empty())
return call_function;
const std::string inner_template =
"r.${inner}([=](${params}) { return ${call}; })(${args})";
std::string inner_name = use_lazy_inner(ins) ? "lazy_inner" : "inner";
auto args = cpp_generator::to_args(tensors, names);
auto params = cpp_generator::to_args(tensors, inner_names);
std::transform(
params.begin(), params.end(), params.begin(), [](auto s) { return "auto " + s; });
return interpolate_string(inner_template,
{{"inner", inner_name},
{"params", join_strings(params, ", ")},
{"args", join_strings(args, ", ")},
{"call", call_function}});
}
else if(ins->name() == "multibroadcast")
{
return names.at(ins->inputs().front());
}
MIGRAPHX_THROW("Unknown operator: " + ins->name());
});
f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name);
f.add_generic_param("r");
f.add_generic_param("out_idx");
f.unused_param("out_idx");
g.create_function(f);
return g.str(); return g.str();
} }
...@@ -196,7 +331,17 @@ static std::vector<std::string> get_op_names(const module& m) ...@@ -196,7 +331,17 @@ static std::vector<std::string> get_op_names(const module& m)
{ {
if(starts_with(ins.name(), "@")) if(starts_with(ins.name(), "@"))
continue; continue;
result.push_back(ins.name()); if(ins.name() == "multibroadcast")
continue;
if(ins.name() == "pointwise")
{
auto names = get_op_names(*ins.module_inputs().front());
result.insert(result.end(), names.begin(), names.end());
}
else
{
result.push_back(ins.name());
}
} }
return result; return result;
} }
......
...@@ -32,6 +32,13 @@ ...@@ -32,6 +32,13 @@
#ifdef MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h> #include <hip/hiprtc.h>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/value.hpp>
#include <migraphx/tmp_dir.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/process.hpp>
#include <migraphx/msgpack.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/file_buffer.hpp>
#else #else
#include <migraphx/compile_src.hpp> #include <migraphx/compile_src.hpp>
#include <migraphx/process.hpp> #include <migraphx/process.hpp>
...@@ -49,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC); ...@@ -49,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#ifdef MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
std::string hiprtc_error(hiprtcResult err, const std::string& msg) std::string hiprtc_error(hiprtcResult err, const std::string& msg)
{ {
return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg)); return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg));
...@@ -63,6 +67,7 @@ void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::str ...@@ -63,6 +67,7 @@ void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::str
throw make_exception(ctx, hiprtc_error(err, msg)); throw make_exception(ctx, hiprtc_error(err, msg));
} }
// NOLINTNEXTLINE
#define MIGRAPHX_HIPRTC(...) \ #define MIGRAPHX_HIPRTC(...) \
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, MIGRAPHX_MAKE_SOURCE_CTX()) hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, MIGRAPHX_MAKE_SOURCE_CTX())
...@@ -110,21 +115,19 @@ struct hiprtc_program ...@@ -110,21 +115,19 @@ struct hiprtc_program
std::string cpp_src = ""; std::string cpp_src = "";
std::string cpp_name = ""; std::string cpp_name = "";
hiprtc_program(const std::vector<src_file>& srcs) hiprtc_program(std::vector<hiprtc_src_file> srcs)
{ {
for(auto&& src : srcs) for(auto&& src : srcs)
{ {
std::string content{src.content.first, src.content.second}; if(ends_with(src.path, ".cpp"))
std::string path = src.path.string();
if(src.path.extension().string() == ".cpp")
{ {
cpp_src = std::move(content); cpp_src = std::move(src.content);
cpp_name = std::move(path); cpp_name = std::move(src.path);
} }
else else
{ {
headers.push_back(std::move(content)); headers.push_back(std::move(src.content));
include_names.push_back(std::move(path)); include_names.push_back(std::move(src.path));
} }
} }
prog = hiprtc_program_create(cpp_src.c_str(), prog = hiprtc_program_create(cpp_src.c_str(),
...@@ -134,7 +137,7 @@ struct hiprtc_program ...@@ -134,7 +137,7 @@ struct hiprtc_program
include_names.data()); include_names.data());
} }
void compile(const std::vector<std::string>& options) void compile(const std::vector<std::string>& options) const
{ {
if(enabled(MIGRAPHX_TRACE_HIPRTC{})) if(enabled(MIGRAPHX_TRACE_HIPRTC{}))
std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl; std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl;
...@@ -175,10 +178,11 @@ struct hiprtc_program ...@@ -175,10 +178,11 @@ struct hiprtc_program
} }
}; };
std::vector<std::vector<char>> std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file> srcs,
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) std::string params,
const std::string& arch)
{ {
hiprtc_program prog(srcs); hiprtc_program prog(std::move(srcs));
auto options = split_string(params, ' '); auto options = split_string(params, ' ');
options.push_back("-DMIGRAPHX_USE_HIPRTC=1"); options.push_back("-DMIGRAPHX_USE_HIPRTC=1");
// remove following three compilation flags for HIPRTC once fixes from hipRTC are available in // remove following three compilation flags for HIPRTC once fixes from hipRTC are available in
...@@ -187,6 +191,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -187,6 +191,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
options.push_back("-DMIGRAPHX_HAS_DPP=0"); options.push_back("-DMIGRAPHX_HAS_DPP=0");
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"); options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
options.push_back("-Wno-reserved-identifier"); options.push_back("-Wno-reserved-identifier");
options.push_back("-Wno-unused-parameter");
options.push_back("-Wno-gnu-line-marker"); options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast"); options.push_back("-Wno-old-style-cast");
} }
...@@ -205,12 +210,48 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -205,12 +210,48 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return {prog.get_code_obj()}; return {prog.get_code_obj()};
} }
std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{
std::vector<hiprtc_src_file> hsrcs{srcs.begin(), srcs.end()};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
for(const auto& src : srcs)
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
}
}
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
auto driver = p.parent_path().parent_path() / "bin" / "migraphx-hiprtc-driver";
if(fs::exists(driver))
{
value v;
v["srcs"] = to_value(hsrcs);
v["params"] = to_value(params);
v["arch"] = to_value(arch);
tmp_dir td{};
auto out = td.path / "output";
process(driver.string() + " " + out.string()).write([&](auto writer) {
to_msgpack(v, writer);
});
if(fs::exists(out))
return {read_buffer(out.string())};
}
return compile_hip_src_with_hiprtc(std::move(hsrcs), std::move(params), arch);
}
#else // MIGRAPHX_USE_HIPRTC #else // MIGRAPHX_USE_HIPRTC
bool is_hcc_compiler() std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file>, // NOLINT
std::string, // NOLINT
const std::string&)
{ {
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "hcc"); MIGRAPHX_THROW("Not using hiprtc");
return result;
} }
bool is_hip_clang_compiler() bool is_hip_clang_compiler()
...@@ -236,7 +277,7 @@ std::vector<std::vector<char>> ...@@ -236,7 +277,7 @@ std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{ {
assert(not srcs.empty()); assert(not srcs.empty());
if(not is_hcc_compiler() and not is_hip_clang_compiler()) if(not is_hip_clang_compiler())
MIGRAPHX_THROW("Unknown hip compiler: " + MIGRAPHX_THROW("Unknown hip compiler: " +
std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER))); std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER)));
...@@ -246,16 +287,9 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -246,16 +287,9 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(enabled(MIGRAPHX_GPU_DEBUG_SYM{})) if(enabled(MIGRAPHX_GPU_DEBUG_SYM{}))
params += " -g"; params += " -g";
params += " -c"; params += " -c";
if(is_hcc_compiler()) params += " --offload-arch=" + arch;
{ params += " --cuda-device-only";
params += " -amdgpu-target=" + arch; params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
}
else if(is_hip_clang_compiler())
{
params += " --offload-arch=" + arch;
params += " --cuda-device-only";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
}
if(enabled(MIGRAPHX_GPU_DEBUG{})) if(enabled(MIGRAPHX_GPU_DEBUG{}))
params += " -DMIGRAPHX_DEBUG"; params += " -DMIGRAPHX_DEBUG";
...@@ -270,24 +304,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -270,24 +304,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(has_compiler_launcher()) if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER); compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER);
#endif #endif
if(is_hcc_compiler())
compiler.process = [&](const fs::path& obj_path) -> fs::path {
process{MIGRAPHX_STRINGIZE(MIGRAPHX_EXTRACT_KERNEL) + std::string{" -i "} +
obj_path.string()}
.cwd(obj_path.parent_path());
for(const auto& entry : fs::directory_iterator{obj_path.parent_path()})
{
const auto& hsaco_path = entry.path();
if(not fs::is_regular_file(hsaco_path))
continue;
if(hsaco_path.extension() != ".hsaco")
continue;
return hsaco_path;
}
MIGRAPHX_THROW("Missing hsaco");
};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{})) if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{ {
for(const auto& src : srcs) for(const auto& src : srcs)
......
...@@ -136,10 +136,15 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) ...@@ -136,10 +136,15 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t max_global = ctx.get_current_device().get_cu_count() * std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu(); ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) { return [n, over, max_global](std::size_t local) {
std::size_t groups = (n + local - 1) / local; std::size_t num_elements = n;
std::size_t max_blocks = max_global / local; std::size_t groups = (num_elements + local - 1) / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local; std::size_t max_blocks = max_global / local;
return std::min(nglobal, n); std::size_t nglobal = std::min(max_blocks * over, groups) * local;
#ifdef MIGRAPHX_USE_HIPRTC
if(enabled(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{}))
num_elements = ((num_elements + local - 1) / local) * local;
#endif
return std::min(nglobal, num_elements);
}; };
} }
......
...@@ -112,14 +112,8 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024) ...@@ -112,14 +112,8 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024)
#ifdef MIGRAPHX_USE_CLANG_TIDY #ifdef MIGRAPHX_USE_CLANG_TIDY
#define MIGRAPHX_DEVICE_SHARED #define MIGRAPHX_DEVICE_SHARED
#else #else
// Workaround hcc's broken tile_static macro
#ifdef tile_static
#undef tile_static
#define MIGRAPHX_DEVICE_SHARED __attribute__((tile_static))
#else
#define MIGRAPHX_DEVICE_SHARED __shared__ #define MIGRAPHX_DEVICE_SHARED __shared__
#endif #endif
#endif
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -94,6 +94,10 @@ template <> ...@@ -94,6 +94,10 @@ template <>
struct is_hip_type<std::uint8_t> : std::true_type struct is_hip_type<std::uint8_t> : std::true_type
{ {
}; };
template <>
struct is_hip_type<std::int32_t> : std::true_type
{
};
template <class T, class V, MIGRAPHX_REQUIRES(is_hip_type<typename T::type>{})> template <class T, class V, MIGRAPHX_REQUIRES(is_hip_type<typename T::type>{})>
void hip_visitor_invoke(T as, V&& v) void hip_visitor_invoke(T as, V&& v)
...@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
if(not std::all_of( if(not std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); })) types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { visit_tensor_size(s.ndim(), [&](auto ndim) {
s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); })); s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
}); });
} }
...@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); }); visit_tensor_size(s.ndim(), [&](auto ndim) { v(f(xs, ndim)...); });
} }
template <class F> template <class F>
......
...@@ -37,22 +37,26 @@ argument scatter( ...@@ -37,22 +37,26 @@ argument scatter(
hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis) hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis)
{ {
auto ds = arg0.get_shape(); auto ds = arg0.get_shape();
auto inds = arg1.get_shape(); auto s1 = arg1.get_shape();
auto axis_dim_size = ds.lens()[axis]; auto axis_dim_size = ds.lens()[axis];
hip_visit_all(result, arg0, inds)([&](auto output, auto data, auto s1) { hip_visit_all(result, arg0, arg2)([&](auto output, auto data, auto update) {
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
const auto* data_ptr = device_cast(data.data()); const auto* data_ptr = device_cast(data.data());
gs_launch(stream, ds.elements())([=](auto i) __device__ { output_ptr[i] = data_ptr[i]; }); gs_launch(stream, ds.elements())([=](auto i) __device__ { output_ptr[i] = data_ptr[i]; });
hip_visit_all(arg1, arg2)([&](auto indices, auto update) {
const auto* upd_ptr = device_cast(update.data()); hip_visit_all(arg1)([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data()); if constexpr(indices.get_shape().lens.size() == output.get_shape().lens.size())
gs_launch(stream, inds.elements())([=](auto i) __device__ { {
auto out_idx = s1.multi(i); const auto* upd_ptr = device_cast(update.data());
auto index = indices_ptr[i]; const auto* indices_ptr = device_cast(indices.data());
index = index < 0 ? index + axis_dim_size : index; gs_launch(stream, s1.elements())([=](auto i) __device__ {
out_idx[axis] = index; auto out_idx = indices.get_shape().multi(i);
output[out_idx] = upd_ptr[i]; auto index = indices_ptr[i];
}); index = index < 0 ? index + axis_dim_size : index;
out_idx[axis] = index;
output[out_idx] = upd_ptr[i];
});
}
}); });
}); });
......
...@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp ...@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
add_executable(gpu-driver add_executable(gpu-driver
${GPU_DRIVER_SRCS} ${GPU_DRIVER_SRCS}
) )
rocm_clang_tidy_check(gpu-driver)
target_include_directories(gpu-driver PRIVATE include) target_include_directories(gpu-driver PRIVATE include)
target_link_libraries(gpu-driver PRIVATE migraphx_gpu) target_link_libraries(gpu-driver PRIVATE migraphx_gpu)
...@@ -44,7 +44,7 @@ struct auto_register_action ...@@ -44,7 +44,7 @@ struct auto_register_action
template <class T> template <class T>
static void apply() static void apply()
{ {
auto name = get_type_name<T>(); const auto& name = get_type_name<T>();
register_action(name.substr(name.rfind("::") + 2), register_action(name.substr(name.rfind("::") + 2),
[](auto&&... xs) { T::apply(std::forward<decltype(xs)>(xs)...); }); [](auto&&... xs) { T::apply(std::forward<decltype(xs)>(xs)...); });
} }
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/env.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -35,9 +36,34 @@ struct module; ...@@ -35,9 +36,34 @@ struct module;
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR);
bool mlir_enabled()
{
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
struct mlir_conv const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{});
if(mlir_enabled)
{
return true;
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
return false;
}
#else
return false;
#endif
}
#ifdef MIGRAPHX_MLIR
struct mlir_op
{ {
std::string name() const { return "gpu::mlir_op"; }
operation op = make_op("convolution"); operation op = make_op("convolution");
template <class Self, class F> template <class Self, class F>
...@@ -46,7 +72,6 @@ struct mlir_conv ...@@ -46,7 +72,6 @@ struct mlir_conv
return pack(f(self.op, "op")); return pack(f(self.op, "op"));
} }
std::string name() const { return "gpu::mlir_conv"; }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{ {
check_shapes{inputs, *this}.packed_or_broadcasted(); check_shapes{inputs, *this}.packed_or_broadcasted();
...@@ -54,17 +79,50 @@ struct mlir_conv ...@@ -54,17 +79,50 @@ struct mlir_conv
MIGRAPHX_THROW("should have one submodule."); MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2) if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size();
return op.compute_shape({inputs[n - 2], inputs[n - 1]}); module_ref mod = mods[0];
auto type = mod->get_output_shapes().front().type();
std::unordered_map<instruction_ref, shape> ins_shapes;
size_t param_cnt = 0;
std::vector<std::string> names = mod->get_parameter_names();
std::sort(names.begin(), names.end());
for(std::string param_name : names)
{
ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
}
for(auto ins : iterator_for(*mod))
{
if(ins->name() == "@param")
{
continue;
}
if(ins->name() == "@literal")
{
ins_shapes[ins] = ins->get_shape();
continue;
}
if(ins->name() == "@return")
{
return ins_shapes[ins->inputs().at(0)].with_type(type);
}
std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size());
std::transform(ins->inputs().begin(),
ins->inputs().end(),
input_shapes.begin(),
[&](auto in) { return ins_shapes[in]; });
ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes);
}
MIGRAPHX_THROW("No return found in the submodule");
} }
}; };
MIGRAPHX_REGISTER_OP(mlir_conv); MIGRAPHX_REGISTER_OP(mlir_op);
namespace { namespace {
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{ {
if(ins->name() != "convolution") if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false; return false;
value v = ins->get_operator().to_value(); value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>(); auto group = v.at("group").to<int>();
...@@ -76,51 +134,107 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) ...@@ -76,51 +134,107 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
return true; return true;
} }
struct find_conv_pointwise struct find_mlir_op
{ {
// Find a convolution followed by a pointwise operation.
auto matcher() const auto matcher() const
{ {
auto convolution = auto dot_or_conv = match::skip(match::name("contiguous"))(
match::skip(match::name("contiguous"))(is_mlir_conv().bind("convolution")); match::any_of(match::name("dot"), is_mlir_conv()).bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](convolution.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) const
{
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
{
if(ins->name() != "@literal")
{
continue;
}
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
}
return ins_map;
}
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) const
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto conv_ins = r.instructions["convolution"]; auto gemm_based_op = r.instructions["gemm_based_op"];
auto x_ins = r.instructions["x"]; // input after contiguous auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
// Whitelist pointwise operators // Whitelist pointwise operators
if(std::any_of(pm->begin(), pm->end(), [](const auto& i) { if(std::any_of(pm->begin(), pm->end(), [](const auto& i) {
return not contains({"@literal", "@param", "@return", "convolution", "add", "relu"}, return not contains({"@literal",
"@param",
"@return",
"convolution",
"quant_convolution",
"dot",
"add",
"relu",
"dequantizelinear",
"quantizelinear",
"mul"},
i.name()); i.name());
})) }))
return; return;
// Only fuse with fp32/fp16 // Only fuse with fp32/fp16/int8/int32
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) { if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type, shape::type_t::half_type}, return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::int8_type,
shape::type_t::int32_type},
i->get_shape().type()); i->get_shape().type());
})) }))
return; return;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name()); module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass(); mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map; std::unordered_map<instruction_ref, instruction_ref> param_map =
auto x = mm->add_parameter("x" + std::to_string(names.size()), create_param_map_with_literals(mm, pm, gemm_based_op->get_shape());
conv_ins->inputs().at(0)->get_shape()); auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op);
auto w = mm->add_parameter("x" + std::to_string(names.size() + 1),
conv_ins->inputs().at(1)->get_shape());
auto conv = mm->add_instruction(conv_ins->get_operator(), {x, w});
std::transform(names.begin(), std::transform(names.begin(),
names.end(), names.end(),
ins->inputs().begin(), ins->inputs().begin(),
std::inserter(param_map, param_map.end()), std::inserter(param_map, param_map.end()),
[&](auto name, auto input) { [&, &anchor_op = anchor_op](auto name, auto input) {
if(input == x_ins) if(input == x_ins)
return std::make_pair(pm->get_parameter(name), conv); return std::make_pair(pm->get_parameter(name), anchor_op);
return std::make_pair(pm->get_parameter(name), return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape())); mm->add_parameter(name, input->get_shape()));
}); });
...@@ -130,12 +244,13 @@ struct find_conv_pointwise ...@@ -130,12 +244,13 @@ struct find_conv_pointwise
std::copy_if(ins->inputs().begin(), std::copy_if(ins->inputs().begin(),
ins->inputs().end(), ins->inputs().end(),
std::back_inserter(inputs), std::back_inserter(inputs),
[&](auto input) { return input != conv_ins; }); [&](auto input) { return input != gemm_based_op; });
inputs.insert(inputs.end(), conv_ins->inputs().begin(), conv_ins->inputs().end()); inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
mpm.get_module().replace_instruction( mpm.get_module().replace_instruction(
ins, mlir_conv{conv_ins->get_operator()}, inputs, {mm}); ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
} }
}; };
} // namespace } // namespace
#endif #endif
...@@ -143,7 +258,7 @@ struct find_conv_pointwise ...@@ -143,7 +258,7 @@ struct find_conv_pointwise
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_conv_pointwise{}); match::find_matches(mpm, find_mlir_op{});
#else #else
(void)mpm; (void)mpm;
#endif #endif
......
...@@ -157,13 +157,8 @@ void gemm_impl(context& ctx, ...@@ -157,13 +157,8 @@ void gemm_impl(context& ctx,
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag = rocblas_gemm_flags flag =
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none; int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#else
(void)int8_x4_format;
int flag = 0;
#endif
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
......
...@@ -189,8 +189,20 @@ argument register_on_gpu(const argument& arg) ...@@ -189,8 +189,20 @@ argument register_on_gpu(const argument& arg)
argument to_gpu(const argument& arg, bool host) argument to_gpu(const argument& arg, bool host)
{ {
auto p = write_to_gpu(arg.data(), arg.get_shape().bytes(), host); argument result;
return {arg.get_shape(), p}; arg.visit(
[&](auto x) {
auto p = write_to_gpu(arg.data(), arg.get_shape().bytes(), host);
result = {x.get_shape(), p};
},
[&](const auto& xs) {
std::vector<argument> args;
std::transform(xs.begin(), xs.end(), std::back_inserter(args), [&](auto x) {
return to_gpu(x, host);
});
result = argument{args};
});
return result;
} }
argument from_gpu(const argument& arg) argument from_gpu(const argument& arg)
......
...@@ -21,71 +21,13 @@ ...@@ -21,71 +21,13 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
project(migraphx-doc)
find_package(ROCM REQUIRED)
include(ROCMDoxygenDoc) add_executable(migraphx-hiprtc-driver
main.cpp
set(DOXYGEN_OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/doxygen)
rocm_add_doxygen_doc(
OUTPUT_DIRECTORY ${DOXYGEN_OUTPUT}
INPUT
${CMAKE_SOURCE_DIR}/src
INCLUDE_PATH
${CMAKE_SOURCE_DIR}/src/include
${CMAKE_SOURCE_DIR}/src/targets/cpu/include
${CMAKE_SOURCE_DIR}/src/targets/gpu/include
STRIP_FROM_INC_PATH
${CMAKE_SOURCE_DIR}/src/include
${CMAKE_SOURCE_DIR}/src/targets/cpu/include
${CMAKE_SOURCE_DIR}/src/targets/gpu/include
EXCLUDE_PATTERNS
${CMAKE_SOURCE_DIR}/src/targets/gpu/kernels
${CMAKE_SOURCE_DIR}/src/targets/gpu/device
SEARCH_INCLUDES YES
MACRO_EXPANSION YES
RECURSIVE YES
GENERATE_XML YES
GENERATE_LATEX YES
USE_PDFLATEX YES
CALL_GRAPH YES
CALLER_GRAPH YES
BUILTIN_STL_SUPPORT YES
PROJECT_NAME MIGraphX
SORT_MEMBERS_CTORS_1ST YES
SOURCE_BROWSER YES
GENERATE_TREEVIEW YES
REFERENCED_BY_RELATION YES
REFERENCES_RELATION YES
REFERENCES_LINK_SOURCE YES
EXTRACT_ALL YES
ENUM_VALUES_PER_LINE 1
FULL_PATH_NAMES YES
WARN_LOGFILE "${DOXYGEN_OUTPUT}/DoxygenWarningLog.txt"
PREDEFINED DOXYGEN
) )
rocm_clang_tidy_check(migraphx-hiprtc-driver)
include(ROCMSphinxDoc) target_link_libraries(migraphx-hiprtc-driver PRIVATE migraphx_gpu)
rocm_add_sphinx_doc(src add_dependencies(migraphx_all_targets migraphx-hiprtc-driver)
BUILDER html rocm_install_targets(
OUTPUT_DIR html TARGETS migraphx-hiprtc-driver
VARS
breathe_projects.proj=${DOXYGEN_OUTPUT}/xml
breathe_default_project=proj
DEPENDS doxygen
) )
find_package(LATEX)
if(LATEX_FOUND)
rocm_add_sphinx_doc(src
BUILDER latex
OUTPUT_DIR pdf
VARS
breathe_projects.proj=${DOXYGEN_OUTPUT}/xml
breathe_default_project=proj
DEPENDS doxygen
)
else()
message("Latex builder not found. Latex builder is required only for building the PDF documentation for MIGraphX and is not necessary for building the library, or any other components. To build PDF documentation run make in ${CMAKE_CURRENT_SOURCE_DIR}/pdf, once a latex builder is installed.")
endif()
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/value.hpp>
#include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/ranges.hpp>
#include <iostream>
#include <cstring>
std::vector<char> read_stdin()
{
std::vector<char> result;
std::array<char, 1024> buffer;
std::size_t len = 0;
while((len = std::fread(buffer.data(), 1, buffer.size(), stdin)) > 0)
{
if(std::ferror(stdin) != 0 and std::feof(stdin) == 0)
MIGRAPHX_THROW(std::strerror(errno));
result.insert(result.end(), buffer.data(), buffer.data() + len);
}
return result;
}
int main(int argc, char const* argv[])
{
if(argc < 2 or migraphx::contains({"-h", "--help", "-v", "--version"}, std::string(argv[1])))
{
std::cout << "USAGE:" << std::endl;
std::cout << " ";
std::cout << "Used internally by migraphx to compile hip programs out-of-process."
<< std::endl;
std::exit(0);
}
std::string output_name = argv[1];
auto v = migraphx::from_msgpack(read_stdin());
std::vector<migraphx::gpu::hiprtc_src_file> srcs;
migraphx::from_value(v.at("srcs"), srcs);
auto out = migraphx::gpu::compile_hip_src_with_hiprtc(
std::move(srcs), v.at("params").to<std::string>(), v.at("arch").to<std::string>());
if(not out.empty())
migraphx::write_buffer(output_name, out.front());
}
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/module_ref.hpp> #include <migraphx/module_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -34,6 +35,7 @@ namespace migraphx { ...@@ -34,6 +35,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct shape; struct shape;
struct operation;
namespace gpu { namespace gpu {
...@@ -72,8 +74,23 @@ std::string make_transformer_args(Ts... xs) ...@@ -72,8 +74,23 @@ std::string make_transformer_args(Ts... xs)
std::string generate_pointwise(const module& pm, const std::string& name); std::string generate_pointwise(const module& pm, const std::string& name);
std::string generate_reduce(const module& m, const std::string& name);
std::string generate_name_from_ops(const module& m); std::string generate_name_from_ops(const module& m);
struct reduce_op
{
std::string input = "";
std::string reduction = "";
std::string init = "0";
std::string read = "op::id{}";
std::string write = "op::id{}";
void set(instruction_ref ins, const operation& op);
std::string str() const;
static std::string generate(instruction_ref ins, const std::string& x);
};
} // namespace gen } // namespace gen
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/compile_src.hpp> #include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp>
#include <migraphx/functional.hpp>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -35,6 +37,31 @@ namespace migraphx { ...@@ -35,6 +37,31 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
#ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
#endif
struct hiprtc_src_file
{
hiprtc_src_file() = default;
hiprtc_src_file(const src_file& s)
: path(s.path.string()), content(s.content.first, s.content.second)
{
}
std::string path;
std::string content;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.path, "path"), f(self.content, "content"));
}
};
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file> srcs,
std::string params,
const std::string& arch);
std::vector<std::vector<char>> std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch); compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment