Commit 546a764e authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'migraphx' into uif2-migraphx

parents 8da3dfff 57cdd70b
...@@ -48,6 +48,13 @@ build* ...@@ -48,6 +48,13 @@ build*
.gdb_history .gdb_history
install.dir* install.dir*
# directories containing generated documentation
docs/source/_build/
docs/docBin/
# Generated source
library/src/jit_library/solution_instances/
# documentation artifacts # documentation artifacts
_build/ _build/
_images/ _images/
...@@ -57,6 +64,9 @@ _toc.yml ...@@ -57,6 +64,9 @@ _toc.yml
docBin/ docBin/
_doxygen/ _doxygen/
# pycache
__pycache__/
# JetBrains IDE # JetBrains IDE
.idea/ .idea/
cmake-build*/ cmake-build*/
......
...@@ -145,88 +145,91 @@ if(GPU_TARGETS) ...@@ -145,88 +145,91 @@ if(GPU_TARGETS)
else() else()
message("Building CK for the following targets: ${AMDGPU_TARGETS}") message("Building CK for the following targets: ${AMDGPU_TARGETS}")
endif() endif()
find_package(hip)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}")
message("hip_version_flat=${hip_VERSION_FLAT}")
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500723302)
message("Adding the fno-offload-uniform-block compiler flag")
add_compile_options(-fno-offload-uniform-block)
endif()
option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF)
if(USE_BITINT_EXTENSION_INT4)
add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
add_compile_options(-Wno-bit-int-extension)
message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
endif()
if(USE_OPT_NAVI3X)
add_compile_options(-mcumode)
add_compile_options(-mno-wavefrontsize64)
message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}")
endif()
## Threads
if(NOT WIN32)
set(THREADS_PREFER_PTHREAD_FLAG ON)
endif()
find_package(Threads REQUIRED)
link_libraries(Threads::Threads)
## C++ ## C++
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_CXX_EXTENSIONS OFF)
message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
## OpenMP option(CK_BUILD_JIT_LIB "Only build the CK JIT Helper Library" OFF)
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") if (NOT CK_BUILD_JIT_LIB)
# workaround issue hipcc in rocm3.5 cannot find openmp find_package(hip)
set(OpenMP_CXX "${CMAKE_CXX_COMPILER}") # No assumption that HIP kernels are launched with uniform block size for backward compatibility
set(OpenMP_CXX_FLAGS "-fopenmp=libomp -Wno-unused-command-line-argument") # SWDEV-413293 and https://reviews.llvm.org/D155213
set(OpenMP_CXX_LIB_NAMES "libomp" "libgomp" "libiomp5") math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}")
set(OpenMP_libomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) message("hip_version_flat=${hip_VERSION_FLAT}")
set(OpenMP_libgomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500723302)
set(OpenMP_libiomp5_LIBRARY ${OpenMP_CXX_LIB_NAMES}) message("Adding the fno-offload-uniform-block compiler flag")
else() add_compile_options(-fno-offload-uniform-block)
find_package(OpenMP REQUIRED) endif()
endif()
message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF)
message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}")
message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") if(USE_BITINT_EXTENSION_INT4)
add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
link_libraries(${OpenMP_gomp_LIBRARY}) add_compile_options(-Wno-bit-int-extension)
link_libraries(${OpenMP_pthread_LIBRARY}) message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
endif()
## HIP
find_package(HIP REQUIRED) if(USE_OPT_NAVI3X)
# Override HIP version in config.h, if necessary. add_compile_options(-mcumode)
# The variables set by find_package() can't be overwritten, add_compile_options(-mno-wavefrontsize64)
# therefore let's use intermediate variables. message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}")
set(CK_HIP_VERSION_MAJOR "${HIP_VERSION_MAJOR}") endif()
set(CK_HIP_VERSION_MINOR "${HIP_VERSION_MINOR}")
set(CK_HIP_VERSION_PATCH "${HIP_VERSION_PATCH}") ## Threads
if( DEFINED CK_OVERRIDE_HIP_VERSION_MAJOR ) if(NOT WIN32)
set(CK_HIP_VERSION_MAJOR "${CK_OVERRIDE_HIP_VERSION_MAJOR}") set(THREADS_PREFER_PTHREAD_FLAG ON)
message(STATUS "CK_HIP_VERSION_MAJOR overriden with ${CK_OVERRIDE_HIP_VERSION_MAJOR}") endif()
endif() find_package(Threads REQUIRED)
if( DEFINED CK_OVERRIDE_HIP_VERSION_MINOR ) link_libraries(Threads::Threads)
set(CK_HIP_VERSION_MINOR "${CK_OVERRIDE_HIP_VERSION_MINOR}")
message(STATUS "CK_HIP_VERSION_MINOR overriden with ${CK_OVERRIDE_HIP_VERSION_MINOR}") ## OpenMP
endif() if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH ) # workaround issue hipcc in rocm3.5 cannot find openmp
set(CK_HIP_VERSION_PATCH "${CK_OVERRIDE_HIP_VERSION_PATCH}") set(OpenMP_CXX "${CMAKE_CXX_COMPILER}")
message(STATUS "CK_HIP_VERSION_PATCH overriden with ${CK_OVERRIDE_HIP_VERSION_PATCH}") set(OpenMP_CXX_FLAGS "-fopenmp=libomp -Wno-unused-command-line-argument")
set(OpenMP_CXX_LIB_NAMES "libomp" "libgomp" "libiomp5")
set(OpenMP_libomp_LIBRARY ${OpenMP_CXX_LIB_NAMES})
set(OpenMP_libgomp_LIBRARY ${OpenMP_CXX_LIB_NAMES})
set(OpenMP_libiomp5_LIBRARY ${OpenMP_CXX_LIB_NAMES})
else()
find_package(OpenMP REQUIRED)
endif()
message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}")
message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}")
message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}")
message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}")
link_libraries(${OpenMP_gomp_LIBRARY})
link_libraries(${OpenMP_pthread_LIBRARY})
## HIP
find_package(HIP REQUIRED)
# Override HIP version in config.h, if necessary.
# The variables set by find_package() can't be overwritten,
# therefore let's use intermediate variables.
set(CK_HIP_VERSION_MAJOR "${HIP_VERSION_MAJOR}")
set(CK_HIP_VERSION_MINOR "${HIP_VERSION_MINOR}")
set(CK_HIP_VERSION_PATCH "${HIP_VERSION_PATCH}")
if( DEFINED CK_OVERRIDE_HIP_VERSION_MAJOR )
set(CK_HIP_VERSION_MAJOR "${CK_OVERRIDE_HIP_VERSION_MAJOR}")
message(STATUS "CK_HIP_VERSION_MAJOR overriden with ${CK_OVERRIDE_HIP_VERSION_MAJOR}")
endif()
if( DEFINED CK_OVERRIDE_HIP_VERSION_MINOR )
set(CK_HIP_VERSION_MINOR "${CK_OVERRIDE_HIP_VERSION_MINOR}")
message(STATUS "CK_HIP_VERSION_MINOR overriden with ${CK_OVERRIDE_HIP_VERSION_MINOR}")
endif()
if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH )
set(CK_HIP_VERSION_PATCH "${CK_OVERRIDE_HIP_VERSION_PATCH}")
message(STATUS "CK_HIP_VERSION_PATCH overriden with ${CK_OVERRIDE_HIP_VERSION_PATCH}")
endif()
message(STATUS "Build with HIP ${HIP_VERSION}")
link_libraries(hip::device)
add_compile_definitions(__HIP_PLATFORM_HCC__=1)
endif() endif()
message(STATUS "Build with HIP ${HIP_VERSION}")
link_libraries(hip::device)
add_compile_definitions(__HIP_PLATFORM_HCC__=1)
## tidy ## tidy
include(EnableCompilerWarnings) include(EnableCompilerWarnings)
...@@ -381,89 +384,98 @@ include_directories(BEFORE ...@@ -381,89 +384,98 @@ include_directories(BEFORE
${HIP_INCLUDE_DIRS} ${HIP_INCLUDE_DIRS}
) )
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") if (NOT CK_BUILD_JIT_LIB)
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
set(CK_DEVICE_INSTANCES)
FOREACH(subdir_path ${dir_list}) file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
set(target_dir) file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}") set(CK_DEVICE_INSTANCES)
set(cmake_instance) FOREACH(subdir_path ${dir_list})
file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance) set(target_dir)
set(add_inst 0) IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}")
if(("${cmake_instance}" MATCHES "fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8") set(cmake_instance)
#message("fp8 instance found!") file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance)
set(add_inst 1) set(add_inst 0)
endif() if(("${cmake_instance}" MATCHES "fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
if(("${cmake_instance}" MATCHES "bf8" OR "${cmake_instance}" MATCHES "_b8") AND DTYPES MATCHES "bf8") #message("fp8 instance found!")
#message("bf8 instance found!") set(add_inst 1)
set(add_inst 1) endif()
endif() if(("${cmake_instance}" MATCHES "bf8" OR "${cmake_instance}" MATCHES "_b8") AND DTYPES MATCHES "bf8")
if(("${cmake_instance}" MATCHES "fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16") #message("bf8 instance found!")
#message("fp16 instance found!") set(add_inst 1)
set(add_inst 1) endif()
endif() if(("${cmake_instance}" MATCHES "fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") #message("fp16 instance found!")
#message("fp32 instance found!") set(add_inst 1)
set(add_inst 1) endif()
endif() if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") #message("fp32 instance found!")
#message("fp64 instance found!") set(add_inst 1)
set(add_inst 1) endif()
endif() if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
if(("${cmake_instance}" MATCHES "bf16" OR "${cmake_instance}" MATCHES "_b16") AND DTYPES MATCHES "bf16") #message("fp64 instance found!")
#message("bf16 instance found!") set(add_inst 1)
set(add_inst 1) endif()
endif() if(("${cmake_instance}" MATCHES "bf16" OR "${cmake_instance}" MATCHES "_b16") AND DTYPES MATCHES "bf16")
if(("${cmake_instance}" MATCHES "int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8") #message("bf16 instance found!")
#message("int8 instance found!") set(add_inst 1)
set(add_inst 1) endif()
endif() if(("${cmake_instance}" MATCHES "int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
if(NOT "${cmake_instance}" MATCHES "DTYPES") #message("int8 instance found!")
#message("instance should be built for all types!") set(add_inst 1)
set(add_inst 1) endif()
endif() if(NOT "${cmake_instance}" MATCHES "DTYPES")
if(add_inst EQUAL 1 OR NOT DEFINED DTYPES) #message("instance should be built for all types!")
list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance) set(add_inst 1)
endif() endif()
ENDIF() if(add_inst EQUAL 1 OR NOT DEFINED DTYPES)
ENDFOREACH() list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance)
endif()
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) ENDIF()
add_subdirectory(library) ENDFOREACH()
if(NOT DEFINED INSTANCES_ONLY) add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
if(NOT DEFINED PROFILER_ONLY) add_subdirectory(library)
rocm_package_setup_component(tests
LIBRARY_NAME composablekernel if(NOT DEFINED INSTANCES_ONLY)
PACKAGE_NAME tests # Prevent -static suffix on package name if(NOT DEFINED PROFILER_ONLY)
) rocm_package_setup_component(tests
LIBRARY_NAME composablekernel
PACKAGE_NAME tests # Prevent -static suffix on package name
)
rocm_package_setup_component(examples rocm_package_setup_component(examples
LIBRARY_NAME composablekernel LIBRARY_NAME composablekernel
PACKAGE_NAME examples PACKAGE_NAME examples
) )
add_subdirectory(example) add_subdirectory(example)
if(BUILD_TESTING) if(BUILD_TESTING)
add_subdirectory(test) add_subdirectory(test)
endif() endif()
rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler
)
add_subdirectory(profiler)
else()
#When building PROFILER_ONLY, label the package with GPU_ARCH
rocm_package_setup_component(profiler rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler_${GPU_ARCH} PACKAGE_NAME ckProfiler
) )
add_subdirectory(profiler) add_subdirectory(profiler)
endif() else()
#When building PROFILER_ONLY, label the package with GPU_ARCH
rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckProfiler_${GPU_ARCH}
)
add_subdirectory(profiler)
endif()
endif()
else()
rocm_package_setup_component(jit_library
LIBRARY_NAME composablekernel
PACKAGE_NAME jit_library
)
add_subdirectory(library)
add_subdirectory(test)
endif() endif()
#Create an interface target for the include only files and call it "composablekernels" #Create an interface target for the include only files and call it "composablekernels"
......
@PACKAGE_INIT@ @PACKAGE_INIT@
set(_composable_kernel_supported_components device_operations utility) set(_composable_kernel_supported_components device_operations utility jit_library)
foreach(_comp ${composable_kernel_FIND_COMPONENTS}) foreach(_comp ${composable_kernel_FIND_COMPONENTS})
if(NOT _comp IN_LIST _composable_kernel_supported_components) if(NOT _comp IN_LIST _composable_kernel_supported_components)
set(composable_kernel_FOUND False) set(composable_kernel_FOUND False)
set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}") set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}")
endif() endif()
include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake") if(EXISTS "${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake")
include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake")
else()
set(composable_kernel_FOUND False)
set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}")
endif()
endforeach() endforeach()
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
find_program(EMBED_LD ld)
find_program(EMBED_OBJCOPY objcopy)
option(EMBED_USE_LD "Use ld to embed data files" OFF)
function(wrap_string)
set(options)
set(oneValueArgs VARIABLE AT_COLUMN)
set(multiValueArgs)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN})
string(LENGTH ${${PARSE_VARIABLE}} string_length)
math(EXPR offset "0")
while(string_length GREATER 0)
if(string_length GREATER ${PARSE_AT_COLUMN})
math(EXPR length "${PARSE_AT_COLUMN}")
else()
math(EXPR length "${string_length}")
endif()
string(SUBSTRING ${${PARSE_VARIABLE}} ${offset} ${length} line)
set(lines "${lines}\n${line}")
math(EXPR string_length "${string_length} - ${length}")
math(EXPR offset "${offset} + ${length}")
endwhile()
set(${PARSE_VARIABLE} "${lines}" PARENT_SCOPE)
endfunction()
function(generate_embed_source EMBED_NAME)
set(options)
set(oneValueArgs SRC HEADER RELATIVE)
set(multiValueArgs OBJECTS SYMBOLS FILES)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(EXTERNS)
set(INIT_KERNELS)
list(LENGTH PARSE_SYMBOLS SYMBOLS_LEN)
list(LENGTH PARSE_OBJECTS OBJECTS_LEN)
if(NOT ${SYMBOLS_LEN} EQUAL ${OBJECTS_LEN})
message(FATAL_ERROR "Symbols and objects dont match: ${SYMBOLS_LEN} != ${OBJECTS_LEN}")
endif()
math(EXPR LEN "${SYMBOLS_LEN} - 1")
foreach(idx RANGE ${LEN})
list(GET PARSE_SYMBOLS ${idx} SYMBOL)
list(GET PARSE_OBJECTS ${idx} OBJECT)
list(GET PARSE_FILES ${idx} FILE)
set(START_SYMBOL "_binary_${SYMBOL}_start")
set(LENGTH_SYMBOL "_binary_${SYMBOL}_length")
if(EMBED_USE_LD)
string(APPEND EXTERNS "
extern const char ${START_SYMBOL}[];
extern const size_t _binary_${SYMBOL}_size;
const auto ${LENGTH_SYMBOL} = reinterpret_cast<size_t>(&_binary_${SYMBOL}_size);
")
else()
string(APPEND EXTERNS "
extern const char ${START_SYMBOL}[];
extern const size_t ${LENGTH_SYMBOL};
")
endif()
if(PARSE_RELATIVE)
file(RELATIVE_PATH BASE_NAME ${PARSE_RELATIVE} "${FILE}")
else()
get_filename_component(BASE_NAME "${FILE}" NAME)
endif()
string(APPEND INIT_KERNELS "
{ \"${BASE_NAME}\", { ${START_SYMBOL}, ${LENGTH_SYMBOL}} },")
endforeach()
file(WRITE "${PARSE_HEADER}" "
#include <string_view>
#include <unordered_map>
#include <utility>
std::unordered_map<std::string_view, std::string_view> ${EMBED_NAME}();
")
file(WRITE "${PARSE_SRC}" "
#include <${EMBED_NAME}.hpp>
${EXTERNS}
std::unordered_map<std::string_view, std::string_view> ${EMBED_NAME}()
{
static std::unordered_map<std::string_view, std::string_view> result = {${INIT_KERNELS}};
return result;
}
")
endfunction()
function(embed_file OUTPUT_FILE OUTPUT_SYMBOL FILE)
set(WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
# Glob is used to compute the relative path
file(GLOB FILES RELATIVE ${WORKING_DIRECTORY} ${FILE})
foreach(REL_FILE ${FILES})
string(MAKE_C_IDENTIFIER "${REL_FILE}" SYMBOL)
get_filename_component(OUTPUT_FILE_DIR "${REL_FILE}" DIRECTORY)
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE_DIR}")
if(EMBED_USE_LD)
set(OUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.o")
else()
set(OUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.cpp")
endif()
set(${OUTPUT_SYMBOL} ${SYMBOL} PARENT_SCOPE)
set(${OUTPUT_FILE} "${OUT_FILE}" PARENT_SCOPE)
if(EMBED_USE_LD)
add_custom_command(
OUTPUT "${OUT_FILE}"
COMMAND ${EMBED_LD} -r -o "${OUT_FILE}" -z noexecstack --format=binary "${REL_FILE}"
COMMAND ${EMBED_OBJCOPY} --rename-section .data=.rodata,alloc,load,readonly,data,contents "${OUT_FILE}"
WORKING_DIRECTORY ${WORKING_DIRECTORY}
DEPENDS ${FILE}
VERBATIM
)
else()
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${FILE})
# reads source file contents as hex string
file(READ ${FILE} HEX_STRING HEX)
# wraps the hex string into multiple lines
wrap_string(VARIABLE HEX_STRING AT_COLUMN 80)
# adds '0x' prefix and comma suffix before and after every byte respectively
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1, " ARRAY_VALUES ${HEX_STRING})
# removes trailing comma
string(REGEX REPLACE ", $" "" ARRAY_VALUES ${ARRAY_VALUES})
file(WRITE "${OUT_FILE}" "
#include <cstddef>
extern const char _binary_${SYMBOL}_start[] = { ${ARRAY_VALUES} };
extern const size_t _binary_${SYMBOL}_length = sizeof(_binary_${SYMBOL}_start);
")
endif()
endforeach()
endfunction()
function(add_embed_library EMBED_NAME)
set(options)
set(oneValueArgs RELATIVE)
set(multiValueArgs)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/embed)
file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/embed/${EMBED_NAME})
set(EMBED_DIR ${CMAKE_CURRENT_BINARY_DIR}/embed/${EMBED_NAME})
set(SRC_FILE "${EMBED_DIR}/${EMBED_NAME}.cpp")
set(HEADER_FILE "${EMBED_DIR}/include/${EMBED_NAME}.hpp")
set(WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
set(OUTPUT_FILES)
set(SYMBOLS)
message(STATUS "Embedding files")
foreach(FILE ${PARSE_UNPARSED_ARGUMENTS})
embed_file(OUTPUT_FILE OUTPUT_SYMBOL ${FILE})
list(APPEND OUTPUT_FILES ${OUTPUT_FILE})
list(APPEND SYMBOLS ${OUTPUT_SYMBOL})
endforeach()
message(STATUS "Generating embedding library ${EMBED_NAME}")
generate_embed_source(${EMBED_NAME} SRC ${SRC_FILE} HEADER ${HEADER_FILE} OBJECTS ${OUTPUT_FILES} SYMBOLS ${SYMBOLS} RELATIVE ${PARSE_RELATIVE} FILES ${PARSE_UNPARSED_ARGUMENTS})
set(INTERNAL_EMBED_LIB embed_lib_${EMBED_NAME})
add_library(${INTERNAL_EMBED_LIB} OBJECT "${SRC_FILE}")
target_include_directories(${INTERNAL_EMBED_LIB} PRIVATE "${EMBED_DIR}/include")
target_compile_options(${INTERNAL_EMBED_LIB} PRIVATE -Wno-reserved-identifier -Wno-extern-initializer -Wno-missing-variable-declarations)
set_target_properties(${INTERNAL_EMBED_LIB} PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(${EMBED_NAME} INTERFACE)
if(EMBED_USE_LD)
target_sources(${EMBED_NAME} INTERFACE ${OUTPUT_FILES})
else()
target_sources(${INTERNAL_EMBED_LIB} PRIVATE ${OUTPUT_FILES})
endif()
target_sources(${EMBED_NAME} INTERFACE $<TARGET_OBJECTS:${INTERNAL_EMBED_LIB}>)
target_include_directories(${EMBED_NAME} INTERFACE "${EMBED_DIR}/include")
endfunction()
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
#pragma once #pragma once
#include "ck/config.h" #include "ck/config.h"
#ifndef __HIPCC_RTC__
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
#endif #endif
#endif
#define CK_TIME_KERNEL 1 #define CK_TIME_KERNEL 1
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <string> #include <string>
#include <map> #include <map>
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
...@@ -59,3 +60,4 @@ inline bool is_xdl_supported() ...@@ -59,3 +60,4 @@ inline bool is_xdl_supported()
} }
} // namespace ck } // namespace ck
#endif
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
...@@ -150,3 +150,4 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -150,3 +150,4 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return 0; return 0;
#endif #endif
} }
#endif
...@@ -2,16 +2,17 @@ ...@@ -2,16 +2,17 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <string> #include <string>
#include <sstream> #include <sstream>
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
#endif
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
#ifndef __HIPCC_RTC__
struct BaseArgument struct BaseArgument
{ {
BaseArgument() = default; BaseArgument() = default;
...@@ -36,6 +37,7 @@ struct BaseInvoker ...@@ -36,6 +37,7 @@ struct BaseInvoker
virtual ~BaseInvoker() {} virtual ~BaseInvoker() {}
}; };
#endif
struct BaseOperator struct BaseOperator
{ {
...@@ -43,7 +45,9 @@ struct BaseOperator ...@@ -43,7 +45,9 @@ struct BaseOperator
BaseOperator(const BaseOperator&) = default; BaseOperator(const BaseOperator&) = default;
BaseOperator& operator=(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default;
#ifndef __HIPCC_RTC__
virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; } virtual std::string GetTypeString() const { return ""; }
virtual std::string GetTypeIdName() const { return typeid(*this).name(); } virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
...@@ -56,7 +60,6 @@ struct BaseOperator ...@@ -56,7 +60,6 @@ struct BaseOperator
return oss.str(); return oss.str();
}; };
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
...@@ -64,7 +67,7 @@ struct BaseOperator ...@@ -64,7 +67,7 @@ struct BaseOperator
assert(p_arg); assert(p_arg);
p_arg->p_workspace_ = p_workspace; p_arg->p_workspace_ = p_workspace;
} }
#endif
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
......
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#endif
#include "device_base.hpp" #include "device_base.hpp"
...@@ -28,6 +29,7 @@ template <typename ALayout, ...@@ -28,6 +29,7 @@ template <typename ALayout,
bool MaskOutUpperTriangle> // TODO: enum for mask type bool MaskOutUpperTriangle> // TODO: enum for mask type
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
{ {
#ifndef __HIPCC_RTC__
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b0, const void* p_b0,
...@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator ...@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
CElementwiseOperation c_element_op) = 0; CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
}; };
} // namespace device } // namespace device
......
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <array> #include <array>
#endif
#include "ck/utility/array.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck { namespace ck {
...@@ -34,23 +36,24 @@ struct DeviceGemmMultipleD : public BaseOperator ...@@ -34,23 +36,24 @@ struct DeviceGemmMultipleD : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef __HIPCC_RTC__
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, ck::Array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs, ck::Array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE, ck::index_t StrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0; CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; #endif
}; };
} // namespace device } // namespace device
......
...@@ -28,7 +28,7 @@ enum struct GemmSpecialization ...@@ -28,7 +28,7 @@ enum struct GemmSpecialization
NKOPadding, NKOPadding,
MNKOPadding, MNKOPadding,
}; };
#ifndef __HIPCC_RTC__
inline std::string getGemmSpecializationString(const GemmSpecialization& s) inline std::string getGemmSpecializationString(const GemmSpecialization& s)
{ {
switch(s) switch(s)
...@@ -52,6 +52,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s) ...@@ -52,6 +52,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
#endif
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -3,8 +3,12 @@ ...@@ -3,8 +3,12 @@
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -15,8 +19,6 @@ ...@@ -15,8 +19,6 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -126,7 +128,6 @@ __global__ void ...@@ -126,7 +128,6 @@ __global__ void
// else // else
// AccElement = -INFINITY // AccElement = -INFINITY
// Otherwise, result may be wrong. // Otherwise, result may be wrong.
template <typename ALayout, template <typename ALayout,
typename BLayout, // B0Layout typename BLayout, // B0Layout
typename B1Layout, typename B1Layout,
...@@ -430,6 +431,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -430,6 +431,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
matrix_padder.PadN, matrix_padder.PadN,
MaskOutUpperTriangle>; MaskOutUpperTriangle>;
#ifndef __HIPCC_RTC__
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -604,13 +606,103 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -604,13 +606,103 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
#endif
static constexpr bool IsValidCompilationParameter() static constexpr bool IsValidCompilationParameter()
{ {
// TODO: properly implement this check // TODO: properly implement this check
return true; return true;
} }
static constexpr bool
IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
{
// check vector load/store
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row>)
{
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col>)
{
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B
if constexpr(is_same_v<BLayout, Row>)
{
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Col>)
{
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B1
if constexpr(is_same_v<B1Layout, Row>)
{
if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<B1Layout, Col>)
{
if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of C
if constexpr(is_same_v<CLayout, Row>)
{
if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else if constexpr(is_same_v<CLayout, Col>)
{
if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
...@@ -625,29 +717,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -625,29 +717,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_) and
IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw);
} }
// polymorphic // polymorphic
...@@ -685,7 +760,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -685,7 +760,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op, BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op,
b1_element_op, c_element_op}; b1_element_op, c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
...@@ -765,6 +839,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -765,6 +839,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return str.str(); return str.str();
} }
#endif
template <class ADesc, class BDesc, class B1Desc, class CDesc>
struct Descriptor
{
template <class AGridDescriptor>
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
{
const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class BGridDescriptor>
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
{
const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class B1GridDescriptor>
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
{
const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class CGridDescriptor>
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
{
return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
}
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN,
MaskOutUpperTriangle>;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
C0MatrixMask c0_matrix_mask;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock;
// element-wise op
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
B1ElementwiseOperation b1_element_op;
CElementwiseOperation c_element_op;
bool has_main_k_block_loop = true;
bool is_valid = false;
constexpr Descriptor(ADesc a,
BDesc b,
B1Desc b1,
CDesc c,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
B1ElementwiseOperation b1_element_op_,
CElementwiseOperation c_element_op_)
: a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)},
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
c0_matrix_mask{c.GetLength(I1)},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
b1_element_op{b1_element_op_},
c_element_op{c_element_op_},
is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
block_2_ctile_map) and
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
b_grid_desc_bk0_n_bk1.GetLength(I1),
a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2),
b1_grid_desc_bk0_n_bk1.GetLength(I1))}
{
}
constexpr bool IsValid() const { return is_valid; }
};
template <class ADesc, class BDesc, class B1Desc, class CDesc>
static constexpr auto
make_descriptor(ADesc a,
BDesc b,
B1Desc b1,
CDesc c,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
CElementwiseOperation c_element_op = CElementwiseOperation{})
{
return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op);
}
template <class Desc>
__device__ static void Run(const Desc& desc,
const float scale,
const ADataType* __restrict__ p_a_grid,
const ADataType* __restrict__ p_b_grid,
const ADataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid)
{
#ifndef __HIPCC_RTC__
assert(desc.is_valid);
#endif
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
AccElementwiseOperation acc_element_op{scale};
if(desc.has_main_k_block_loop)
{
Desc::GridwiseGemm::template Run<true>(
p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
else
{
Desc::GridwiseGemm::template Run<false>(
p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
}
}; };
} // namespace device } // namespace device
......
...@@ -2,20 +2,22 @@ ...@@ -2,20 +2,22 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/array.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/utility/common_header.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
...@@ -225,9 +227,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -225,9 +227,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
} }
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws, static auto MakeDsGridDescriptor_M_N(const ck::Array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws, const ck::Array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride) const ck::Array<index_t, NumDTensor>& DsStride)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -308,20 +310,20 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -308,20 +310,20 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
#ifndef __HIPCC_RTC__
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const void* p_a_grid, Argument(const void* p_a_grid,
const void* p_b_grid, const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid, ck::Array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid, void* p_e_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, ck::Array<index_t, NumDTensor> StrideDs,
index_t StrideE, index_t StrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -420,7 +422,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -420,7 +422,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
index_t NRaw_; index_t NRaw_;
index_t KRaw_; index_t KRaw_;
}; };
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
...@@ -497,95 +498,100 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -497,95 +498,100 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
#endif
static bool IsSupportedArgument(const Argument& arg) static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
// check vector load/store // check vector load/store
{ using Row = ck::tensor_layout::gemm::RowMajor;
using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A // check vector load of A
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2) if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{ {
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{ {
// FIXME: not rigorous return false;
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
} }
else }
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
}
else
{
return false;
}
// check vector laod of B // check vector laod of B
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2) if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{ {
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{ {
// FIXME: not rigorous return false;
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
} }
else }
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
}
else
{
return false;
}
// check vector load of Ds // check vector load of Ds
// only support RowMajor for now // only support RowMajor for now
bool all_valid = true; bool all_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(!is_same_v<DLayout, Row>) if constexpr(!is_same_v<DLayout, Row>)
{
all_valid = false;
}
});
if(!all_valid)
{ {
return false; all_valid = false;
} }
});
// check vector store of E if(!all_valid)
// only support RowMajor for now {
if constexpr(is_same_v<ELayout, Row>) return false;
{ }
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{ // check vector store of E
return false; // only support RowMajor for now
} if constexpr(is_same_v<ELayout, Row>)
} {
else if(NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{ {
return false; return false;
} }
} }
else
{
return false;
}
return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, return IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_) and
GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
...@@ -597,17 +603,16 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -597,17 +603,16 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{ {
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const void* p_a, static auto MakeArgument(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, ck::Array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, ck::Array<index_t, NumDTensor> StrideDs,
index_t StrideE, index_t StrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -635,14 +640,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -635,14 +640,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, ck::Array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs, ck::Array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE, index_t StrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -675,11 +680,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -675,11 +680,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{ {
auto str = std::stringstream(); auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{ std::map<LoopScheduler, std::string> LoopSchedToString{{LoopScheduler::Default, "Default"},
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; { LoopScheduler::Interwave,
"Interwave" }};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"}, std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}}; { PipelineVersion::v2,
"v2" }};
// clang-format off // clang-format off
str << "DeviceGemmMultipleD_Xdl_CShuffle" str << "DeviceGemmMultipleD_Xdl_CShuffle"
...@@ -708,6 +715,149 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -708,6 +715,149 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return str.str(); return str.str();
} }
#endif
template <class ADesc, class BDesc, class DsDesc, class EDesc>
struct Descriptor
{
static constexpr auto ds_tuple()
{
return transform_tuples(
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
DsDesc{});
}
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_tuple()))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
using Block2ETileMap = remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
Block2ETileMap block_2_etile_map;
// element-wise op
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
CDEElementwiseOperation cde_element_op;
bool has_main_k_block_loop = true;
bool is_valid = false;
constexpr Descriptor(ADesc a,
BDesc b,
DsDesc ds,
EDesc e,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_)
: a_grid_desc_ak0_m_ak1{GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
DeviceOp::matrix_padder.PadADescriptor_M_K(a))},
b_grid_desc_bk0_n_bk1{GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
DeviceOp::matrix_padder.PadBDescriptor_N_K(b))},
ds_grid_desc_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
transform_tuples(
[&](auto d) constexpr {
return DeviceOp::matrix_padder.PadCDescriptor_M_N(d);
},
ds))},
e_grid_desc_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DeviceOp::matrix_padder.PadCDescriptor_M_N(e))},
block_2_etile_map{GridwiseGemm::MakeDefaultBlock2ETileMap(
DeviceOp::matrix_padder.PadCDescriptor_M_N(e))},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
cde_element_op{cde_element_op_},
is_valid{GridwiseGemm::CheckValidity(
(DeviceOp::matrix_padder.PadADescriptor_M_K(a)),
DeviceOp::matrix_padder.PadBDescriptor_N_K(b),
transform_tuples(
[&](auto d) constexpr {
return DeviceOp::matrix_padder.PadCDescriptor_M_N(d);
},
ds),
DeviceOp::matrix_padder.PadCDescriptor_M_N(e),
block_2_etile_map) and
IsSupported(e.GetLength(I0), e.GetLength(I1), a.GetLength(I1))}
{
}
constexpr bool IsValid() const { return is_valid; }
};
template <class ADesc, class BDesc, class DsDesc, class EDesc>
static constexpr auto
make_descriptor(ADesc a,
BDesc b,
DsDesc ds,
EDesc e,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation cde_element_op = CDEElementwiseOperation{})
{
return Descriptor<ADesc, BDesc, DsDesc, EDesc>(
a, b, ds, e, a_element_op, b_element_op, cde_element_op);
}
template <class Desc, class DsPointer>
__device__ static void Run(const Desc& desc,
const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid)
{
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
#ifndef __HIPCC_RTC__
assert(desc.is_valid);
#endif
if(desc.has_main_k_block_loop)
{
GridwiseGemm::template Run<true>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
desc.cde_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
desc.block_2_etile_map);
}
else
{
GridwiseGemm::template Run<false>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
desc.cde_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
desc.block_2_etile_map);
}
}
}; };
} // namespace device } // namespace device
......
...@@ -13,6 +13,7 @@ enum struct MaskingSpecialization ...@@ -13,6 +13,7 @@ enum struct MaskingSpecialization
MaskOutUpperTriangle MaskOutUpperTriangle
}; };
#ifndef __HIPCC_RTC__
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s) inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
{ {
switch(s) switch(s)
...@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s ...@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
#endif
struct MaskDisabledPredicate struct MaskDisabledPredicate
{ {
...@@ -53,7 +55,7 @@ struct MaskOutUpperTrianglePredicate ...@@ -53,7 +55,7 @@ struct MaskOutUpperTrianglePredicate
template <typename MaskOutPredicate> template <typename MaskOutPredicate>
struct C0MatrixMask_impl struct C0MatrixMask_impl
{ {
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {} constexpr C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
{ {
......
...@@ -406,7 +406,7 @@ struct G_NDHW : public BaseTensorLayout ...@@ -406,7 +406,7 @@ struct G_NDHW : public BaseTensorLayout
template < template <
typename Layout, typename Layout,
typename std::enable_if<std::is_base_of<BaseTensorLayout, Layout>::value, bool>::type = false> typename ck::enable_if<ck::is_base_of<BaseTensorLayout, Layout>::value, bool>::type = false>
std::ostream& operator<<(std::ostream& os, const Layout&) std::ostream& operator<<(std::ostream& os, const Layout&)
{ {
os << Layout::name; os << Layout::name;
......
...@@ -354,6 +354,7 @@ struct FastGelu ...@@ -354,6 +354,7 @@ struct FastGelu
template <typename Y, typename X> template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const; __device__ void operator()(Y& y, const X& x) const;
#ifndef __HIPCC_RTC__
template <> template <>
__host__ void operator()<float, float>(float& y, const float& x) const __host__ void operator()<float, float>(float& y, const float& x) const
{ {
...@@ -363,7 +364,7 @@ struct FastGelu ...@@ -363,7 +364,7 @@ struct FastGelu
y = x * cdf; y = x * cdf;
} }
#endif
// device code, use lower precision "__expf" and "rcp" // device code, use lower precision "__expf" and "rcp"
template <> template <>
__device__ void operator()<float, float>(float& y, const float& x) const __device__ void operator()<float, float>(float& y, const float& x) const
......
...@@ -5,10 +5,13 @@ ...@@ -5,10 +5,13 @@
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#ifndef __HIPCC_RTC__
#include <limits> #include <limits>
#include <stdlib.h> #include <stdlib.h>
#endif
namespace ck { namespace ck {
...@@ -86,16 +89,16 @@ struct BlockToCTileMap_M00_N0_M01 ...@@ -86,16 +89,16 @@ struct BlockToCTileMap_M00_N0_M01
const auto M00 = math::integer_divide_ceil(M0, M01); const auto M00 = math::integer_divide_ceil(M0, M01);
const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor( const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1), ck::make_tuple(make_insert_transform(1),
make_unmerge_transform(make_tuple(M00, M01)), make_unmerge_transform(ck::make_tuple(M00, M01)),
make_pass_through_transform(make_tuple(N0))), make_pass_through_transform(ck::make_tuple(N0))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), ck::make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); ck::make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor( const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(1, M00, N0, M01))), ck::make_tuple(make_merge_transform(ck::make_tuple(1, M00, N0, M01))),
make_tuple(Sequence<0, 1, 2, 3>{}), ck::make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{})); ck::make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor = const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor, chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor,
...@@ -120,31 +123,33 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void> ...@@ -120,31 +123,33 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) = __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
default; const BlockToCTileMap_M00_N0_M01Adapt&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) = __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
default; BlockToCTileMap_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt& __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default; operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt& __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default; operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) __host__
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
: M_(M), N_(N), M01_(M01) : M_(M), N_(N), M01_(M01)
{ {
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, __host__
index_t M01 = 8) __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8)
: BlockToCTileMap_M00_N0_M01Adapt( : BlockToCTileMap_M00_N0_M01Adapt(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{ {
} }
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{ {
const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock);
...@@ -153,13 +158,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void> ...@@ -153,13 +158,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) __host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const __host__ __device__ constexpr bool
CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{ {
return true; return true;
} }
...@@ -227,13 +234,13 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void> ...@@ -227,13 +234,13 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
* output {1, 2} * output {1, 2}
*/ */
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, return ck::make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt); idx_N0_M01_local / M01_adapt);
} }
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const const CTileDim& /* c_tile_dim */) const
{ {
return true; // always valid provided that user gets grid size from CalculateGridSize() return true; // always valid provided that user gets grid size from CalculateGridSize()
} }
...@@ -303,9 +310,9 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt ...@@ -303,9 +310,9 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
index_t idx_M01 = idx_M0 % M01_; index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_ksplit, return ck::make_tuple(idx_ksplit,
idx_N0_M01_local % M01_adapt + idx_M00 * M01_, idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt); idx_N0_M01_local / M01_adapt);
} }
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
...@@ -402,17 +409,17 @@ struct BlockToCTileMap_M00_N00_M01_N01 ...@@ -402,17 +409,17 @@ struct BlockToCTileMap_M00_N00_M01_N01
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions ck::make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions
make_unmerge_transform(make_tuple(M00, M01)), make_unmerge_transform(ck::make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))), make_unmerge_transform(ck::make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), ck::make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); ck::make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(1, M00, N00, M01, N01))), ck::make_tuple(make_merge_transform(ck::make_tuple(1, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}), ck::make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{})); ck::make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor = const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
...@@ -521,17 +528,17 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -521,17 +528,17 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
const auto ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = const auto ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(KSplit), ck::make_tuple(make_pass_through_transform(KSplit),
make_unmerge_transform(make_tuple(M00, M01)), make_unmerge_transform(ck::make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))), make_unmerge_transform(ck::make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), ck::make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); ck::make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor = const auto c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KSplit, M00, N00, M01, N01))), ck::make_tuple(make_merge_transform(ck::make_tuple(KSplit, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}), ck::make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{})); ck::make_tuple(Sequence<0>{}));
const auto c_blockid_to_ksplit_m0_n0_block_cluster_adaptor = const auto c_blockid_to_ksplit_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, chain_tensor_adaptors(ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
...@@ -649,13 +656,13 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -649,13 +656,13 @@ struct BlockToCTileMap_3DGrid_KSplit
const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return std::make_tuple(N0, M0, k_split); return ck::make_tuple(N0, M0, k_split);
} }
template <typename TopIdx> template <typename TopIdx>
__device__ constexpr auto CalculateBottomIndex(const TopIdx&) const __device__ constexpr auto CalculateBottomIndex(const TopIdx&) const
{ {
return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x); return ck::make_tuple(blockIdx.z, blockIdx.y, blockIdx.x);
} }
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
...@@ -773,7 +780,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -773,7 +780,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t dp_for_sk_iters = k_iters_per_tile.get(); uint32_t dp_for_sk_iters = k_iters_per_tile.get();
uint32_t best_sk_score = uint32_t best_sk_score =
std::numeric_limits<int>::max(); // we need to find the smallest sk iters ck::NumericLimits<int32_t>::Max(); // we need to find the smallest sk iters
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
tentative_sk_blocks++) tentative_sk_blocks++)
{ {
......
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
#pragma once #pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#ifndef __HIPCC_RTC__
#include <iostream>
#endif
namespace ck { namespace ck {
...@@ -38,7 +39,9 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -38,7 +39,9 @@ constexpr auto GridwiseGemmPipeline_Selector()
} }
else else
{ {
#ifndef __HIPCC_RTC__
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
#endif
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -7,10 +7,12 @@ ...@@ -7,10 +7,12 @@
#include "ck/utility/functional2.hpp" #include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#ifndef __HIPCC_RTC__
#include <array> #include <array>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
#endif
namespace ck { namespace ck {
namespace detail { namespace detail {
...@@ -37,7 +39,7 @@ struct get_carrier<3> ...@@ -37,7 +39,7 @@ struct get_carrier<3>
{ {
using value_type = uint32_t; using value_type = uint32_t;
std::array<std::byte, 3> bytes; ck::byte bytes[3];
static_assert(sizeof(bytes) <= sizeof(value_type)); static_assert(sizeof(bytes) <= sizeof(value_type));
// replacement of host std::copy_n() // replacement of host std::copy_n()
...@@ -59,24 +61,21 @@ struct get_carrier<3> ...@@ -59,24 +61,21 @@ struct get_carrier<3>
} }
// method to trigger template substitution failure // method to trigger template substitution failure
__device__ carrier(const carrier& other) noexcept __device__ carrier(const carrier& other) noexcept { copy_n(&other.bytes[0], 3, &bytes[0]); }
{
copy_n(other.bytes.begin(), bytes.size(), bytes.begin());
}
public: public:
__device__ carrier& operator=(value_type value) noexcept __device__ carrier& operator=(value_type value) noexcept
{ {
copy_n(reinterpret_cast<const std::byte*>(&value), bytes.size(), bytes.begin()); copy_n(reinterpret_cast<const ck::byte*>(&value), 3, &bytes[0]);
return *this; return *this;
} }
__device__ operator value_type() const noexcept __device__ operator value_type() const noexcept
{ {
std::byte result[sizeof(value_type)]; ck::byte result[sizeof(value_type)];
copy_n(bytes.begin(), bytes.size(), result); copy_n(&bytes[0], 3, result);
return *reinterpret_cast<const value_type*>(result); return *reinterpret_cast<const value_type*>(result);
} }
...@@ -100,17 +99,17 @@ __device__ inline int32_t amd_wave_read_first_lane(int32_t value) ...@@ -100,17 +99,17 @@ __device__ inline int32_t amd_wave_read_first_lane(int32_t value)
return __builtin_amdgcn_readfirstlane(value); return __builtin_amdgcn_readfirstlane(value);
} }
template < template <typename Object,
typename Object, typename = ck::enable_if_t<ck::is_class<Object>::value &&
typename = std::enable_if_t<std::is_class_v<Object> && std::is_trivially_copyable_v<Object>>> ck::is_trivially_copyable<Object>::value>>
__device__ auto amd_wave_read_first_lane(const Object& obj) __device__ auto amd_wave_read_first_lane(const Object& obj)
{ {
using Size = unsigned; using Size = unsigned;
constexpr Size SgprSize = 4; constexpr Size SgprSize = 4;
constexpr Size ObjectSize = sizeof(Object); constexpr Size ObjectSize = sizeof(Object);
auto* const from_obj = reinterpret_cast<const std::byte*>(&obj); auto* const from_obj = reinterpret_cast<const ck::byte*>(&obj);
alignas(Object) std::byte to_obj[ObjectSize]; alignas(Object) ck::byte to_obj[ObjectSize];
constexpr Size RemainedSize = ObjectSize % SgprSize; constexpr Size RemainedSize = ObjectSize % SgprSize;
constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize;
......
...@@ -52,7 +52,7 @@ template <typename X, typename... Xs> ...@@ -52,7 +52,7 @@ template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
{ {
using data_type = remove_cvref_t<X>; using data_type = remove_cvref_t<X>;
return Array<data_type, sizeof...(Xs) + 1>{std::forward<X>(x), std::forward<Xs>(xs)...}; return Array<data_type, sizeof...(Xs) + 1>{ck::forward<X>(x), ck::forward<Xs>(xs)...};
} }
// make empty array // make empty array
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment