Commit 722cf052 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

merge with uif2-initial branch

parent 2476750c
......@@ -81,7 +81,7 @@ cmake-build*/
build*/
# Recommended location to install rbuild dependencies from README.md
depend/
depend*/
# Python virtual environment
.venv/
......@@ -41,9 +41,10 @@ if(NOT MIGRAPHX_GENERATOR_IS_MULTI_CONFIG)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS ${CMAKE_CONFIGURATION_TYPES})
endif()
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
set(CMAKE_BUILD_RPATH "${CMAKE_BINARY_DIR}/lib")
if(NOT WIN32)
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
set(CMAKE_BUILD_RPATH "${CMAKE_BINARY_DIR}/lib")
endif()
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/llvm $ENV{ROCM_PATH} $ENV{HIP_PATH})
......@@ -59,15 +60,15 @@ else()
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
endif()
# By default build shared libraries
option(BUILD_SHARED_LIBS "Create shared libraries" ON)
if(WIN32) # CK is not yet ported to Windows
option(MIGRAPHX_USE_COMPOSABLEKERNEL "Enable MIGraphX to use composable kernel JIT library" OFF)
else()
option(MIGRAPHX_USE_COMPOSABLEKERNEL "Enable MIGraphX to use composable kernel JIT library" ON)
endif()
# By default build shared libraries
option(BUILD_SHARED_LIBS "Create shared libraries" ON)
if(WIN32)
add_compile_definitions($<$<COMPILE_LANGUAGE:C,CXX>:_CRT_SECURE_NO_WARNINGS>)
add_subdirectory(extern)
......@@ -111,9 +112,17 @@ set(MIGRAPHX_ENABLE_CPU Off CACHE BOOL "")
# Disable fpga backend by default
set(MIGRAPHX_ENABLE_FPGA Off CACHE BOOL "")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
if(WIN32)
add_compile_definitions(
$<$<COMPILE_LANGUAGE:C,CXX>:_CRT_SECURE_NO_WARNINGS>
$<$<COMPILE_LANGUAGE:C,CXX>:_USE_MATH_DEFINES>)
endif()
if(MSVC)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/std:c++17>)
else()
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:-std=c++17>)
endif()
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
include(EnableCompilerWarnings)
......
......@@ -21,10 +21,13 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
find_program(EMBED_LD ld)
find_program(EMBED_OBJCOPY objcopy)
option(EMBED_USE_LD "Use ld to embed data files" OFF)
option(EMBED_USE_BINARY "Use data file embedding to binary" ON)
if(EMBED_USE_BINARY AND NOT WIN32)
find_program(EMBED_LD ld REQUIRED)
find_program(EMBED_OBJCOPY objcopy REQUIRED)
endif()
function(wrap_string)
set(options)
......@@ -53,40 +56,76 @@ function(wrap_string)
set(${PARSE_VARIABLE} "${lines}" PARENT_SCOPE)
endfunction()
function(generate_embed_source EMBED_NAME SRC_FILE HEADER_FILE BASE_DIRECTORY)
function(generate_embed_source EMBED_NAME EMBED_DIR BASE_DIRECTORY)
set(options)
set(oneValueArgs "")
set(oneValueArgs)
set(multiValueArgs SYMBOLS FILES)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(RESOURCE_ID 100)
foreach(SYMBOL FILE IN ZIP_LISTS PARSE_SYMBOLS PARSE_FILES)
set(START_SYMBOL "_binary_${SYMBOL}_start")
set(LENGTH_SYMBOL "_binary_${SYMBOL}_length")
if(EMBED_USE_LD)
string(APPEND EXTERNS "
cmake_path(RELATIVE_PATH FILE BASE_DIRECTORY ${BASE_DIRECTORY} OUTPUT_VARIABLE BASE_NAME)
if(EMBED_USE_BINARY AND WIN32)
string(TOUPPER "${SYMBOL}" SYMBOL)
string(APPEND FILE_IDS "#define IDR_${SYMBOL} ${RESOURCE_ID}\n")
string(APPEND RC_MAPPING "IDR_${SYMBOL} TEXTFILE \"${BASE_NAME}\"\n")
string(APPEND INIT_KERNELS " {\"${BASE_NAME}\", resource::read(IDR_${SYMBOL})},\n")
math(EXPR RESOURCE_ID "${RESOURCE_ID} + 1" OUTPUT_FORMAT DECIMAL)
else()
set(START_SYMBOL "_binary_${SYMBOL}_start")
set(LENGTH_SYMBOL "_binary_${SYMBOL}_length")
if(EMBED_USE_BINARY)
string(APPEND EXTERNS "
extern const char ${START_SYMBOL}[];
extern const size_t _binary_${SYMBOL}_size;
const auto ${LENGTH_SYMBOL} = reinterpret_cast<size_t>(&_binary_${SYMBOL}_size);
")
else()
string(APPEND EXTERNS "
")
else()
string(APPEND EXTERNS "
extern const char ${START_SYMBOL}[];
extern const size_t ${LENGTH_SYMBOL};
")
")
endif()
string(APPEND INIT_KERNELS "
{ \"${BASE_NAME}\", { ${START_SYMBOL}, ${LENGTH_SYMBOL}} },")
endif()
cmake_path(RELATIVE_PATH FILE BASE_DIRECTORY ${BASE_DIRECTORY} OUTPUT_VARIABLE BASE_NAME)
string(APPEND INIT_KERNELS "
{ \"${BASE_NAME}\", { ${START_SYMBOL}, ${LENGTH_SYMBOL}} },")
endforeach()
if(EMBED_USE_BINARY AND WIN32)
file(WRITE "${EMBED_DIR}/include/resource.h" "
#define TEXTFILE 256
file(WRITE "${HEADER_FILE}" "
${FILE_IDS}
")
file(WRITE "${EMBED_DIR}/resource.rc" "
#include \"resource.h\"
${RC_FILE_MAPPING}
")
set(EXTERNS "
#include <Windows.h>
#include \"resource.h\"
namespace resource {
std::string_view read(int id)
{
HMODULE handle = GetModuleHandle(nullptr);
HRSRC rc = FindResource(handle, MAKEINTRESOURCE(id), MAKEINTRESOURCE(TEXTFILE));
HGLOBAL data = LoadResource(handle, rc);
return {static_cast<const char*>(LockResource(data)), SizeofResource(handle, rc)};
}
}
")
set(EMBED_FILES ${EMBED_DIR}/include/resource.h ${EMBED_DIR}/resource.rc)
endif()
file(WRITE "${EMBED_DIR}/include/${EMBED_NAME}.hpp" "
#include <string_view>
#include <unordered_map>
#include <utility>
std::unordered_map<std::string_view, std::string_view> ${EMBED_NAME}();
")
file(WRITE "${SRC_FILE}" "
file(WRITE "${EMBED_DIR}/${EMBED_NAME}.cpp" "
#include <${EMBED_NAME}.hpp>
${EXTERNS}
std::unordered_map<std::string_view, std::string_view> ${EMBED_NAME}()
......@@ -95,23 +134,28 @@ std::unordered_map<std::string_view, std::string_view> ${EMBED_NAME}()
return result;
}
")
list(APPEND EMBED_FILES ${EMBED_DIR}/${EMBED_NAME}.cpp ${EMBED_DIR}/include/${EMBED_NAME}.hpp)
set(EMBED_FILES ${EMBED_FILES} PARENT_SCOPE)
endfunction()
function(embed_file FILE BASE_DIRECTORY)
message(STATUS " ${FILE}")
cmake_path(RELATIVE_PATH FILE BASE_DIRECTORY ${BASE_DIRECTORY} OUTPUT_VARIABLE REL_FILE)
cmake_path(RELATIVE_PATH FILE BASE_DIRECTORY "${BASE_DIRECTORY}" OUTPUT_VARIABLE REL_FILE)
string(MAKE_C_IDENTIFIER "${REL_FILE}" OUTPUT_SYMBOL)
get_filename_component(OUTPUT_FILE_DIR "${REL_FILE}" DIRECTORY)
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE_DIR}")
if(EMBED_USE_LD)
set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.o")
add_custom_command(
OUTPUT ${OUTPUT_FILE}
COMMAND ${EMBED_LD} -r -o "${OUTPUT_FILE}" -z noexecstack --format=binary "${REL_FILE}"
COMMAND ${EMBED_OBJCOPY} --rename-section .data=.rodata,alloc,load,readonly,data,contents "${OUTPUT_FILE}"
WORKING_DIRECTORY ${BASE_DIRECTORY}
DEPENDS ${FILE}
VERBATIM)
if(EMBED_USE_BINARY)
if(NOT WIN32)
set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.o")
add_custom_command(
OUTPUT "${OUTPUT_FILE}"
COMMAND ${EMBED_LD} -r -o "${OUTPUT_FILE}" -z noexecstack --format=binary "${REL_FILE}"
COMMAND ${EMBED_OBJCOPY} --rename-section .data=.rodata,alloc,load,readonly,data,contents "${OUTPUT_FILE}"
WORKING_DIRECTORY "${BASE_DIRECTORY}"
DEPENDS "${FILE}"
VERBATIM)
set(OUTPUT_FILE ${OUTPUT_FILE} PARENT_SCOPE)
endif()
else()
set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.cpp")
# reads source file contents as hex string
......@@ -127,40 +171,38 @@ function(embed_file FILE BASE_DIRECTORY)
extern const char _binary_${OUTPUT_SYMBOL}_start[] = { ${ARRAY_VALUES} };
extern const size_t _binary_${OUTPUT_SYMBOL}_length = sizeof(_binary_${OUTPUT_SYMBOL}_start);
")
set(OUTPUT_FILE ${OUTPUT_FILE} PARENT_SCOPE)
endif()
set(OUTPUT_FILE ${OUTPUT_FILE} PARENT_SCOPE)
set(OUTPUT_SYMBOL ${OUTPUT_SYMBOL} PARENT_SCOPE)
endfunction()
function(add_embed_library EMBED_NAME)
set(options)
set(oneValueArgs BASE_DIRECTORY)
set(oneValueArgs RELATIVE)
set(multiValueArgs)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(EMBED_DIR ${CMAKE_CURRENT_BINARY_DIR}/embed/${EMBED_NAME})
file(MAKE_DIRECTORY ${EMBED_DIR})
set(SRC_FILE "${EMBED_DIR}/${EMBED_NAME}.cpp")
set(HEADER_FILE "${EMBED_DIR}/include/${EMBED_NAME}.hpp")
message(STATUS "Embedding kernel files:")
foreach(FILE ${PARSE_UNPARSED_ARGUMENTS})
embed_file(${FILE} ${PARSE_BASE_DIRECTORY})
embed_file(${FILE} ${PARSE_RELATIVE})
list(APPEND OUTPUT_FILES ${OUTPUT_FILE})
list(APPEND SYMBOLS ${OUTPUT_SYMBOL})
endforeach()
message(STATUS "Generating embedding library '${EMBED_NAME}'")
generate_embed_source(${EMBED_NAME} ${SRC_FILE} ${HEADER_FILE} "${PARSE_BASE_DIRECTORY}" SYMBOLS ${SYMBOLS} FILES ${PARSE_UNPARSED_ARGUMENTS})
add_library(embed_lib_${EMBED_NAME} OBJECT ${SRC_FILE} ${HEADER_FILE})
if(NOT EMBED_USE_LD)
target_sources(embed_lib_${EMBED_NAME} PRIVATE ${OUTPUT_FILES})
generate_embed_source(${EMBED_NAME} ${EMBED_DIR} "${PARSE_RELATIVE}" SYMBOLS ${SYMBOLS} FILES ${PARSE_UNPARSED_ARGUMENTS})
set(INTERNAL_EMBED_LIB embed_lib_${EMBED_NAME})
add_library(${INTERNAL_EMBED_LIB} OBJECT ${EMBED_FILES})
if(NOT EMBED_USE_BINARY)
target_sources(${INTERNAL_EMBED_LIB} PRIVATE ${OUTPUT_FILES})
endif()
target_include_directories(embed_lib_${EMBED_NAME} PUBLIC ${EMBED_DIR}/include)
target_compile_options(embed_lib_${EMBED_NAME} PRIVATE
-Wno-reserved-identifier -Wno-extern-initializer -Wno-missing-variable-declarations)
set_target_properties(embed_lib_${EMBED_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
add_library(${EMBED_NAME} INTERFACE $<TARGET_OBJECTS:embed_lib_${EMBED_NAME}> ${OUTPUT_FILES})
target_link_libraries(${EMBED_NAME} INTERFACE $<TARGET_OBJECTS:embed_lib_${EMBED_NAME}>)
if(EMBED_USE_LD)
target_link_libraries(${EMBED_NAME} INTERFACE ${OUTPUT_FILES})
target_include_directories(${INTERNAL_EMBED_LIB} PRIVATE "${EMBED_DIR}/include")
target_compile_options(${INTERNAL_EMBED_LIB} PRIVATE -Wno-reserved-identifier -Wno-extern-initializer -Wno-missing-variable-declarations)
set_target_properties(${INTERNAL_EMBED_LIB} PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(${EMBED_NAME} INTERFACE $<TARGET_OBJECTS:${INTERNAL_EMBED_LIB}> ${OUTPUT_FILES})
if(EMBED_USE_BINARY AND WIN32)
target_link_libraries(${EMBED_NAME} INTERFACE $<TARGET_OBJECTS:${INTERNAL_EMBED_LIB}>)
endif()
target_include_directories(${EMBED_NAME} INTERFACE ${EMBED_DIR}/include)
target_include_directories(${EMBED_NAME} INTERFACE "${EMBED_DIR}/include")
endfunction()
......@@ -23,77 +23,95 @@
#####################################################################################
# - Enable warning all for gcc/clang or use /W4 for visual studio
## Strict compile options for Visual C++ compiler
set(__default_msvc_compile_options /w)
## Strict warning level
if (MSVC)
# Use the highest warning level for visual studio.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /w")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /w")
# set(CMAKE_CXX_WARNING_LEVEL 4)
# if (CMAKE_CXX_FLAGS MATCHES "/W[0-4]")
# string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
# else ()
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4")
# endif ()
## Strict compile options for GNU/Clang compilers
set(__default_compile_options
-Wall -Wextra
-Wcomment
-Wendif-labels
-Wformat
-Winit-self
-Wreturn-type
-Wsequence-point
-Wswitch
-Wtrigraphs
-Wundef
-Wuninitialized
-Wunreachable-code
-Wunused
-Wno-sign-compare
-Wno-reserved-macro-identifier)
# set(CMAKE_C_WARNING_LEVEL 4)
# if (CMAKE_C_FLAGS MATCHES "/W[0-4]")
# string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
# else ()
# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W4")
# endif ()
## Strict compile options for Clang compilers
set(__default_clang_compile_options
-Weverything
-Wshadow
-Wno-c++98-compat
-Wno-c++98-compat-pedantic
-Wno-conversion
-Wno-double-promotion
-Wno-exit-time-destructors
-Wno-extra-semi
-Wno-extra-semi-stmt
-Wno-float-conversion
-Wno-gnu-anonymous-struct
-Wno-gnu-zero-variadic-macro-arguments
-Wno-missing-prototypes
-Wno-nested-anon-types
-Wno-option-ignored
-Wno-padded
-Wno-shorten-64-to-32
-Wno-sign-conversion
-Wno-unused-command-line-argument
-Wno-weak-vtables
-Wno-c99-extensions
-fno-sanitize=function,vptr)
if(WIN32)
list(APPEND __default_clang_compile_options
-fms-extensions
-fms-compatibility
-fdelayed-template-parsing)
endif()
set(__default_gnu_compile_options
-Wduplicated-branches
-Wduplicated-cond
-Wno-noexcept-type
-Wodr
-Wshift-negative-value
-Wshift-overflow=2
-Wno-missing-field-initializers
-Wno-maybe-uninitialized)
add_compile_options(
"$<$<OR:$<CXX_COMPILER_ID:MSVC>,$<C_COMPILER_ID:MSVC>>:${__default_msvc_compile_options}>"
"$<$<OR:$<CXX_COMPILER_ID:GNU,Clang>,$<C_COMPILER_ID:GNU,Clang>>:${__default_compile_options}>"
"$<$<OR:$<AND:$<CXX_COMPILER_ID:GNU>,$<VERSION_GREATER_EQUAL:$<CXX_COMPILER_VERSION>,7>>,$<AND:$<C_COMPILER_ID:GNU>,$<VERSION_GREATER_EQUAL:$<C_COMPILER_VERSION>,7>>>:${__default_gnu_compile_options}>"
"$<$<OR:$<CXX_COMPILER_ID:Clang>,$<C_COMPILER_ID:Clang>>:${__default_clang_compile_options}>")
unset(__default_msvc_compile_options)
unset(__default_compile_options)
unset(__default_gnu_compile_options)
unset(__default_clang_compile_options)
else()
foreach(COMPILER C CXX)
set(CMAKE_COMPILER_WARNINGS)
# use -Wall for gcc and clang
list(APPEND CMAKE_COMPILER_WARNINGS
-Wall
-Wextra
-Wcomment
-Wendif-labels
-Wformat
-Winit-self
-Wreturn-type
-Wsequence-point
# Shadow is broken on gcc when using lambdas
# -Wshadow
-Wswitch
-Wtrigraphs
-Wundef
-Wuninitialized
-Wunreachable-code
-Wunused
-Wno-sign-compare
)
# Flags for gcc 7
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "7.0")
list(APPEND CMAKE_COMPILER_WARNINGS
-Wduplicated-branches
-Wduplicated-cond
-Wno-noexcept-type
-Wodr
-Wshift-negative-value
-Wshift-overflow=2
)
endif()
endif()
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
list(APPEND CMAKE_COMPILER_WARNINGS
-Weverything
-Wno-c++98-compat
-Wno-c++98-compat-pedantic
-Wno-conversion
-Wno-double-promotion
-Wno-exit-time-destructors
-Wno-extra-semi
-Wno-extra-semi-stmt
-Wno-float-conversion
-Wno-gnu-anonymous-struct
-Wno-gnu-zero-variadic-macro-arguments
-Wno-missing-prototypes
-Wno-nested-anon-types
-Wno-option-ignored
-Wno-padded
-Wno-shorten-64-to-32
-Wno-sign-conversion
-Wno-unused-command-line-argument
-Wno-weak-vtables
-Wno-c99-extensions
# -Wno-c++2a-designator
)
else()
list(APPEND CMAKE_COMPILER_WARNINGS
-Wno-missing-field-initializers
-Wno-maybe-uninitialized
# -Wno-deprecated-declarations
)
endif()
foreach(COMPILER_WARNING ${CMAKE_COMPILER_WARNINGS})
add_compile_options($<$<COMPILE_LANGUAGE:${COMPILER}>:${COMPILER_WARNING}>)
endforeach()
endforeach()
endif ()
......@@ -188,11 +188,11 @@ ExternalProject_Add(
BUILD_COMMAND ${NMAKE_EXECUTABLE} /f ..\\sqlite3\\Makefile.msc USE_AMALGAMATION=1 NO_TCL=1 TOP=..\\sqlite3 libsqlite3.lib
INSTALL_COMMAND "")
add_library(SQLite::SQLite3 INTERFACE IMPORTED GLOBAL)
add_dependencies(SQLite::SQLite3 sqlite3)
ExternalProject_Get_Property(sqlite3 BINARY_DIR)
# For compatibility with PkgConfig on Linux
add_library(PkgConfig::SQLITE3 INTERFACE IMPORTED GLOBAL)
add_dependencies(PkgConfig::SQLITE3 sqlite3)
target_link_directories(PkgConfig::SQLITE3 INTERFACE ${BINARY_DIR})
target_link_libraries(PkgConfig::SQLITE3 INTERFACE libsqlite3.lib)
target_include_directories(PkgConfig::SQLITE3 INTERFACE ${BINARY_DIR})
target_link_directories(SQLite::SQLite3 INTERFACE ${BINARY_DIR})
target_link_libraries(SQLite::SQLite3 INTERFACE libsqlite3.lib)
target_include_directories(SQLite::SQLite3 INTERFACE ${BINARY_DIR})
......@@ -28,7 +28,11 @@ include(ROCMInstallTargets)
include(ROCMPackageConfigHelpers)
include(RegisterOp)
include(CheckCXXLinkerFlag)
if(WIN32)
# Due to compilation crashing, we need to use type-erased matchers on Windows.
add_compile_definitions($<$<COMPILE_LANGUAGE:C,CXX>:MIGRAPHX_USE_TYPE_ERASED_MATCHERS=1>)
endif()
add_library(migraphx
adjust_allocation.cpp
......@@ -263,7 +267,9 @@ endif()
target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json)
migraphx_generate_export_header(migraphx)
find_package(SQLite3 REQUIRED)
if(NOT WIN32)
find_package(SQLite3 REQUIRED)
endif()
target_link_libraries(migraphx PRIVATE SQLite::SQLite3)
if(NOT WIN32)
......
......@@ -43,7 +43,6 @@ if(NOT WIN32)
)
set_directory_properties(PROPERTIES ADDITIONAL_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver)
endif()
target_compile_options(driver PRIVATE -Wno-ignored-attributes -Wno-unused-parameter)
rocm_clang_tidy_check(driver)
file(STRINGS "${CMAKE_SOURCE_DIR}/test/onnx/.onnxrt-commit" String_output)
......
......@@ -105,6 +105,8 @@ inline std::ostream& operator<<(std::ostream& os, const color& c)
static const bool use_color = isatty(STDOUT_FILENO) != 0;
if(use_color)
return os << "\033[" << static_cast<std::size_t>(c) << "m";
#else
(void)c;
#endif
return os;
}
......
......@@ -603,8 +603,7 @@ struct version : command<version>
void run() const
{
std::cout << "MIGraphX Version: " << MIGRAPHX_VERSION_MAJOR << "." << MIGRAPHX_VERSION_MINOR
<< "." << MIGRAPHX_VERSION_PATCH << "."
<< MIGRAPHX_STRINGIZE(MIGRAPHX_VERSION_TWEAK) << std::endl;
<< "." << MIGRAPHX_VERSION_PATCH << "." MIGRAPHX_VERSION_TWEAK << std::endl;
}
};
......@@ -762,8 +761,7 @@ struct main_command
{
std::string version_str = "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) +
"." + std::to_string(MIGRAPHX_VERSION_MINOR) + "." +
std::to_string(MIGRAPHX_VERSION_PATCH) + "." +
MIGRAPHX_STRINGIZE(MIGRAPHX_VERSION_TWEAK);
std::to_string(MIGRAPHX_VERSION_PATCH) + "." MIGRAPHX_VERSION_TWEAK;
ap(wrong_commands, {}, ap.metavar("<command>"), ap.append());
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
ap(nullptr,
......
......@@ -25,7 +25,11 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#define MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#if !defined(_WIN32)
#include <half/half.hpp>
#else
#include <half.hpp>
#endif
#include <migraphx/config.hpp>
namespace migraphx {
......
......@@ -56,7 +56,7 @@ add_library(migraphx_cpu
set_target_properties(migraphx_cpu PROPERTIES EXPORT_NAME cpu)
rocm_set_soversion(migraphx_cpu ${MIGRAPHX_SO_VERSION})
option(MIGRAPHX_ENABLE_ZENDNN "MIGraphX enable ZenDNN" Off)
set(MIGRAPHX_ENABLE_ZENDNN Off CACHE BOOL "")
if(MIGRAPHX_ENABLE_ZENDNN)
find_path(ZENDNN_INC_PATH zendnn.hpp)
......@@ -67,7 +67,6 @@ elseif(NOT WIN32)
endif()
rocm_clang_tidy_check(migraphx_cpu)
if(MIGRAPHX_ENABLE_ZENDNN)
target_compile_definitions(migraphx_cpu PRIVATE -DMIGRAPHX_ENABLE_ZENDNN)
target_include_directories(migraphx_cpu PRIVATE ${ZENDNN_INC_PATH})
......
......@@ -22,8 +22,7 @@
# THE SOFTWARE.
# ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm $ENV{HIP_PATH})
find_package(hip)
find_package(hip REQUIRED)
if(NOT GPU_TARGETS)
message(FATAL_ERROR "HIP package is broken and has no GPU_TARGETS, please pass -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to cmake to build for your gpu.")
endif()
......@@ -55,7 +54,7 @@ if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
endif()
include(Embed)
add_embed_library(migraphx_kernels ${KERNEL_FILES} BASE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
......@@ -156,8 +155,6 @@ add_library(migraphx_gpu
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
migraphx_generate_export_header(migraphx_gpu)
target_link_options(migraphx_gpu PUBLIC -Wno-option-ignored)
function(register_migraphx_gpu_ops PREFIX)
foreach(OP ${ARGN})
register_op(migraphx_gpu HEADER migraphx/gpu/${OP}.hpp OPERATORS gpu::${PREFIX}${OP} INCLUDES migraphx/gpu/context.hpp)
......
......@@ -361,7 +361,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
bool hip_has_flags(const std::vector<std::string>& flags)
{
src_compiler compiler;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
compiler.compiler = MIGRAPHX_HIP_COMPILER;
compiler.flags =
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
......
......@@ -22,17 +22,10 @@
# THE SOFTWARE.
#####################################################################################
add_executable(migraphx_gpu_driver
action.cpp
compile_op.cpp
main.cpp
parser.cpp
run_op.cpp)
set_target_properties(migraphx_gpu_driver PROPERTIES OUTPUT_NAME migraphx-gpu-driver)
rocm_clang_tidy_check(migraphx_gpu_driver)
target_include_directories(migraphx_gpu_driver PRIVATE include)
target_link_libraries(migraphx_gpu_driver PRIVATE migraphx MIOpen migraphx_gpu hip::device)
target_compile_options(migraphx_gpu_driver PRIVATE -Wno-ignored-attributes -Wno-option-ignored)
file(GLOB GPU_DRIVER_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
add_executable(gpu-driver
${GPU_DRIVER_SRCS}
)
rocm_clang_tidy_check(gpu-driver)
target_include_directories(gpu-driver PRIVATE include)
target_link_libraries(gpu-driver PRIVATE migraphx_gpu)
......@@ -22,15 +22,12 @@
# THE SOFTWARE.
#####################################################################################
add_executable(migraphx_hiprtc_driver
add_executable(migraphx-hiprtc-driver
main.cpp
)
rocm_clang_tidy_check(migraphx_hiprtc_driver)
target_link_libraries(migraphx_hiprtc_driver PRIVATE migraphx migraphx_gpu)
add_dependencies(migraphx_all_targets migraphx_hiprtc_driver)
set_target_properties(migraphx_hiprtc_driver PROPERTIES OUTPUT_NAME migraphx-hiprtc-driver)
rocm_clang_tidy_check(migraphx-hiprtc-driver)
target_link_libraries(migraphx-hiprtc-driver PRIVATE migraphx_gpu)
add_dependencies(migraphx_all_targets migraphx-hiprtc-driver)
rocm_install_targets(
TARGETS migraphx_hiprtc_driver
TARGETS migraphx-hiprtc-driver
)
target_compile_options(migraphx_hiprtc_driver PRIVATE -Wno-ignored-attributes)
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <Windows.h>
#include "migraphx_kernels.hpp"
#include "resource.h"
namespace {
std::string_view __read(int id)
{
HMODULE handle{::GetModuleHandle(nullptr)};
HRSRC rc{::FindResource(handle, MAKEINTRESOURCE(id), MAKEINTRESOURCE(MIGRAPHX_TEXTFILE))};
HGLOBAL data{::LoadResource(handle, rc)};
return {static_cast<char const*>(::LockResource(data)), ::SizeofResource(handle, rc)};
}
}
std::unordered_map<std::string_view, std::string_view> migraphx_kernels()
{
static std::unordered_map<std::string_view, std::string_view> kernels = {
{"migraphx/kernels/algorithm.hpp", __read(MIGRAPHX_IDR_ALGORITHM_HPP)},
{"migraphx/kernels/args.hpp", __read(MIGRAPHX_IDR_ARGS_HPP)},
{"migraphx/kernels/array.hpp", __read(MIGRAPHX_IDR_ARRAY_HPP)},
{"migraphx/kernels/concat.hpp", __read(MIGRAPHX_IDR_CONCAT_HPP)},
{"migraphx/kernels/debug.hpp", __read(MIGRAPHX_IDR_DEBUG_HPP)},
{"migraphx/kernels/dfor.hpp", __read(MIGRAPHX_IDR_DFOR_HPP)},
{"migraphx/kernels/dpp.hpp", __read(MIGRAPHX_IDR_DPP_HPP)},
{"migraphx/kernels/functional.hpp", __read(MIGRAPHX_IDR_FUNCTIONAL_HPP)},
{"migraphx/kernels/gather.hpp", __read(MIGRAPHX_IDR_GATHER_HPP)},
{"migraphx/kernels/gathernd.hpp", __read(MIGRAPHX_IDR_GATHERND_HPP)},
{"migraphx/kernels/generic_constant.hpp", __read(MIGRAPHX_IDR_GENERIC_CONSTANT_HPP)},
{"migraphx/kernels/hip.hpp", __read(MIGRAPHX_IDR_HIP_HPP)},
{"migraphx/kernels/index.hpp", __read(MIGRAPHX_IDR_INDEX_HPP)},
{"migraphx/kernels/integral_constant.hpp", __read(MIGRAPHX_IDR_INTEGRAL_CONSTANT_HPP)},
{"migraphx/kernels/iota_iterator.hpp", __read(MIGRAPHX_IDR_IOTA_ITERATOR_HPP)},
{"migraphx/kernels/layernorm.hpp", __read(MIGRAPHX_IDR_LAYERNORM_HPP)},
{"migraphx/kernels/math.hpp", __read(MIGRAPHX_IDR_MATH_HPP)},
{"migraphx/kernels/ops.hpp", __read(MIGRAPHX_IDR_OPS_HPP)},
{"migraphx/kernels/pad.hpp", __read(MIGRAPHX_IDR_PAD_HPP)},
{"migraphx/kernels/pointwise.hpp", __read(MIGRAPHX_IDR_POINTWISE_HPP)},
{"migraphx/kernels/preload.hpp", __read(MIGRAPHX_IDR_PRELOAD_HPP)},
{"migraphx/kernels/print.hpp", __read(MIGRAPHX_IDR_PRINT_HPP)},
{"migraphx/kernels/ranges.hpp", __read(MIGRAPHX_IDR_RANGES_HPP)},
{"migraphx/kernels/reduce.hpp", __read(MIGRAPHX_IDR_REDUCE_HPP)},
{"migraphx/kernels/roialign.hpp", __read(MIGRAPHX_IDR_ROIALIGN_HPP)},
{"migraphx/kernels/scatternd.hpp", __read(MIGRAPHX_IDR_SCATTERND_HPP)},
{"migraphx/kernels/shape.hpp", __read(MIGRAPHX_IDR_SHAPE_HPP)},
{"migraphx/kernels/softmax.hpp", __read(MIGRAPHX_IDR_SOFTMAX_HPP)},
{"migraphx/kernels/tensor_view.hpp", __read(MIGRAPHX_IDR_TENSOR_VIEW_HPP)},
{"migraphx/kernels/type_traits.hpp", __read(MIGRAPHX_IDR_TYPE_TRAITS_HPP)},
{"migraphx/kernels/types.hpp", __read(MIGRAPHX_IDR_TYPES_HPP)},
{"migraphx/kernels/vec.hpp", __read(MIGRAPHX_IDR_VEC_HPP)},
{"migraphx/kernels/vectorize.hpp", __read(MIGRAPHX_IDR_VECTORIZE_HPP)}};
return kernels;
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_KERNELS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_KERNELS_HPP
#include <string_view>
#include <unordered_map>
std::unordered_map<std::string_view, std::string_view> migraphx_kernels();
#endif // MIGRAPHX_GUARD_MIGRAPHX_KERNELS_HPP
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_KERNELS_RESOURCE_H
#define MIGRAPHX_GUARD_MIGRAPHX_KERNELS_RESOURCE_H
#define MIGRAPHX_TEXTFILE 256
#define MIGRAPHX_IDR_ALGORITHM_HPP 101
#define MIGRAPHX_IDR_ARGS_HPP 102
#define MIGRAPHX_IDR_ARRAY_HPP 103
#define MIGRAPHX_IDR_CONCAT_HPP 104
#define MIGRAPHX_IDR_DEBUG_HPP 105
#define MIGRAPHX_IDR_DFOR_HPP 106
#define MIGRAPHX_IDR_DPP_HPP 107
#define MIGRAPHX_IDR_FUNCTIONAL_HPP 108
#define MIGRAPHX_IDR_GATHER_HPP 109
#define MIGRAPHX_IDR_GATHERND_HPP 110
#define MIGRAPHX_IDR_GENERIC_CONSTANT_HPP 111
#define MIGRAPHX_IDR_HIP_HPP 112
#define MIGRAPHX_IDR_INDEX_HPP 113
#define MIGRAPHX_IDR_INTEGRAL_CONSTANT_HPP 114
#define MIGRAPHX_IDR_IOTA_ITERATOR_HPP 115
#define MIGRAPHX_IDR_LAYERNORM_HPP 116
#define MIGRAPHX_IDR_MATH_HPP 117
#define MIGRAPHX_IDR_OPS_HPP 118
#define MIGRAPHX_IDR_PAD_HPP 119
#define MIGRAPHX_IDR_POINTWISE_HPP 120
#define MIGRAPHX_IDR_PRELOAD_HPP 121
#define MIGRAPHX_IDR_PRINT_HPP 122
#define MIGRAPHX_IDR_RANGES_HPP 123
#define MIGRAPHX_IDR_REDUCE_HPP 124
#define MIGRAPHX_IDR_ROIALIGN_HPP 125
#define MIGRAPHX_IDR_SCATTERND_HPP 126
#define MIGRAPHX_IDR_SHAPE_HPP 127
#define MIGRAPHX_IDR_SOFTMAX_HPP 128
#define MIGRAPHX_IDR_TENSOR_VIEW_HPP 129
#define MIGRAPHX_IDR_TYPE_TRAITS_HPP 130
#define MIGRAPHX_IDR_TYPES_HPP 131
#define MIGRAPHX_IDR_VEC_HPP 132
#define MIGRAPHX_IDR_VECTORIZE_HPP 134
#endif // MIGRAPHX_GUARD_MIGRAPHX_KERNELS_RESOURCE_H
//
// The MIT License (MIT)
//
// Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
//
#include "resource.h"
MIGRAPHX_IDR_ALGORITHM_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/algorithm.hpp"
MIGRAPHX_IDR_ARGS_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/args.hpp"
MIGRAPHX_IDR_ARRAY_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/array.hpp"
MIGRAPHX_IDR_CONCAT_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/concat.hpp"
MIGRAPHX_IDR_DEBUG_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/debug.hpp"
MIGRAPHX_IDR_DFOR_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/dfor.hpp"
MIGRAPHX_IDR_DPP_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/dpp.hpp"
MIGRAPHX_IDR_FUNCTIONAL_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/functional.hpp"
MIGRAPHX_IDR_GATHER_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/gather.hpp"
MIGRAPHX_IDR_GATHERND_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/gathernd.hpp"
MIGRAPHX_IDR_GENERIC_CONSTANT_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/generic_constant.hpp"
MIGRAPHX_IDR_HIP_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/hip.hpp"
MIGRAPHX_IDR_INDEX_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/index.hpp"
MIGRAPHX_IDR_INTEGRAL_CONSTANT_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/integral_constant.hpp"
MIGRAPHX_IDR_IOTA_ITERATOR_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/iota_iterator.hpp"
MIGRAPHX_IDR_LAYERNORM_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/layernorm.hpp"
MIGRAPHX_IDR_MATH_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/math.hpp"
MIGRAPHX_IDR_OPS_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/ops.hpp"
MIGRAPHX_IDR_PAD_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/pad.hpp"
MIGRAPHX_IDR_POINTWISE_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/pointwise.hpp"
MIGRAPHX_IDR_PRELOAD_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/preload.hpp"
MIGRAPHX_IDR_PRINT_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/print.hpp"
MIGRAPHX_IDR_RANGES_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/ranges.hpp"
MIGRAPHX_IDR_REDUCE_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/reduce.hpp"
MIGRAPHX_IDR_ROIALIGN_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/roialign.hpp"
MIGRAPHX_IDR_SCATTERND_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/scatternd.hpp"
MIGRAPHX_IDR_SHAPE_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/shape.hpp"
MIGRAPHX_IDR_SOFTMAX_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/softmax.hpp"
MIGRAPHX_IDR_TENSOR_VIEW_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/tensor_view.hpp"
MIGRAPHX_IDR_TYPE_TRAITS_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/type_traits.hpp"
MIGRAPHX_IDR_TYPES_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/types.hpp"
MIGRAPHX_IDR_VEC_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/vec.hpp"
MIGRAPHX_IDR_VECTORIZE_HPP MIGRAPHX_TEXTFILE "include/migraphx/kernels/vectorize.hpp"
......@@ -28,7 +28,9 @@
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
#include <migraphx/gpu/ck.hpp>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -122,6 +124,8 @@ struct find_add_layernorm
}
};
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
struct pre_gemm_softmax_gemm : gemm_softmax_gemm
{
std::string name() const { return "gpu::pre_gemm_softmax_gemm"; }
......@@ -175,6 +179,8 @@ struct find_gemm_softmax_gemm
}
};
#endif
} // namespace
void prefuse_ops::apply(module_pass_manager& mpm) const
......@@ -182,8 +188,10 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{});
#ifdef MIHRAPHX_USE_COMPOSABLEKERNEL
if(enabled(MIGRAPHX_ENABLE_CK{}))
match::find_matches(mpm, find_gemm_softmax_gemm{});
#endif
}
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment