Commit 8d32c6b8 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into blas_tuning

parents 23cb7917 f25606f9
...@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes) ...@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it->max}; dds_it->max};
} }
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/** /**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if` * Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those * and `loop` instructions, depending on how the submodules for those
...@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const ...@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens}); dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins); auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs}); submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod); submodules.push_back(submod);
} }
// redirect to select_module operator and return // redirect to select_module operator and return
......
...@@ -91,6 +91,19 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op> ...@@ -91,6 +91,19 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op>
} }
}; };
template <class F>
struct execute_wrapper
{
F f;
argument operator()(context&, const std::vector<argument>& args) const { return f(args); }
};
template <class F>
execute_wrapper<F> make_execute_wrapper(F f)
{
return {std::move(f)};
}
template <class Derived, class Primitive> template <class Derived, class Primitive>
struct dnnl_op : auto_register_op<Derived> struct dnnl_op : auto_register_op<Derived>
{ {
...@@ -308,7 +321,7 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -308,7 +321,7 @@ struct dnnl_op : auto_register_op<Derived>
#ifndef NDEBUG #ifndef NDEBUG
auto prim_attr = get_primitive_attr(md); auto prim_attr = get_primitive_attr(md);
#endif #endif
execute = [=](context&, const std::vector<argument>& args) { execute = make_execute_wrapper([=](const std::vector<argument>& args) {
#ifndef NDEBUG #ifndef NDEBUG
// Check that the memory descriptors have not changed // Check that the memory descriptors have not changed
auto debug_args = args; auto debug_args = args;
...@@ -379,7 +392,7 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -379,7 +392,7 @@ struct dnnl_op : auto_register_op<Derived>
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]); m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m); prim.execute(get_dnnl_context().stream, m);
return args.back(); return args.back();
}; });
} }
std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const
{ {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP #ifndef MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#define MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP #define MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#include <migraphx/config.hpp> #include <migraphx/cpu/context.hpp>
#include <string> #include <string>
namespace migraphx { namespace migraphx {
...@@ -34,9 +34,7 @@ struct module; ...@@ -34,9 +34,7 @@ struct module;
namespace cpu { namespace cpu {
struct context; struct MIGRAPHX_CPU_EXPORT fuse_ops
struct fuse_ops
{ {
context* ctx = nullptr; context* ctx = nullptr;
std::string name() const { return "cpu::fuse_ops"; } std::string name() const { return "cpu::fuse_ops"; }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#include <array>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
...@@ -23,6 +23,10 @@ ...@@ -23,6 +23,10 @@
# #################################################################################### # ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm) list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
if(NOT GPU_TARGETS)
message(FATAL_ERROR "HIP package is broken and has no GPU_TARGETS, please pass -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to cmake to build for your gpu.")
endif()
find_package(miopen) find_package(miopen)
# rocblas # rocblas
...@@ -44,12 +48,21 @@ else() ...@@ -44,12 +48,21 @@ else()
set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs") set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs")
endif() endif()
include(Embed)
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
if(WIN32)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM KERNEL_FILES
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck.hpp)
endif()
include(Embed)
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/) add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp) file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
add_library(migraphx_device ${DEVICE_GPU_SRCS}) add_library(migraphx_device ${DEVICE_GPU_SRCS})
...@@ -69,6 +82,7 @@ rocm_clang_tidy_check(migraphx_device) ...@@ -69,6 +82,7 @@ rocm_clang_tidy_check(migraphx_device)
target_link_libraries(migraphx_device PUBLIC migraphx) target_link_libraries(migraphx_device PUBLIC migraphx)
target_link_libraries(migraphx_device PRIVATE compile_for_gpu) target_link_libraries(migraphx_device PRIVATE compile_for_gpu)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_BINAR_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
target_compile_options(migraphx_device PRIVATE -Wno-ignored-attributes) target_compile_options(migraphx_device PRIVATE -Wno-ignored-attributes)
migraphx_generate_export_header(migraphx_device DIRECTORY migraphx/gpu/device) migraphx_generate_export_header(migraphx_device DIRECTORY migraphx/gpu/device)
...@@ -123,6 +137,7 @@ add_library(migraphx_gpu ...@@ -123,6 +137,7 @@ add_library(migraphx_gpu
lrn.cpp lrn.cpp
mlir.cpp mlir.cpp
multinomial.cpp multinomial.cpp
no_device.cpp
nonzero.cpp nonzero.cpp
pack_args.cpp pack_args.cpp
pack_int8_args.cpp pack_int8_args.cpp
...@@ -191,7 +206,7 @@ register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp ...@@ -191,7 +206,7 @@ register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "") set(MIGRAPHX_ENABLE_MLIR ON CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR) if(MIGRAPHX_ENABLE_MLIR)
# Find package rocMLIR # Find package rocMLIR
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#include <deque>
#ifdef MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h> #include <hip/hiprtc.h>
...@@ -92,7 +93,7 @@ struct hiprtc_program ...@@ -92,7 +93,7 @@ struct hiprtc_program
{ {
struct string_array struct string_array
{ {
std::vector<std::string> strings{}; std::deque<std::string> strings{};
std::vector<const char*> c_strs{}; std::vector<const char*> c_strs{};
string_array() {} string_array() {}
...@@ -115,6 +116,12 @@ struct hiprtc_program ...@@ -115,6 +116,12 @@ struct hiprtc_program
std::string cpp_src = ""; std::string cpp_src = "";
std::string cpp_name = ""; std::string cpp_name = "";
hiprtc_program(const std::string& src, const std::string& name = "main.cpp")
: cpp_src(src), cpp_name(name)
{
create_program();
}
hiprtc_program(std::vector<hiprtc_src_file> srcs) hiprtc_program(std::vector<hiprtc_src_file> srcs)
{ {
for(auto&& src : srcs) for(auto&& src : srcs)
...@@ -130,6 +137,14 @@ struct hiprtc_program ...@@ -130,6 +137,14 @@ struct hiprtc_program
include_names.push_back(std::move(src.path)); include_names.push_back(std::move(src.path));
} }
} }
create_program();
}
void create_program()
{
assert(not cpp_src.empty());
assert(not cpp_name.empty());
assert(headers.size() == include_names.size());
prog = hiprtc_program_create(cpp_src.c_str(), prog = hiprtc_program_create(cpp_src.c_str(),
cpp_name.c_str(), cpp_name.c_str(),
headers.size(), headers.size(),
...@@ -137,7 +152,7 @@ struct hiprtc_program ...@@ -137,7 +152,7 @@ struct hiprtc_program
include_names.data()); include_names.data());
} }
void compile(const std::vector<std::string>& options) const void compile(const std::vector<std::string>& options, bool quiet = false) const
{ {
if(enabled(MIGRAPHX_TRACE_HIPRTC{})) if(enabled(MIGRAPHX_TRACE_HIPRTC{}))
std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl; std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl;
...@@ -148,7 +163,7 @@ struct hiprtc_program ...@@ -148,7 +163,7 @@ struct hiprtc_program
[](const std::string& s) { return s.c_str(); }); [](const std::string& s) { return s.c_str(); });
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data()); auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
auto prog_log = log(); auto prog_log = log();
if(not prog_log.empty()) if(not prog_log.empty() and not quiet)
{ {
std::cerr << prog_log << std::endl; std::cerr << prog_log << std::endl;
} }
...@@ -195,7 +210,6 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr ...@@ -195,7 +210,6 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
options.push_back("-Wno-gnu-line-marker"); options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast"); options.push_back("-Wno-old-style-cast");
} }
if(enabled(MIGRAPHX_GPU_DEBUG{})) if(enabled(MIGRAPHX_GPU_DEBUG{}))
options.push_back("-DMIGRAPHX_DEBUG"); options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) { if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
...@@ -210,6 +224,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr ...@@ -210,6 +224,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
return {prog.get_code_obj()}; return {prog.get_code_obj()};
} }
bool hip_has_flags(const std::vector<std::string>& flags)
{
hiprtc_program prog{" "};
try
{
prog.compile(flags, true);
return true;
}
catch(...)
{
return false;
}
}
std::vector<std::vector<char>> std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{ {
...@@ -220,7 +248,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -220,7 +248,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{ {
if(src.path.extension() != ".cpp") if(src.path.extension() != ".cpp")
continue; continue;
std::cout << std::string(src.content.first, src.len()) << std::endl; std::cout << std::string(src.content) << std::endl;
} }
} }
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc); auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
...@@ -310,7 +338,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -310,7 +338,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{ {
if(src.path.extension() != ".cpp") if(src.path.extension() != ".cpp")
continue; continue;
std::cout << std::string(src.content.first, src.len()) << std::endl; std::cout << std::string(src.content) << std::endl;
} }
} }
...@@ -323,6 +351,27 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -323,6 +351,27 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return {compiler.compile(srcs)}; return {compiler.compile(srcs)};
} }
bool hip_has_flags(const std::vector<std::string>& flags)
{
src_compiler compiler;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
compiler.flags =
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
std::string src;
src_file input{"main.cpp", src};
try
{
compiler.compile({input});
return true;
}
catch(...)
{
return false;
}
}
#endif // MIGRAPHX_USE_HIPRTC #endif // MIGRAPHX_USE_HIPRTC
std::string enum_params(std::size_t count, std::string param) std::string enum_params(std::size_t count, std::string param)
......
...@@ -91,28 +91,39 @@ __content__ ...@@ -91,28 +91,39 @@ __content__
return replace_string(args_hpp, "__content__", inner); return replace_string(args_hpp, "__content__", inner);
} }
static std::vector<std::string> get_compiler_warnings()
{
std::vector<std::string> warnings = {
"-Weverything",
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions",
};
if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"}))
warnings.push_back("-Wno-unsafe-buffer-usage");
return warnings;
}
const std::vector<std::string>& compiler_warnings() const std::vector<std::string>& compiler_warnings()
{ {
static std::vector<std::string> warnings = {"-Weverything", static std::vector<std::string> warnings = get_compiler_warnings();
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions"};
return warnings; return warnings;
} }
...@@ -161,21 +172,17 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -161,21 +172,17 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert(options.inputs.size() == options.virtual_inputs.size() or assert(options.inputs.size() == options.virtual_inputs.size() or
options.virtual_inputs.empty()); options.virtual_inputs.empty());
std::vector<src_file> srcs = options.additional_src_files; std::vector<src_file> srcs = options.additional_src_files;
std::transform(migraphx_kernels().begin(), static auto kernels{::migraphx_kernels()};
migraphx_kernels().end(), std::transform(
std::back_inserter(srcs), kernels.begin(),
[](auto&& p) { kernels.end(),
auto&& name = p.first; std::back_inserter(srcs),
auto&& c = p.second; [](const std::pair<std::string_view, std::string_view>& elem) { return src_file{elem}; });
auto path = name; srcs.emplace_back("main.cpp", content);
return src_file{path, c};
});
srcs.push_back(src_file{fs::path{"main.cpp"},
std::make_pair(content.data(), content.data() + content.size())});
auto args_hpp = auto args_hpp =
generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs); generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs);
srcs.push_back(src_file{fs::path{"args.hpp"}, srcs.emplace_back("args.hpp", args_hpp);
std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())});
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global); options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local); options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " "); options.params += " " + join_strings(compiler_warnings(), " ");
......
...@@ -185,8 +185,7 @@ struct compile_plan ...@@ -185,8 +185,7 @@ struct compile_plan
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) { results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) {
if(not cr.has_value()) if(not cr.has_value())
return std::numeric_limits<double>::max(); return std::numeric_limits<double>::max();
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20) return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
.first;
}); });
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end())); auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl; std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
......
...@@ -26,7 +26,9 @@ ...@@ -26,7 +26,9 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/targets.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -79,13 +81,28 @@ inline auto launch(hipStream_t stream, index_int global, index_int local) ...@@ -79,13 +81,28 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
using f_type = decltype(f); using f_type = decltype(f);
dim3 nblocks(global / local); dim3 nblocks(global / local);
dim3 nthreads(local); dim3 nthreads(local);
/*
hipGetLastError() returns error for the first failed HIP call that happened previously.
MIGraphX calls into various backend libraries and failed HIP calls can also happen there.
Calling hipGetLastError() would reset error code to hipSuccess, so that inside MIGraphX
failed call to hipLaunchKernelGGL() can be captured.
*/
hipError_t flush_call = hipGetLastError();
(void)(flush_call);
// cppcheck-suppress UseDeviceLaunch // cppcheck-suppress UseDeviceLaunch
hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f); hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f);
hipError_t kernel_launch_status = hipGetLastError(); hipError_t kernel_launch_status = hipGetLastError();
if(kernel_launch_status != hipSuccess) if(kernel_launch_status != hipSuccess)
{ {
MIGRAPHX_THROW("MIGraphX device kernel failed to launch with error: " + std::string message = hipGetErrorString(kernel_launch_status);
std::string(hipGetErrorString(kernel_launch_status))); if(not contains(get_targets(), get_device_name()))
{
message += ". Trying to run a kernel for " + get_device_name() +
" but MIGraphX was built for targets " + get_targets_as_string() +
". Please rebuild MIGraphX with -DGPU_TARGETS='" + get_device_name() +
"'.";
}
MIGRAPHX_THROW("MIGraphX device kernel failed to launch with error: " + message);
} }
}; };
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/device/targets.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
static std::vector<std::string> parse_targets() { return split_string(MIGRAPHX_GPU_TARGETS, ';'); }
const std::vector<std::string>& get_targets()
{
static auto result = parse_targets();
return result;
}
std::string get_targets_as_string() { return join_strings(get_targets(), ", "); }
static int get_device_id()
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
MIGRAPHX_THROW("No device");
return device;
}
std::string get_device_name()
{
hipDeviceProp_t props{};
auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties");
return props.gcnArchName;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#include <migraphx/gpu/device/config.hpp>
#include <string>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
#define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT
MIGRAPHX_DEVICE_EXPORT
const std::vector<std::string>& get_targets();
MIGRAPHX_DEVICE_EXPORT
std::string get_targets_as_string();
MIGRAPHX_DEVICE_EXPORT
std::string get_device_name();
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
...@@ -38,10 +38,8 @@ struct compile_op : action<compile_op> ...@@ -38,10 +38,8 @@ struct compile_op : action<compile_op>
context ctx; context ctx;
auto inputs = p.parse_shapes(v.at("inputs")); auto inputs = p.parse_shapes(v.at("inputs"));
auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v); auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v);
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms"; std::cout << op << ": " << t << "ms";
if(device_time > 0)
std::cout << ", " << device_time << "ms";
std::cout << std::endl; std::cout << std::endl;
} }
}; };
......
...@@ -43,8 +43,8 @@ struct run_op : action<run_op> ...@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto op = make_op(name); auto op = make_op(name);
if(v.contains("fields")) if(v.contains("fields"))
op.from_value(v.at("fields")); op.from_value(v.at("fields"));
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms" << std::endl; std::cout << op << ": " << t << "ms" << std::endl;
} }
}; };
......
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/fuse_ck.hpp> #include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
namespace migraphx { namespace migraphx {
...@@ -55,7 +55,7 @@ struct ck_gemm ...@@ -55,7 +55,7 @@ struct ck_gemm
{ {
check_shapes{inputs, *this}.same_ndims(); check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2) if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW(name() + ": should have at least two inputs.");
auto a = inputs[0]; auto a = inputs[0];
auto b = inputs[1]; auto b = inputs[1];
for(const auto& input : inputs) for(const auto& input : inputs)
...@@ -65,21 +65,27 @@ struct ck_gemm ...@@ -65,21 +65,27 @@ struct ck_gemm
return r; return r;
return r.with_type(mods.front()->get_output_shapes().front().type()); return r.with_type(mods.front()->get_output_shapes().front().type());
} }
static bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
}; };
MIGRAPHX_REGISTER_OP(ck_gemm); MIGRAPHX_REGISTER_OP(ck_gemm);
namespace { struct ck_gemm_softmax_gemm : gemm_softmax_gemm
bool is_ck_supported_type(shape::type_t t)
{ {
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t); std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
} };
MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{ {
if(ins->name() != "dot" and ins->name() != "quant_dot") if(ins->name() != "dot" and ins->name() != "quant_dot")
return false; return false;
if(not is_ck_supported_type(ins->get_shape().type())) if(not ck_gemm::is_ck_supported_type(ins->get_shape().type()))
return false; return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
...@@ -127,7 +133,11 @@ struct find_ck_gemm_pointwise ...@@ -127,7 +133,11 @@ struct find_ck_gemm_pointwise
ins->get_shape().type() != gemm_ins->get_shape().type()) ins->get_shape().type() != gemm_ins->get_shape().type())
return; return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) { if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not is_ck_supported_type(input->get_shape().type()); return not ck_gemm::is_ck_supported_type(input->get_shape().type());
}))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
})) }))
return; return;
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
...@@ -152,7 +162,7 @@ struct find_ck_gemm_pointwise ...@@ -152,7 +162,7 @@ struct find_ck_gemm_pointwise
struct find_ck_gemm struct find_ck_gemm
{ {
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); } auto matcher() const { return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
...@@ -161,11 +171,26 @@ struct find_ck_gemm ...@@ -161,11 +171,26 @@ struct find_ck_gemm
} }
}; };
struct find_ck_gemm_softmax_gemm
{
auto matcher() const { return match::name("gpu::pre_gemm_softmax_gemm"); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto v = ins->get_operator().to_value();
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
mpm.get_module().replace_instruction(
ins, ck_gemm_softmax_gemm{migraphx::make_op("dot"), scale}, ins->inputs());
}
};
} // namespace } // namespace
void fuse_ck::apply(module_pass_manager& mpm) const void fuse_ck::apply(module_pass_manager& mpm) const
{ {
match::find_matches(mpm, find_ck_gemm_pointwise{}); match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{}); match::find_matches(mpm, find_ck_gemm{});
} }
......
...@@ -36,24 +36,14 @@ struct module; ...@@ -36,24 +36,14 @@ struct module;
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
bool mlir_enabled() bool mlir_enabled()
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{}); const bool mlir_disabled = enabled(MIGRAPHX_DISABLE_MLIR{});
if(mlir_enabled) return not mlir_disabled;
{
return true;
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
return false;
}
#else #else
return false; return false;
#endif #endif
...@@ -103,7 +93,10 @@ struct mlir_op ...@@ -103,7 +93,10 @@ struct mlir_op
} }
if(ins->name() == "@return") if(ins->name() == "@return")
{ {
return ins_shapes[ins->inputs().at(0)].with_type(type); auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
} }
std::vector<shape> input_shapes; std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size()); input_shapes.resize(ins->inputs().size());
...@@ -119,28 +112,107 @@ struct mlir_op ...@@ -119,28 +112,107 @@ struct mlir_op
MIGRAPHX_REGISTER_OP(mlir_op); MIGRAPHX_REGISTER_OP(mlir_op);
namespace { namespace {
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains(
{"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
input->name()))
{
operation op = input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, input->name()))
{
op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}});
}
op_stream.push_back(op);
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) enum class mlir_mode
{ {
if(ins->name() != "convolution" and ins->name() != "quant_convolution") all,
return false; fast,
value v = ins->get_operator().to_value(); int8,
auto group = v.at("group").to<int>(); none
if(group != 1) };
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!" auto is_mlir_dot(mlir_mode mode)
if(ins->get_shape().lens().size() != 4) {
return false; return match::make_basic_pred_matcher([=](instruction_ref ins) {
return true; if(mode == mlir_mode::none)
return false;
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(mode != mlir_mode::fast)
return true;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
// auto m = a.lens()[a.lens().size() - 2];
// auto n = b.lens().back();
auto k = a.lens().back();
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from MLIR
// To-do: Investigate a more precise strategy
return k <= 2048;
});
} }
struct find_mlir_op auto is_mlir_conv(mlir_mode mode)
{ {
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(mode == mlir_mode::int8)
return false;
if(mode == mlir_mode::all)
return true;
auto w = ins->inputs().at(1)->get_shape();
if(w.lens().size() != 4)
return true;
if(w.lens()[2] != w.lens()[3])
return true;
return (w.lens()[3] % 3) != 0;
});
}
struct find_mlir_fused_ops
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const auto matcher() const
{ {
auto dot_or_conv = match::skip(match::name("contiguous"))( auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv()) match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
.bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
} }
...@@ -163,34 +235,6 @@ struct find_mlir_op ...@@ -163,34 +235,6 @@ struct find_mlir_op
return ins_map; return ins_map;
} }
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) const
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
// Whitelist supported fusion options, including imposing type constraints // Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function) // for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types. // on particular types.
...@@ -236,8 +280,7 @@ struct find_mlir_op ...@@ -236,8 +280,7 @@ struct find_mlir_op
"log", "log",
"recip", "recip",
"rsqrt", "rsqrt",
// There are bugs in MLIR right now for models using sigmoid so disable it for now "sigmoid",
// "sigmoid",
"softmax", "softmax",
"tanh", "tanh",
}; };
...@@ -282,9 +325,9 @@ struct find_mlir_op ...@@ -282,9 +325,9 @@ struct find_mlir_op
names.end(), names.end(),
ins->inputs().begin(), ins->inputs().begin(),
std::inserter(param_map, param_map.end()), std::inserter(param_map, param_map.end()),
[&, &anchor_op = anchor_op](auto name, auto input) { [&, &anchor = anchor_op](auto name, auto input) {
if(input == x_ins) if(input == x_ins)
return std::make_pair(pm->get_parameter(name), anchor_op); return std::make_pair(pm->get_parameter(name), anchor);
return std::make_pair(pm->get_parameter(name), return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape())); mm->add_parameter(name, input->get_shape()));
}); });
...@@ -301,20 +344,90 @@ struct find_mlir_op ...@@ -301,20 +344,90 @@ struct find_mlir_op
} }
}; };
template <auto Matcher>
struct find_mlir_standalone_op
{
mlir_mode mode = mlir_mode::none;
auto matcher() const { return Matcher(mode); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto conv_based_op = r.result;
// enable only for fp32/fp16/i8 types
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type());
}))
return;
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op);
mm->add_return({anchor_op});
mpm.get_module().replace_instruction(
conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm});
}
};
using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>;
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool is_requested(std::string_view option, bool fallback = false)
{
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ',');
return contains(options, option);
}
} // namespace } // namespace
#endif #endif // MIGRAPHX_MLIR
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_mlir_op{}); const auto& device_name = ctx == nullptr ? "" : ctx->get_current_device().get_gfx_name();
const bool is_navi = starts_with(device_name, "gfx110");
auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) {
if(is_requested(option))
return mlir_mode::all;
if(is_navi)
return mlir_mode::all;
return std::max(m1, m2);
};
mlir_mode mode =
(enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none;
match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)});
match::find_matches(
mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else #else
(void)mpm; (void)mpm;
#endif #endif
} }
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -790,22 +790,26 @@ struct find_layernorm_pointwise ...@@ -790,22 +790,26 @@ struct find_layernorm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return precompile_name("pointwise")(match::arg(0)( return precompile_name("pointwise")(match::any_of[match::inputs()](
precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm"))); precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto pw_ins = r.result;
auto layernorm = r.instructions["layernorm"]; auto layernorm = r.instructions["layernorm"];
if(not layernorm->module_inputs().empty()) if(not layernorm->module_inputs().empty())
return; return;
auto* pm = ins->module_inputs().front(); auto* pm = pw_ins->module_inputs().front();
auto pw_inputs = pw_ins->inputs();
auto ln_pos = std::find(pw_inputs.begin(), pw_inputs.end(), layernorm);
assert(ln_pos != pw_inputs.end());
pw_inputs.erase(ln_pos);
auto inputs = layernorm->inputs(); auto inputs = layernorm->inputs();
inputs.pop_back(); inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end()); inputs.insert(inputs.end(), pw_inputs.begin(), pw_inputs.end());
m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm}); m.replace_instruction(pw_ins, layernorm->get_operator(), inputs, {pm});
} }
}; };
......
...@@ -55,7 +55,7 @@ bool is_device_ptr(const void* ptr) ...@@ -55,7 +55,7 @@ bool is_device_ptr(const void* ptr)
auto status = hipPointerGetAttributes(&attr, ptr); auto status = hipPointerGetAttributes(&attr, ptr);
if(status != hipSuccess) if(status != hipSuccess)
return false; return false;
return attr.memoryType == hipMemoryTypeDevice; return attr.type == hipMemoryTypeDevice;
} }
std::size_t get_available_gpu_memory() std::size_t get_available_gpu_memory()
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/msgpack.hpp> #include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <array>
#include <iostream> #include <iostream>
#include <cstring> #include <cstring>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_CK_HPP
#define MIGRAPHX_GUARD_GPU_CK_HPP
#include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include <string_view>
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
#endif
// NOLINTNEXTLINE
const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__";
template <class P>
std::string ck_disable_warnings(P p)
{
return interpolate_string(disable_warning_pragma,
{{"content", std::string{p.data(), p.size()}}});
}
static std::unordered_map<std::string, std::string> create_ck_header_strings()
{
std::unordered_map<std::string, std::string> result;
auto ck_headers = ck::host::GetHeaders();
std::transform(
ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto& p) {
return std::pair<std::string, std::string>(p.first, ck_disable_warnings(p.second));
});
return result;
}
static std::vector<src_file> create_ck_headers()
{
static const auto& header_strings = create_ck_header_strings();
std::vector<src_file> srcs;
std::transform(header_strings.begin(),
header_strings.end(),
std::back_inserter(srcs),
[&](auto& p) { return src_file{p}; });
return srcs;
}
static inline const std::vector<src_file>& ck_headers()
{
static const auto& headers = create_ck_headers();
return headers;
}
inline bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
inline ck::host::DataType get_type(const shape& s)
{
if(s.type() == shape::half_type)
return ck::host::DataType::Half;
else if(s.type() == shape::float_type)
return ck::host::DataType::Float;
else if(s.type() == shape::int8_type)
return ck::host::DataType::Int8;
else if(s.type() == shape::int32_type)
return ck::host::DataType::Int32;
MIGRAPHX_THROW("Unsupported ck type");
}
inline std::size_t get_batch_count(const shape& s)
{
return std::accumulate(
s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
}
inline void fold_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto batch_count = get_batch_count(s);
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
if(transposed_matrix(s))
s = shape{s.type(), {m1, m2 * batch_count}};
else
s = shape{s.type(), {m1 * batch_count, m2}};
}
inline void remove_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
s = shape{s.type(), {m1, m2}};
}
inline bool standard_batch(const shape& s)
{
if(s.lens().size() < 3)
return true;
std::vector<std::size_t> lens(s.lens().begin(), s.lens().end() - 2);
std::vector<std::size_t> strides(s.strides().begin(), s.strides().end() - 2);
auto base = *(s.lens().end() - 2) * *(s.lens().end() - 1);
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto stride) {
return stride / base;
});
return shape{s.type(), lens, strides}.standard();
}
inline bool can_fold_batch(const std::vector<shape>& inputs)
{
const auto& b_shape = inputs[1];
if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) {
return not standard_batch(input);
}))
return false;
const auto& b_strides = b_shape.strides();
return std::all_of(
b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_CK_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment