"vscode:/vscode.git/clone" did not exist on "240cbda06c309bf4ba51ef808c8b075b7ae3d818"
Unverified Commit 40fbef9b authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into threaded_nms

parents d164b151 aeb9f78c
...@@ -113,8 +113,7 @@ void subgraph::apply(module_pass_manager& mpm) const ...@@ -113,8 +113,7 @@ void subgraph::apply(module_pass_manager& mpm) const
// TODO(varunsh): this code may be replaceable by code in the fuse_pointwise pass // TODO(varunsh): this code may be replaceable by code in the fuse_pointwise pass
// assuming all FPGA instructions are in one contiguous range // assuming all FPGA instructions are in one contiguous range
pm->insert_instructions(pm->end(), first, last, {}); pm->insert_instructions(pm->end(), first, std::next(last), {});
migraphx::instruction_ref placeholder_ins; migraphx::instruction_ref placeholder_ins;
for(auto it : iterator_for(mod)) for(auto it : iterator_for(mod))
{ {
......
...@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; }; ...@@ -33,7 +33,7 @@ migraphx::shape x_model::get_shape() const { return shape; };
void x_model::set_shape(migraphx::shape s) { shape = s; } void x_model::set_shape(migraphx::shape s) { shape = s; }
x_model create_xmodel(const migraphx::module_ref mod) x_model create_xmodel(migraphx::const_module_ref mod)
{ {
std::cout << "Calling an external function: create_xmodel!\n"; std::cout << "Calling an external function: create_xmodel!\n";
x_model xmodel; x_model xmodel;
......
...@@ -33,6 +33,11 @@ if(NOT TARGET MIOpen) ...@@ -33,6 +33,11 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
endif()
if(BUILD_DEV) if(BUILD_DEV)
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs") set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
else() else()
...@@ -40,12 +45,12 @@ else() ...@@ -40,12 +45,12 @@ else()
endif() endif()
include(Embed) include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS} file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES}) add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
file(GLOB DEVICE_GPU_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp) file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
add_library(migraphx_device ${DEVICE_GPU_SRCS}) add_library(migraphx_device ${DEVICE_GPU_SRCS})
add_library(compile_for_gpu INTERFACE) add_library(compile_for_gpu INTERFACE)
...@@ -65,6 +70,8 @@ target_link_libraries(migraphx_device PUBLIC migraphx) ...@@ -65,6 +70,8 @@ target_link_libraries(migraphx_device PUBLIC migraphx)
target_link_libraries(migraphx_device PRIVATE compile_for_gpu) target_link_libraries(migraphx_device PRIVATE compile_for_gpu)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
target_compile_options(migraphx_device PRIVATE -Wno-ignored-attributes)
migraphx_generate_export_header(migraphx_device DIRECTORY migraphx/gpu/device)
add_library(kernel_file_check EXCLUDE_FROM_ALL) add_library(kernel_file_check EXCLUDE_FROM_ALL)
...@@ -80,7 +87,13 @@ target_link_libraries(kernel_file_check compile_for_gpu) ...@@ -80,7 +87,13 @@ target_link_libraries(kernel_file_check compile_for_gpu)
rocm_clang_tidy_check(kernel_file_check) rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp) file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(WIN32)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp)
endif()
add_library(migraphx_gpu add_library(migraphx_gpu
abs.cpp abs.cpp
analyze_streams.cpp analyze_streams.cpp
...@@ -95,6 +108,7 @@ add_library(migraphx_gpu ...@@ -95,6 +108,7 @@ add_library(migraphx_gpu
compile_miopen.cpp compile_miopen.cpp
compiler.cpp compiler.cpp
device_name.cpp device_name.cpp
fuse_ck.cpp
fuse_mlir.cpp fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp gather.cpp
...@@ -123,11 +137,14 @@ add_library(migraphx_gpu ...@@ -123,11 +137,14 @@ add_library(migraphx_gpu
schedule_model.cpp schedule_model.cpp
sync_device.cpp sync_device.cpp
target.cpp target.cpp
time_op.cpp
topk.cpp topk.cpp
write_literals.cpp write_literals.cpp
${JIT_GPU_SRCS} ${JIT_GPU_SRCS}
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
migraphx_generate_export_header(migraphx_gpu)
function(register_migraphx_gpu_ops PREFIX) function(register_migraphx_gpu_ops PREFIX)
foreach(OP ${ARGN}) foreach(OP ${ARGN})
...@@ -169,7 +186,7 @@ register_op(migraphx_gpu ...@@ -169,7 +186,7 @@ register_op(migraphx_gpu
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot> OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp) INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::deconvolution> gpu::miopen_convolution<op::quant_convolution> OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::convolution_backwards> gpu::miopen_convolution<op::quant_convolution>
INCLUDES migraphx/gpu/context.hpp) INCLUDES migraphx/gpu/context.hpp)
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
...@@ -181,7 +198,9 @@ if(MIGRAPHX_ENABLE_MLIR) ...@@ -181,7 +198,9 @@ if(MIGRAPHX_ENABLE_MLIR)
find_package(rocMLIR 1.0.0 CONFIG REQUIRED) find_package(rocMLIR 1.0.0 CONFIG REQUIRED)
message(STATUS "Build with rocMLIR::rockCompiler ${rocMLIR_VERSION}") message(STATUS "Build with rocMLIR::rockCompiler ${rocMLIR_VERSION}")
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR") target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR")
target_link_libraries(migraphx_gpu PUBLIC rocMLIR::rockCompiler) # Make this private to avoid multiple inclusions of LLVM symbols.
# TODO: Fix rocMLIR's library to hide LLVM internals.
target_link_libraries(migraphx_gpu PRIVATE rocMLIR::rockCompiler)
endif() endif()
if(MIGRAPHX_USE_HIPRTC) if(MIGRAPHX_USE_HIPRTC)
...@@ -227,7 +246,12 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_ ...@@ -227,7 +246,12 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "") set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
if(MIGRAPHX_USE_FIND_2_API) if(MIGRAPHX_USE_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API) check_library_exists(MIOpen "miopenSetFindOptionPreallocatedTensor" "${MIOPEN_LOCATION}" HAS_PREALLOCATION_API)
if(HAS_PREALLOCATION_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API -DMIGRAPHX_PREALLOCATE_MIOPEN_BUFFERS)
else()
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
endif()
message(STATUS "MIGraphx is using Find-2.0 API of MIOpen") message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
else() else()
message(STATUS "MIGraphx is using legacy Find API in MIOpen") message(STATUS "MIGraphx is using legacy Find API in MIOpen")
...@@ -242,6 +266,10 @@ endif() ...@@ -242,6 +266,10 @@ endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
endif()
add_subdirectory(driver) add_subdirectory(driver)
add_subdirectory(hiprtc) add_subdirectory(hiprtc)
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -171,7 +172,8 @@ std::string make_transformer_args(std::vector<std::string> transformers) ...@@ -171,7 +172,8 @@ std::string make_transformer_args(std::vector<std::string> transformers)
void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name) void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name)
{ {
module m = pm; module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}}); run_passes(m,
{rewrite_quantization{}, eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g; cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
...@@ -280,6 +282,14 @@ std::string generate_reduce(const module& m, const std::string& name) ...@@ -280,6 +282,14 @@ std::string generate_reduce(const module& m, const std::string& name)
not input->get_shape().broadcasted(); not input->get_shape().broadcasted();
}); });
auto inner_names = names; auto inner_names = names;
for(auto input : ins->inputs())
{
if(input->name() != "@param")
continue;
if(contains(tensors, input))
continue;
inner_names[input] += "[out_idx]";
}
for(auto input : tensors) for(auto input : tensors)
inner_names[input] += "_lambda_param"; inner_names[input] += "_lambda_param";
auto call_function = auto call_function =
...@@ -308,6 +318,8 @@ std::string generate_reduce(const module& m, const std::string& name) ...@@ -308,6 +318,8 @@ std::string generate_reduce(const module& m, const std::string& name)
}); });
f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name); f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name);
f.add_generic_param("r"); f.add_generic_param("r");
f.add_generic_param("out_idx");
f.unused_param("out_idx");
g.create_function(f); g.create_function(f);
return g.str(); return g.str();
} }
......
...@@ -56,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC); ...@@ -56,9 +56,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#ifdef MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
std::string hiprtc_error(hiprtcResult err, const std::string& msg) std::string hiprtc_error(hiprtcResult err, const std::string& msg)
{ {
return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg)); return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg));
...@@ -194,6 +191,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr ...@@ -194,6 +191,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
options.push_back("-DMIGRAPHX_HAS_DPP=0"); options.push_back("-DMIGRAPHX_HAS_DPP=0");
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"); options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
options.push_back("-Wno-reserved-identifier"); options.push_back("-Wno-reserved-identifier");
options.push_back("-Wno-unused-parameter");
options.push_back("-Wno-gnu-line-marker"); options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast"); options.push_back("-Wno-old-style-cast");
} }
...@@ -216,6 +214,15 @@ std::vector<std::vector<char>> ...@@ -216,6 +214,15 @@ std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{ {
std::vector<hiprtc_src_file> hsrcs{srcs.begin(), srcs.end()}; std::vector<hiprtc_src_file> hsrcs{srcs.begin(), srcs.end()};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
for(const auto& src : srcs)
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
}
}
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc); auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
auto driver = p.parent_path().parent_path() / "bin" / "migraphx-hiprtc-driver"; auto driver = p.parent_path().parent_path() / "bin" / "migraphx-hiprtc-driver";
......
...@@ -135,10 +135,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) ...@@ -135,10 +135,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t max_global = ctx.get_current_device().get_cu_count() * std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu(); ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) { return [n, over, max_global](std::size_t local) {
std::size_t groups = (n + local - 1) / local; // hip require global workitems multiple of local workitems. It may degrade performance.
std::size_t max_blocks = max_global / local; // [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available.
std::size_t nglobal = std::min(max_blocks * over, groups) * local; // https://reviews.llvm.org/D155213
return std::min(nglobal, n); std::size_t num_elements = ((n + local - 1) / local) * local;
std::size_t groups = (num_elements + local - 1) / local;
std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local;
return std::min(nglobal, num_elements);
}; };
} }
...@@ -156,14 +160,14 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -156,14 +160,14 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert(not options.inputs.empty()); assert(not options.inputs.empty());
assert(options.inputs.size() == options.virtual_inputs.size() or assert(options.inputs.size() == options.virtual_inputs.size() or
options.virtual_inputs.empty()); options.virtual_inputs.empty());
std::vector<src_file> srcs; std::vector<src_file> srcs = options.additional_src_files;
std::transform(migraphx_kernels().begin(), std::transform(migraphx_kernels().begin(),
migraphx_kernels().end(), migraphx_kernels().end(),
std::back_inserter(srcs), std::back_inserter(srcs),
[](auto&& p) { [](auto&& p) {
auto&& name = p.first; auto&& name = p.first;
auto&& c = p.second; auto&& c = p.second;
auto path = fs::path{"migraphx"} / "kernels" / name; auto path = name;
return src_file{path, c}; return src_file{path, c};
}); });
srcs.push_back(src_file{fs::path{"main.cpp"}, srcs.push_back(src_file{fs::path{"main.cpp"},
......
...@@ -79,7 +79,7 @@ void compile_miopen::apply(module& m) const ...@@ -79,7 +79,7 @@ void compile_miopen::apply(module& m) const
std::size_t ws = 0; std::size_t ws = 0;
try try
{ {
// for the regular convolution and deconvolution, this try would always succeed // for the regular convolution and convolution_backwards, this try would always succeed
ws = compile(op, ins, int8_x4_format); ws = compile(op, ins, int8_x4_format);
} }
catch(migraphx::exception&) catch(migraphx::exception&)
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/time_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -76,33 +77,201 @@ struct compiled_result ...@@ -76,33 +77,201 @@ struct compiled_result
instruction_ref ins; instruction_ref ins;
}; };
struct problem_cache
{
bool has(const std::string& name, const value& problem) const
{
return contains(cache, create_key(name, problem));
}
void insert(const std::string& name, const value& problem, const value& solution)
{
assert(not solution.is_null());
cache[create_key(name, problem)] = solution;
}
void mark(const std::string& name, const value& problem)
{
cache.insert(std::make_pair(create_key(name, problem), value{}));
}
optional<value> get(const std::string& name, const value& problem) const
{
auto it = cache.find(create_key(name, problem));
if(it == cache.end())
return nullopt;
return it->second;
}
static value create_key(const std::string& name, const value& problem)
{
return {{"name", name}, {"problem", problem}};
}
std::unordered_map<value, value> cache;
};
struct compile_plan
{
context* ctx;
operation preop;
instruction_ref ins;
optional<tuning_config> config = nullopt;
std::vector<optional<compiled_result>> results = {};
void update_config(bool exhaustive)
{
config = get_tuning_config(*ctx, ins, preop, exhaustive);
}
template <class Vector>
void insert_compiles(Vector& compiles, const value& solution, std::size_t i)
{
compiles.emplace_back([=] {
try
{
results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins};
}
catch(...)
{
results[i] = nullopt;
}
});
}
template <class Vector>
void add_compiles(Vector& compiles, problem_cache& pc)
{
if(config.has_value())
{
const auto& problem = config->problem;
if(auto sol = pc.get(preop.name(), problem))
{
auto solution = sol.value();
// No solution yet until benchmarked so skip for now
if(solution.is_null())
return;
results.resize(1);
insert_compiles(compiles, solution, 0);
}
else
{
pc.mark(preop.name(), problem);
const auto& solutions = config->solutions;
results.resize(solutions.size());
for(auto i : range(solutions.size()))
{
auto solution = solutions[i];
insert_compiles(compiles, solution, i);
}
}
}
else
{
results.resize(1);
insert_compiles(compiles, value{}, 0);
}
}
const compiled_result& benchmark(problem_cache& pc) const
{
if(results.empty())
MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1)
{
if(not results.front().has_value())
MIGRAPHX_THROW("No configs to tune");
return *results.front();
}
if(not config)
MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
std::vector<double> times;
times.reserve(results.size());
std::transform(
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) {
if(not cr.has_value())
return std::numeric_limits<double>::max();
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20)
.first;
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i));
if(not results[i].has_value())
MIGRAPHX_THROW("No valid tuned compilation.");
return *results[i];
}
void replace(module& m, problem_cache& pc) const
{
const auto& cr = benchmark(pc);
cr.replace.replace(m, cr.ins);
}
};
template <class F> template <class F>
void par_compile(std::size_t n, F f) void par_compile(std::size_t n, F f)
{ {
if(n == 0) if(n == 0)
return; return;
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f); auto d = value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{});
if(d == 0)
d = n;
par_for(n, n / d, f);
} }
void compile_ops::apply(module& m) const struct compile_manager
{ {
std::vector<std::function<compiled_result()>> compiles; problem_cache pc;
std::vector<compile_plan> cps;
bool exhaustive = false;
template <class... Ts>
void add_plan(Ts&&... xs)
{
cps.push_back({std::forward<Ts>(xs)...});
}
void update_configs()
{
par_compile(cps.size(), [&](auto i) { cps[i].update_config(exhaustive); });
}
void compile(module& m)
{
std::vector<std::function<void()>> compiles;
for(auto& cp : cps)
{
cp.add_compiles(compiles, pc);
}
par_compile(compiles.size(), [&](auto i) { compiles[i](); });
// Replace and/or benchmark
for(const auto& cp : cps)
{
if(cp.results.empty())
continue;
cp.replace(m, pc);
}
// Remove compile_plan already executed
cps.erase(std::remove_if(cps.begin(),
cps.end(),
[](const auto& cp) { return not cp.results.empty(); }),
cps.end());
}
};
void compile_ops::apply(module& m) const
{
compile_manager cm;
cm.exhaustive = exhaustive_tune;
// Find all precompile opes
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "gpu::precompile_op") if(ins->name() != "gpu::precompile_op")
continue; continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op; operation preop = any_cast<precompile_op>(ins->get_operator()).op;
compiles.emplace_back([=]() -> compiled_result { cm.add_plan(ctx, preop, ins);
return {compile(*ctx, ins, preop), ins};
});
}
std::vector<compiled_result> results(compiles.size());
par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results)
{
cr.replace(m, cr.ins);
} }
cm.update_configs();
cm.compile(m);
// Compile already tuned configs
cm.compile(m);
assert(cm.cps.empty());
} }
} // namespace gpu } // namespace gpu
......
...@@ -28,33 +28,45 @@ namespace migraphx { ...@@ -28,33 +28,45 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
auto& compiler_map() namespace {
struct compiler_handle
{ {
static std::unordered_map<std::string, compiler_compile> m; // NOLINT compiler_compile compile;
return m; compiler_compile_op compile_op;
} compiler_tuning_config get_tuning_config;
};
} // namespace
auto& compiler_op_map() auto& compiler_map()
{ {
static std::unordered_map<std::string, compiler_compile_op> m; // NOLINT static std::unordered_map<std::string, compiler_handle> m; // NOLINT
return m; return m;
} }
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop) void register_compiler(const std::string& name,
compiler_compile c,
compiler_compile_op cop,
compiler_tuning_config ctg)
{ {
compiler_map()[name] = std::move(c); compiler_map()[name] = {std::move(c), std::move(cop), std::move(ctg)};
compiler_op_map()[name] = std::move(cop);
} }
bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; } bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) compiler_replace
compile(context& ctx, instruction_ref ins, const operation& op, const value& solution)
{ {
return compiler_map().at(op.name())(ctx, ins, op); return compiler_map().at(op.name()).compile(ctx, ins, op, solution);
} }
operation operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v) compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v)
{ {
return compiler_op_map().at(name)(ctx, inputs, v); return compiler_map().at(name).compile_op(ctx, inputs, v);
}
optional<tuning_config>
get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive)
{
return compiler_map().at(op.name()).get_tuning_config(ctx, ins, op, exhaustive);
} }
} // namespace gpu } // namespace gpu
......
...@@ -94,6 +94,10 @@ template <> ...@@ -94,6 +94,10 @@ template <>
struct is_hip_type<std::uint8_t> : std::true_type struct is_hip_type<std::uint8_t> : std::true_type
{ {
}; };
template <>
struct is_hip_type<std::int32_t> : std::true_type
{
};
template <class T, class V, MIGRAPHX_REQUIRES(is_hip_type<typename T::type>{})> template <class T, class V, MIGRAPHX_REQUIRES(is_hip_type<typename T::type>{})>
void hip_visitor_invoke(T as, V&& v) void hip_visitor_invoke(T as, V&& v)
...@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -120,12 +124,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
if(not std::all_of( if(not std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); })) types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { visit_tensor_size(s.ndim(), [&](auto ndim) {
s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); })); s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
}); });
} }
...@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -133,12 +135,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); }); visit_tensor_size(s.ndim(), [&](auto ndim) { v(f(xs, ndim)...); });
} }
template <class F> template <class F>
......
...@@ -67,18 +67,19 @@ void multinomial(hipStream_t stream, ...@@ -67,18 +67,19 @@ void multinomial(hipStream_t stream,
size_t class_size = arg0.get_shape().lens().back(); size_t class_size = arg0.get_shape().lens().back();
size_t sample_size = result.get_shape().lens().back(); size_t sample_size = result.get_shape().lens().back();
hip_visit_all(arg0, arg1)([&](auto cdf, auto dist) { visit_all(arg0, arg1)([&](auto cdf_host, auto dist_host) {
result.visit([&](auto out) { result.visit([&](auto output_host) {
hip_visit_views(out)([&](auto output) { hip_visit_views(cdf_host, dist_host, output_host)(
gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ { [&](auto cdf, auto dist, auto output) {
auto idx = output.get_shape().multi(i); gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ {
auto cdf_begin = cdf.begin() + (idx.front() * class_size); auto idx = output.get_shape().multi(i);
auto cdf_end = cdf_begin + class_size; auto cdf_begin = cdf.begin() + (idx.front() * class_size);
auto sample_iter = auto cdf_end = cdf_begin + class_size;
upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end))); auto* sample_iter =
output[i] = std::distance(cdf_begin, sample_iter); upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end)));
output[i] = std::distance(cdf_begin, sample_iter);
});
}); });
});
}); });
}); });
} }
......
...@@ -37,22 +37,26 @@ argument scatter( ...@@ -37,22 +37,26 @@ argument scatter(
hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis) hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis)
{ {
auto ds = arg0.get_shape(); auto ds = arg0.get_shape();
auto inds = arg1.get_shape(); auto s1 = arg1.get_shape();
auto axis_dim_size = ds.lens()[axis]; auto axis_dim_size = ds.lens()[axis];
hip_visit_all(result, arg0, inds)([&](auto output, auto data, auto s1) { hip_visit_all(result, arg0, arg2)([&](auto output, auto data, auto update) {
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
const auto* data_ptr = device_cast(data.data()); const auto* data_ptr = device_cast(data.data());
gs_launch(stream, ds.elements())([=](auto i) __device__ { output_ptr[i] = data_ptr[i]; }); gs_launch(stream, ds.elements())([=](auto i) __device__ { output_ptr[i] = data_ptr[i]; });
hip_visit_all(arg1, arg2)([&](auto indices, auto update) {
const auto* upd_ptr = device_cast(update.data()); hip_visit_all(arg1)([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data()); if constexpr(indices.get_shape().lens.size() == output.get_shape().lens.size())
gs_launch(stream, inds.elements())([=](auto i) __device__ { {
auto out_idx = s1.multi(i); const auto* upd_ptr = device_cast(update.data());
auto index = indices_ptr[i]; const auto* indices_ptr = device_cast(indices.data());
index = index < 0 ? index + axis_dim_size : index; gs_launch(stream, s1.elements())([=](auto i) __device__ {
out_idx[axis] = index; auto out_idx = indices.get_shape().multi(i);
output[out_idx] = upd_ptr[i]; auto index = indices_ptr[i];
}); index = index < 0 ? index + axis_dim_size : index;
out_idx[axis] = index;
output[out_idx] = upd_ptr[i];
});
}
}); });
}); });
......
...@@ -43,6 +43,8 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string( ...@@ -43,6 +43,8 @@ auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(
return std::string(props.gcnArchName); return std::string(props.gcnArchName);
} }
std::string get_arch_name(const hipDeviceProp_t& props) { return get_arch_name(rank<1>{}, props); }
int get_device_id() int get_device_id()
{ {
int device; int device;
...@@ -58,7 +60,7 @@ std::string get_device_name() ...@@ -58,7 +60,7 @@ std::string get_device_name()
auto status = hipGetDeviceProperties(&props, get_device_id()); auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties"); MIGRAPHX_THROW("Failed to get device properties");
return get_arch_name(rank<1>{}, props); return get_arch_name(props);
} }
} // namespace gpu } // namespace gpu
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) file(GLOB GPU_DRIVER_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
add_executable(gpu-driver add_executable(gpu-driver
${GPU_DRIVER_SRCS} ${GPU_DRIVER_SRCS}
) )
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/driver/action.hpp> #include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/perf.hpp> #include <migraphx/gpu/time_op.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/driver/action.hpp> #include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/perf.hpp> #include <migraphx/gpu/time_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct ck_gemm
{
operation op = make_op("dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::ck_gemm"; }
void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input : inputs)
check_gemm_shape(input);
auto r = op.compute_shape({a, b});
if(mods.empty())
return r;
return r.with_type(mods.front()->get_output_shapes().front().type());
}
};
MIGRAPHX_REGISTER_OP(ck_gemm);
namespace {
bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(not is_ck_supported_type(ins->get_shape().type()))
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
auto m = a.lens()[a.lens().size() - 2];
auto n = b.lens().back();
auto k = a.lens().back();
// Integer gemms must be divisible by 4 in ck
if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type()))
{
if(m % 4 != 0)
return false;
if(n % 4 != 0)
return false;
if(k % 4 != 0)
return false;
}
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
return k <= 2048;
}
struct find_ck_gemm_pointwise
{
// Find a gemm followed by a pointwise operation.
auto matcher() const
{
auto gemm = match::skip(match::name("contiguous"))(
match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
if(gemm_ins->get_shape().type() != shape::int32_type and
ins->get_shape().type() != gemm_ins->get_shape().type())
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not is_ck_supported_type(input->get_shape().type());
}))
return;
assert(gemm_it != inputs.end());
if(gemm_idx != 0)
{
auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto new_gemm_param = pm->add_parameter(names[0] + "_0", gemm_param->get_shape());
auto new_first_param =
pm->add_parameter(names[gemm_idx] + "_0", first_param->get_shape());
pm->replace_instruction(gemm_param, new_gemm_param);
pm->replace_instruction(first_param, new_first_param);
pm->remove_instruction(first_param);
pm->remove_instruction(gemm_param);
}
inputs.erase(gemm_it);
inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
mpm.get_module().replace_instruction(ins, ck_gemm{gemm_ins->get_operator()}, inputs, {pm});
}
};
struct find_ck_gemm
{
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
mpm.get_module().replace_instruction(ins, ck_gemm{ins->get_operator()}, ins->inputs());
}
};
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -38,6 +38,27 @@ namespace gpu { ...@@ -38,6 +38,27 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR);
bool mlir_enabled()
{
#ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{});
if(mlir_enabled)
{
return true;
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
return false;
}
#else
return false;
#endif
}
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
struct mlir_op struct mlir_op
...@@ -58,8 +79,41 @@ struct mlir_op ...@@ -58,8 +79,41 @@ struct mlir_op
MIGRAPHX_THROW("should have one submodule."); MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2) if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size();
return op.compute_shape({inputs[n - 2], inputs[n - 1]}); module_ref mod = mods[0];
auto type = mod->get_output_shapes().front().type();
std::unordered_map<instruction_ref, shape> ins_shapes;
size_t param_cnt = 0;
std::vector<std::string> names = mod->get_parameter_names();
std::sort(names.begin(), names.end());
for(std::string param_name : names)
{
ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
}
for(auto ins : iterator_for(*mod))
{
if(ins->name() == "@param")
{
continue;
}
if(ins->name() == "@literal")
{
ins_shapes[ins] = ins->get_shape();
continue;
}
if(ins->name() == "@return")
{
return ins_shapes[ins->inputs().at(0)].with_type(type);
}
std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size());
std::transform(ins->inputs().begin(),
ins->inputs().end(),
input_shapes.begin(),
[&](auto in) { return ins_shapes[in]; });
ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes);
}
MIGRAPHX_THROW("No return found in the submodule");
} }
}; };
MIGRAPHX_REGISTER_OP(mlir_op); MIGRAPHX_REGISTER_OP(mlir_op);
...@@ -68,7 +122,7 @@ namespace { ...@@ -68,7 +122,7 @@ namespace {
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{ {
if(ins->name() != "convolution") if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false; return false;
value v = ins->get_operator().to_value(); value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>(); auto group = v.at("group").to<int>();
...@@ -85,10 +139,121 @@ struct find_mlir_op ...@@ -85,10 +139,121 @@ struct find_mlir_op
auto matcher() const auto matcher() const
{ {
auto dot_or_conv = match::skip(match::name("contiguous"))( auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), is_mlir_conv()).bind("gemm_based_op")); match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv())
.bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
} }
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) const
{
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
{
if(ins->name() != "@literal")
{
continue;
}
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
}
return ins_map;
}
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) const
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i) const
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::int8_type,
type_t::int32_type,
type_t::bool_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{
return false;
}
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {"convolution",
"quant_convolution",
"dot",
"quant_dot",
"add",
"clip",
"relu",
"sub",
"mul",
"div",
"pow",
"where",
"quantizelinear",
"dequantizelinear",
"abs",
"neg"};
const std::initializer_list<std::string> fp_only_ops = {"ceil",
"erf",
"exp",
"floor",
"log",
"recip",
"rsqrt",
"sigmoid"
"softmax",
"tanh"};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type && contains(no_bool_ops, name))
return true;
if(is_float && contains(fp_only_ops, name))
return true;
// Only conversions between floating types are known to be unambigiously
// supported.
if(is_float && name == "convert")
{
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
});
}
return false;
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -96,35 +261,25 @@ struct find_mlir_op ...@@ -96,35 +261,25 @@ struct find_mlir_op
auto x_ins = r.instructions["x"]; // input after contiguous auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
// Whitelist pointwise operators // Whitelist pointwise operators.
if(std::any_of(pm->begin(), pm->end(), [](const auto& i) { if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
return not contains( return not is_pointwise_op_supported_by_mlir(i);
{"@literal", "@param", "@return", "convolution", "dot", "add", "relu"},
i.name());
}))
return;
// Only fuse with fp32/fp16
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type, shape::type_t::half_type},
i->get_shape().type());
})) }))
return; return;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name()); module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass(); mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map; std::unordered_map<instruction_ref, instruction_ref> param_map =
auto x = mm->add_parameter("x" + std::to_string(names.size()), create_param_map_with_literals(mm, pm, gemm_based_op->get_shape());
gemm_based_op->inputs().at(0)->get_shape()); auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op);
auto w = mm->add_parameter("x" + std::to_string(names.size() + 1),
gemm_based_op->inputs().at(1)->get_shape());
auto conv = mm->add_instruction(gemm_based_op->get_operator(), {x, w});
std::transform(names.begin(), std::transform(names.begin(),
names.end(), names.end(),
ins->inputs().begin(), ins->inputs().begin(),
std::inserter(param_map, param_map.end()), std::inserter(param_map, param_map.end()),
[&](auto name, auto input) { [&, &anchor_op = anchor_op](auto name, auto input) {
if(input == x_ins) if(input == x_ins)
return std::make_pair(pm->get_parameter(name), conv); return std::make_pair(pm->get_parameter(name), anchor_op);
return std::make_pair(pm->get_parameter(name), return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape())); mm->add_parameter(name, input->get_shape()));
}); });
...@@ -135,7 +290,7 @@ struct find_mlir_op ...@@ -135,7 +290,7 @@ struct find_mlir_op
ins->inputs().end(), ins->inputs().end(),
std::back_inserter(inputs), std::back_inserter(inputs),
[&](auto input) { return input != gemm_based_op; }); [&](auto input) { return input != gemm_based_op; });
inputs.insert(inputs.end(), gemm_based_op->inputs().begin(), gemm_based_op->inputs().end()); inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
mpm.get_module().replace_instruction( mpm.get_module().replace_instruction(
ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm}); ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
} }
...@@ -148,17 +303,7 @@ struct find_mlir_op ...@@ -148,17 +303,7 @@ struct find_mlir_op
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{}); match::find_matches(mpm, find_mlir_op{});
if(mlir_enabled)
{
match::find_matches(mpm, find_mlir_op{});
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
}
#else #else
(void)mpm; (void)mpm;
#endif #endif
......
...@@ -165,7 +165,8 @@ struct fusion ...@@ -165,7 +165,8 @@ struct fusion
const std::unordered_set<std::string>& get_supported_archs() const std::unordered_set<std::string>& get_supported_archs()
{ {
static std::unordered_set<std::string> supported_archs{"gfx900", "gfx906", "gfx908", "gfx1030"}; static std::unordered_set<std::string> supported_archs{
"gfx900", "gfx906", "gfx908", "gfx1030", "gfx940"};
return supported_archs; return supported_archs;
} }
......
...@@ -140,12 +140,10 @@ void gemm_impl(context& ctx, ...@@ -140,12 +140,10 @@ void gemm_impl(context& ctx,
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 rocblas_gemm_flags flag = rocblas_gemm_flags_none;
rocblas_gemm_flags flag = #if ROCBLAS_VERSION_MAJOR < 3
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none; if(int8_x4_format)
#else flag = rocblas_gemm_flags_pack_int8x4;
(void)int8_x4_format;
int flag = 0;
#endif #endif
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment