Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
546a764e
Commit
546a764e
authored
Oct 24, 2023
by
Artur Wojcik
Browse files
Merge branch 'migraphx' into uif2-migraphx
parents
8da3dfff
57cdd70b
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1079 additions
and
338 deletions
+1079
-338
.gitignore
.gitignore
+10
-0
CMakeLists.txt
CMakeLists.txt
+160
-148
Config.cmake.in
Config.cmake.in
+7
-2
cmake/Embed.cmake
cmake/Embed.cmake
+202
-0
include/ck/ck.hpp
include/ck/ck.hpp
+2
-1
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+2
-0
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+2
-1
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+7
-4
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
...operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
+4
-1
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
...ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
+8
-5
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
...de/ck/tensor_operation/gpu/device/gemm_specialization.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+360
-24
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+234
-84
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
...ck/tensor_operation/gpu/device/masking_specialization.hpp
+3
-1
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+1
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+54
-47
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+5
-2
include/ck/utility/amd_wave_read_first_lane.hpp
include/ck/utility/amd_wave_read_first_lane.hpp
+13
-14
include/ck/utility/array.hpp
include/ck/utility/array.hpp
+1
-1
No files found.
.gitignore
View file @
546a764e
...
@@ -48,6 +48,13 @@ build*
...
@@ -48,6 +48,13 @@ build*
.gdb_history
.gdb_history
install.dir*
install.dir*
# directories containing generated documentation
docs/source/_build/
docs/docBin/
# Generated source
library/src/jit_library/solution_instances/
# documentation artifacts
# documentation artifacts
_build/
_build/
_images/
_images/
...
@@ -57,6 +64,9 @@ _toc.yml
...
@@ -57,6 +64,9 @@ _toc.yml
docBin/
docBin/
_doxygen/
_doxygen/
# pycache
__pycache__/
# JetBrains IDE
# JetBrains IDE
.idea/
.idea/
cmake-build*/
cmake-build*/
...
...
CMakeLists.txt
View file @
546a764e
...
@@ -145,88 +145,91 @@ if(GPU_TARGETS)
...
@@ -145,88 +145,91 @@ if(GPU_TARGETS)
else
()
else
()
message
(
"Building CK for the following targets:
${
AMDGPU_TARGETS
}
"
)
message
(
"Building CK for the following targets:
${
AMDGPU_TARGETS
}
"
)
endif
()
endif
()
find_package
(
hip
)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
math
(
EXPR hip_VERSION_FLAT
"(
${
hip_VERSION_MAJOR
}
* 1000 +
${
hip_VERSION_MINOR
}
) * 100000 +
${
hip_VERSION_PATCH
}
"
)
message
(
"hip_version_flat=
${
hip_VERSION_FLAT
}
"
)
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 500723302
)
message
(
"Adding the fno-offload-uniform-block compiler flag"
)
add_compile_options
(
-fno-offload-uniform-block
)
endif
()
option
(
USE_BITINT_EXTENSION_INT4,
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
option
(
USE_OPT_NAVI3X,
"Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons."
OFF
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_compile_definitions
(
CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
)
add_compile_options
(
-Wno-bit-int-extension
)
message
(
"CK compiled with USE_BITINT_EXTENSION_INT4 set to
${
USE_BITINT_EXTENSION_INT4
}
"
)
endif
()
if
(
USE_OPT_NAVI3X
)
add_compile_options
(
-mcumode
)
add_compile_options
(
-mno-wavefrontsize64
)
message
(
"CK compiled with USE_OPT_NAVI3X set to
${
USE_OPT_NAVI3X
}
"
)
endif
()
## Threads
if
(
NOT WIN32
)
set
(
THREADS_PREFER_PTHREAD_FLAG ON
)
endif
()
find_package
(
Threads REQUIRED
)
link_libraries
(
Threads::Threads
)
## C++
## C++
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_STANDARD_REQUIRED ON
)
set
(
CMAKE_CXX_STANDARD_REQUIRED ON
)
set
(
CMAKE_CXX_EXTENSIONS OFF
)
set
(
CMAKE_CXX_EXTENSIONS OFF
)
message
(
"CMAKE_CXX_COMPILER_ID:
${
CMAKE_CXX_COMPILER_ID
}
"
)
message
(
"CMAKE_CXX_COMPILER_ID:
${
CMAKE_CXX_COMPILER_ID
}
"
)
## OpenMP
option
(
CK_BUILD_JIT_LIB
"Only build the CK JIT Helper Library"
OFF
)
if
(
CMAKE_CXX_COMPILER_ID MATCHES
"Clang"
)
if
(
NOT CK_BUILD_JIT_LIB
)
# workaround issue hipcc in rocm3.5 cannot find openmp
find_package
(
hip
)
set
(
OpenMP_CXX
"
${
CMAKE_CXX_COMPILER
}
"
)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
set
(
OpenMP_CXX_FLAGS
"-fopenmp=libomp -Wno-unused-command-line-argument"
)
# SWDEV-413293 and https://reviews.llvm.org/D155213
set
(
OpenMP_CXX_LIB_NAMES
"libomp"
"libgomp"
"libiomp5"
)
math
(
EXPR hip_VERSION_FLAT
"(
${
hip_VERSION_MAJOR
}
* 1000 +
${
hip_VERSION_MINOR
}
) * 100000 +
${
hip_VERSION_PATCH
}
"
)
set
(
OpenMP_libomp_LIBRARY
${
OpenMP_CXX_LIB_NAMES
}
)
message
(
"hip_version_flat=
${
hip_VERSION_FLAT
}
"
)
set
(
OpenMP_libgomp_LIBRARY
${
OpenMP_CXX_LIB_NAMES
}
)
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 500723302
)
set
(
OpenMP_libiomp5_LIBRARY
${
OpenMP_CXX_LIB_NAMES
}
)
message
(
"Adding the fno-offload-uniform-block compiler flag"
)
else
()
add_compile_options
(
-fno-offload-uniform-block
)
find_package
(
OpenMP REQUIRED
)
endif
()
endif
()
message
(
"OpenMP_CXX_LIB_NAMES:
${
OpenMP_CXX_LIB_NAMES
}
"
)
option
(
USE_BITINT_EXTENSION_INT4,
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
message
(
"OpenMP_gomp_LIBRARY:
${
OpenMP_gomp_LIBRARY
}
"
)
option
(
USE_OPT_NAVI3X,
"Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons."
OFF
)
message
(
"OpenMP_pthread_LIBRARY:
${
OpenMP_pthread_LIBRARY
}
"
)
message
(
"OpenMP_CXX_FLAGS:
${
OpenMP_CXX_FLAGS
}
"
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_compile_definitions
(
CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
)
link_libraries
(
${
OpenMP_gomp_LIBRARY
}
)
add_compile_options
(
-Wno-bit-int-extension
)
link_libraries
(
${
OpenMP_pthread_LIBRARY
}
)
message
(
"CK compiled with USE_BITINT_EXTENSION_INT4 set to
${
USE_BITINT_EXTENSION_INT4
}
"
)
endif
()
## HIP
find_package
(
HIP REQUIRED
)
if
(
USE_OPT_NAVI3X
)
# Override HIP version in config.h, if necessary.
add_compile_options
(
-mcumode
)
# The variables set by find_package() can't be overwritten,
add_compile_options
(
-mno-wavefrontsize64
)
# therefore let's use intermediate variables.
message
(
"CK compiled with USE_OPT_NAVI3X set to
${
USE_OPT_NAVI3X
}
"
)
set
(
CK_HIP_VERSION_MAJOR
"
${
HIP_VERSION_MAJOR
}
"
)
endif
()
set
(
CK_HIP_VERSION_MINOR
"
${
HIP_VERSION_MINOR
}
"
)
set
(
CK_HIP_VERSION_PATCH
"
${
HIP_VERSION_PATCH
}
"
)
## Threads
if
(
DEFINED CK_OVERRIDE_HIP_VERSION_MAJOR
)
if
(
NOT WIN32
)
set
(
CK_HIP_VERSION_MAJOR
"
${
CK_OVERRIDE_HIP_VERSION_MAJOR
}
"
)
set
(
THREADS_PREFER_PTHREAD_FLAG ON
)
message
(
STATUS
"CK_HIP_VERSION_MAJOR overriden with
${
CK_OVERRIDE_HIP_VERSION_MAJOR
}
"
)
endif
()
endif
()
find_package
(
Threads REQUIRED
)
if
(
DEFINED CK_OVERRIDE_HIP_VERSION_MINOR
)
link_libraries
(
Threads::Threads
)
set
(
CK_HIP_VERSION_MINOR
"
${
CK_OVERRIDE_HIP_VERSION_MINOR
}
"
)
message
(
STATUS
"CK_HIP_VERSION_MINOR overriden with
${
CK_OVERRIDE_HIP_VERSION_MINOR
}
"
)
## OpenMP
endif
()
if
(
CMAKE_CXX_COMPILER_ID MATCHES
"Clang"
)
if
(
DEFINED CK_OVERRIDE_HIP_VERSION_PATCH
)
# workaround issue hipcc in rocm3.5 cannot find openmp
set
(
CK_HIP_VERSION_PATCH
"
${
CK_OVERRIDE_HIP_VERSION_PATCH
}
"
)
set
(
OpenMP_CXX
"
${
CMAKE_CXX_COMPILER
}
"
)
message
(
STATUS
"CK_HIP_VERSION_PATCH overriden with
${
CK_OVERRIDE_HIP_VERSION_PATCH
}
"
)
set
(
OpenMP_CXX_FLAGS
"-fopenmp=libomp -Wno-unused-command-line-argument"
)
set
(
OpenMP_CXX_LIB_NAMES
"libomp"
"libgomp"
"libiomp5"
)
set
(
OpenMP_libomp_LIBRARY
${
OpenMP_CXX_LIB_NAMES
}
)
set
(
OpenMP_libgomp_LIBRARY
${
OpenMP_CXX_LIB_NAMES
}
)
set
(
OpenMP_libiomp5_LIBRARY
${
OpenMP_CXX_LIB_NAMES
}
)
else
()
find_package
(
OpenMP REQUIRED
)
endif
()
message
(
"OpenMP_CXX_LIB_NAMES:
${
OpenMP_CXX_LIB_NAMES
}
"
)
message
(
"OpenMP_gomp_LIBRARY:
${
OpenMP_gomp_LIBRARY
}
"
)
message
(
"OpenMP_pthread_LIBRARY:
${
OpenMP_pthread_LIBRARY
}
"
)
message
(
"OpenMP_CXX_FLAGS:
${
OpenMP_CXX_FLAGS
}
"
)
link_libraries
(
${
OpenMP_gomp_LIBRARY
}
)
link_libraries
(
${
OpenMP_pthread_LIBRARY
}
)
## HIP
find_package
(
HIP REQUIRED
)
# Override HIP version in config.h, if necessary.
# The variables set by find_package() can't be overwritten,
# therefore let's use intermediate variables.
set
(
CK_HIP_VERSION_MAJOR
"
${
HIP_VERSION_MAJOR
}
"
)
set
(
CK_HIP_VERSION_MINOR
"
${
HIP_VERSION_MINOR
}
"
)
set
(
CK_HIP_VERSION_PATCH
"
${
HIP_VERSION_PATCH
}
"
)
if
(
DEFINED CK_OVERRIDE_HIP_VERSION_MAJOR
)
set
(
CK_HIP_VERSION_MAJOR
"
${
CK_OVERRIDE_HIP_VERSION_MAJOR
}
"
)
message
(
STATUS
"CK_HIP_VERSION_MAJOR overriden with
${
CK_OVERRIDE_HIP_VERSION_MAJOR
}
"
)
endif
()
if
(
DEFINED CK_OVERRIDE_HIP_VERSION_MINOR
)
set
(
CK_HIP_VERSION_MINOR
"
${
CK_OVERRIDE_HIP_VERSION_MINOR
}
"
)
message
(
STATUS
"CK_HIP_VERSION_MINOR overriden with
${
CK_OVERRIDE_HIP_VERSION_MINOR
}
"
)
endif
()
if
(
DEFINED CK_OVERRIDE_HIP_VERSION_PATCH
)
set
(
CK_HIP_VERSION_PATCH
"
${
CK_OVERRIDE_HIP_VERSION_PATCH
}
"
)
message
(
STATUS
"CK_HIP_VERSION_PATCH overriden with
${
CK_OVERRIDE_HIP_VERSION_PATCH
}
"
)
endif
()
message
(
STATUS
"Build with HIP
${
HIP_VERSION
}
"
)
link_libraries
(
hip::device
)
add_compile_definitions
(
__HIP_PLATFORM_HCC__=1
)
endif
()
endif
()
message
(
STATUS
"Build with HIP
${
HIP_VERSION
}
"
)
link_libraries
(
hip::device
)
add_compile_definitions
(
__HIP_PLATFORM_HCC__=1
)
## tidy
## tidy
include
(
EnableCompilerWarnings
)
include
(
EnableCompilerWarnings
)
...
@@ -381,89 +384,98 @@ include_directories(BEFORE
...
@@ -381,89 +384,98 @@ include_directories(BEFORE
${
HIP_INCLUDE_DIRS
}
${
HIP_INCLUDE_DIRS
}
)
)
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
add_custom_target
(
check COMMAND
${
CMAKE_CTEST_COMMAND
}
--output-on-failure -C
${
CMAKE_CFG_INTDIR
}
)
add_custom_target
(
check COMMAND
${
CMAKE_CTEST_COMMAND
}
--output-on-failure -C
${
CMAKE_CFG_INTDIR
}
)
file
(
GLOB_RECURSE INSTANCE_FILES
"
${
PROJECT_SOURCE_DIR
}
/*/device_*_instance.cpp"
)
if
(
NOT CK_BUILD_JIT_LIB
)
file
(
GLOB dir_list RELATIVE
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/*
)
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
set
(
CK_DEVICE_INSTANCES
)
FOREACH
(
subdir_path
${
dir_list
}
)
file
(
GLOB_RECURSE INSTANCE_FILES
"
${
PROJECT_SOURCE_DIR
}
/*/device_*_instance.cpp"
)
set
(
target_dir
)
file
(
GLOB dir_list RELATIVE
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/*
)
IF
(
IS_DIRECTORY
"
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/
${
subdir_path
}
"
)
set
(
CK_DEVICE_INSTANCES
)
set
(
cmake_instance
)
FOREACH
(
subdir_path
${
dir_list
}
)
file
(
READ
"
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/
${
subdir_path
}
/CMakeLists.txt"
cmake_instance
)
set
(
target_dir
)
set
(
add_inst 0
)
IF
(
IS_DIRECTORY
"
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/
${
subdir_path
}
"
)
if
((
"
${
cmake_instance
}
"
MATCHES
"fp8"
OR
"
${
cmake_instance
}
"
MATCHES
"_f8"
)
AND DTYPES MATCHES
"fp8"
)
set
(
cmake_instance
)
#message("fp8 instance found!")
file
(
READ
"
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/
${
subdir_path
}
/CMakeLists.txt"
cmake_instance
)
set
(
add_inst 1
)
set
(
add_inst 0
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"fp8"
OR
"
${
cmake_instance
}
"
MATCHES
"_f8"
)
AND DTYPES MATCHES
"fp8"
)
if
((
"
${
cmake_instance
}
"
MATCHES
"bf8"
OR
"
${
cmake_instance
}
"
MATCHES
"_b8"
)
AND DTYPES MATCHES
"bf8"
)
#message("fp8 instance found!")
#message("bf8 instance found!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"bf8"
OR
"
${
cmake_instance
}
"
MATCHES
"_b8"
)
AND DTYPES MATCHES
"bf8"
)
if
((
"
${
cmake_instance
}
"
MATCHES
"fp16"
OR
"
${
cmake_instance
}
"
MATCHES
"_f16"
)
AND DTYPES MATCHES
"fp16"
)
#message("bf8 instance found!")
#message("fp16 instance found!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"fp16"
OR
"
${
cmake_instance
}
"
MATCHES
"_f16"
)
AND DTYPES MATCHES
"fp16"
)
if
((
"
${
cmake_instance
}
"
MATCHES
"fp32"
OR
"
${
cmake_instance
}
"
MATCHES
"_f32"
)
AND DTYPES MATCHES
"fp32"
)
#message("fp16 instance found!")
#message("fp32 instance found!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"fp32"
OR
"
${
cmake_instance
}
"
MATCHES
"_f32"
)
AND DTYPES MATCHES
"fp32"
)
if
((
"
${
cmake_instance
}
"
MATCHES
"fp64"
OR
"
${
cmake_instance
}
"
MATCHES
"_f64"
)
AND DTYPES MATCHES
"fp64"
)
#message("fp32 instance found!")
#message("fp64 instance found!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"fp64"
OR
"
${
cmake_instance
}
"
MATCHES
"_f64"
)
AND DTYPES MATCHES
"fp64"
)
if
((
"
${
cmake_instance
}
"
MATCHES
"bf16"
OR
"
${
cmake_instance
}
"
MATCHES
"_b16"
)
AND DTYPES MATCHES
"bf16"
)
#message("fp64 instance found!")
#message("bf16 instance found!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"bf16"
OR
"
${
cmake_instance
}
"
MATCHES
"_b16"
)
AND DTYPES MATCHES
"bf16"
)
if
((
"
${
cmake_instance
}
"
MATCHES
"int8"
OR
"
${
cmake_instance
}
"
MATCHES
"_i8"
)
AND DTYPES MATCHES
"int8"
)
#message("bf16 instance found!")
#message("int8 instance found!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"int8"
OR
"
${
cmake_instance
}
"
MATCHES
"_i8"
)
AND DTYPES MATCHES
"int8"
)
if
(
NOT
"
${
cmake_instance
}
"
MATCHES
"DTYPES"
)
#message("int8 instance found!")
#message("instance should be built for all types!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
(
NOT
"
${
cmake_instance
}
"
MATCHES
"DTYPES"
)
if
(
add_inst EQUAL 1 OR NOT DEFINED DTYPES
)
#message("instance should be built for all types!")
list
(
APPEND CK_DEVICE_INSTANCES device_
${
subdir_path
}
_instance
)
set
(
add_inst 1
)
endif
()
endif
()
ENDIF
()
if
(
add_inst EQUAL 1 OR NOT DEFINED DTYPES
)
ENDFOREACH
()
list
(
APPEND CK_DEVICE_INSTANCES device_
${
subdir_path
}
_instance
)
endif
()
add_custom_target
(
instances DEPENDS utility;
${
CK_DEVICE_INSTANCES
}
SOURCES
${
INSTANCE_FILES
}
)
ENDIF
()
add_subdirectory
(
library
)
ENDFOREACH
()
if
(
NOT DEFINED INSTANCES_ONLY
)
add_custom_target
(
instances DEPENDS utility;
${
CK_DEVICE_INSTANCES
}
SOURCES
${
INSTANCE_FILES
}
)
if
(
NOT DEFINED PROFILER_ONLY
)
add_subdirectory
(
library
)
rocm_package_setup_component
(
tests
LIBRARY_NAME composablekernel
if
(
NOT DEFINED INSTANCES_ONLY
)
PACKAGE_NAME tests
# Prevent -static suffix on package name
if
(
NOT DEFINED PROFILER_ONLY
)
)
rocm_package_setup_component
(
tests
LIBRARY_NAME composablekernel
PACKAGE_NAME tests
# Prevent -static suffix on package name
)
rocm_package_setup_component
(
examples
rocm_package_setup_component
(
examples
LIBRARY_NAME composablekernel
LIBRARY_NAME composablekernel
PACKAGE_NAME examples
PACKAGE_NAME examples
)
)
add_subdirectory
(
example
)
add_subdirectory
(
example
)
if
(
BUILD_TESTING
)
if
(
BUILD_TESTING
)
add_subdirectory
(
test
)
add_subdirectory
(
test
)
endif
()
endif
()
rocm_package_setup_component
(
profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler
)
add_subdirectory
(
profiler
)
else
()
#When building PROFILER_ONLY, label the package with GPU_ARCH
rocm_package_setup_component
(
profiler
rocm_package_setup_component
(
profiler
LIBRARY_NAME composablekernel
LIBRARY_NAME composablekernel
PACKAGE_NAME ck
p
rofiler
_
${
GPU_ARCH
}
PACKAGE_NAME ck
P
rofiler
)
)
add_subdirectory
(
profiler
)
add_subdirectory
(
profiler
)
endif
()
else
()
#When building PROFILER_ONLY, label the package with GPU_ARCH
rocm_package_setup_component
(
profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckProfiler_
${
GPU_ARCH
}
)
add_subdirectory
(
profiler
)
endif
()
endif
()
else
()
rocm_package_setup_component
(
jit_library
LIBRARY_NAME composablekernel
PACKAGE_NAME jit_library
)
add_subdirectory
(
library
)
add_subdirectory
(
test
)
endif
()
endif
()
#Create an interface target for the include only files and call it "composablekernels"
#Create an interface target for the include only files and call it "composablekernels"
...
...
Config.cmake.in
View file @
546a764e
@PACKAGE_INIT@
@PACKAGE_INIT@
set(_composable_kernel_supported_components device_operations utility)
set(_composable_kernel_supported_components device_operations utility
jit_library
)
foreach(_comp ${composable_kernel_FIND_COMPONENTS})
foreach(_comp ${composable_kernel_FIND_COMPONENTS})
if(NOT _comp IN_LIST _composable_kernel_supported_components)
if(NOT _comp IN_LIST _composable_kernel_supported_components)
set(composable_kernel_FOUND False)
set(composable_kernel_FOUND False)
set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}")
set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}")
endif()
endif()
include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake")
if(EXISTS "${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake")
include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake")
else()
set(composable_kernel_FOUND False)
set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}")
endif()
endforeach()
endforeach()
cmake/Embed.cmake
0 → 100644
View file @
546a764e
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
find_program
(
EMBED_LD ld
)
find_program
(
EMBED_OBJCOPY objcopy
)
option
(
EMBED_USE_LD
"Use ld to embed data files"
OFF
)
function
(
wrap_string
)
set
(
options
)
set
(
oneValueArgs VARIABLE AT_COLUMN
)
set
(
multiValueArgs
)
cmake_parse_arguments
(
PARSE
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
cmake_parse_arguments
(
WRAP_STRING
"
${
options
}
"
"
${
oneValueArgs
}
"
""
${
ARGN
}
)
string
(
LENGTH
${${
PARSE_VARIABLE
}}
string_length
)
math
(
EXPR offset
"0"
)
while
(
string_length GREATER 0
)
if
(
string_length GREATER
${
PARSE_AT_COLUMN
}
)
math
(
EXPR length
"
${
PARSE_AT_COLUMN
}
"
)
else
()
math
(
EXPR length
"
${
string_length
}
"
)
endif
()
string
(
SUBSTRING
${${
PARSE_VARIABLE
}}
${
offset
}
${
length
}
line
)
set
(
lines
"
${
lines
}
\n
${
line
}
"
)
math
(
EXPR string_length
"
${
string_length
}
-
${
length
}
"
)
math
(
EXPR offset
"
${
offset
}
+
${
length
}
"
)
endwhile
()
set
(
${
PARSE_VARIABLE
}
"
${
lines
}
"
PARENT_SCOPE
)
endfunction
()
function
(
generate_embed_source EMBED_NAME
)
set
(
options
)
set
(
oneValueArgs SRC HEADER RELATIVE
)
set
(
multiValueArgs OBJECTS SYMBOLS FILES
)
cmake_parse_arguments
(
PARSE
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
set
(
EXTERNS
)
set
(
INIT_KERNELS
)
list
(
LENGTH PARSE_SYMBOLS SYMBOLS_LEN
)
list
(
LENGTH PARSE_OBJECTS OBJECTS_LEN
)
if
(
NOT
${
SYMBOLS_LEN
}
EQUAL
${
OBJECTS_LEN
}
)
message
(
FATAL_ERROR
"Symbols and objects dont match:
${
SYMBOLS_LEN
}
!=
${
OBJECTS_LEN
}
"
)
endif
()
math
(
EXPR LEN
"
${
SYMBOLS_LEN
}
- 1"
)
foreach
(
idx RANGE
${
LEN
}
)
list
(
GET PARSE_SYMBOLS
${
idx
}
SYMBOL
)
list
(
GET PARSE_OBJECTS
${
idx
}
OBJECT
)
list
(
GET PARSE_FILES
${
idx
}
FILE
)
set
(
START_SYMBOL
"_binary_
${
SYMBOL
}
_start"
)
set
(
LENGTH_SYMBOL
"_binary_
${
SYMBOL
}
_length"
)
if
(
EMBED_USE_LD
)
string
(
APPEND EXTERNS
"
extern const char
${
START_SYMBOL
}
[];
extern const size_t _binary_
${
SYMBOL
}
_size;
const auto
${
LENGTH_SYMBOL
}
= reinterpret_cast<size_t>(&_binary_
${
SYMBOL
}
_size);
"
)
else
()
string
(
APPEND EXTERNS
"
extern const char
${
START_SYMBOL
}
[];
extern const size_t
${
LENGTH_SYMBOL
}
;
"
)
endif
()
if
(
PARSE_RELATIVE
)
file
(
RELATIVE_PATH BASE_NAME
${
PARSE_RELATIVE
}
"
${
FILE
}
"
)
else
()
get_filename_component
(
BASE_NAME
"
${
FILE
}
"
NAME
)
endif
()
string
(
APPEND INIT_KERNELS
"
{
\"
${
BASE_NAME
}
\"
, {
${
START_SYMBOL
}
,
${
LENGTH_SYMBOL
}
} },"
)
endforeach
()
file
(
WRITE
"
${
PARSE_HEADER
}
"
"
#include <string_view>
#include <unordered_map>
#include <utility>
std::unordered_map<std::string_view, std::string_view>
${
EMBED_NAME
}
();
"
)
file
(
WRITE
"
${
PARSE_SRC
}
"
"
#include <
${
EMBED_NAME
}
.hpp>
${
EXTERNS
}
std::unordered_map<std::string_view, std::string_view>
${
EMBED_NAME
}
()
{
static std::unordered_map<std::string_view, std::string_view> result = {
${
INIT_KERNELS
}
};
return result;
}
"
)
endfunction
()
function
(
embed_file OUTPUT_FILE OUTPUT_SYMBOL FILE
)
set
(
WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
)
# Glob is used to compute the relative path
file
(
GLOB FILES RELATIVE
${
WORKING_DIRECTORY
}
${
FILE
}
)
foreach
(
REL_FILE
${
FILES
}
)
string
(
MAKE_C_IDENTIFIER
"
${
REL_FILE
}
"
SYMBOL
)
get_filename_component
(
OUTPUT_FILE_DIR
"
${
REL_FILE
}
"
DIRECTORY
)
file
(
MAKE_DIRECTORY
"
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
OUTPUT_FILE_DIR
}
"
)
if
(
EMBED_USE_LD
)
set
(
OUT_FILE
"
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
REL_FILE
}
.o"
)
else
()
set
(
OUT_FILE
"
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
REL_FILE
}
.cpp"
)
endif
()
set
(
${
OUTPUT_SYMBOL
}
${
SYMBOL
}
PARENT_SCOPE
)
set
(
${
OUTPUT_FILE
}
"
${
OUT_FILE
}
"
PARENT_SCOPE
)
if
(
EMBED_USE_LD
)
add_custom_command
(
OUTPUT
"
${
OUT_FILE
}
"
COMMAND
${
EMBED_LD
}
-r -o
"
${
OUT_FILE
}
"
-z noexecstack --format=binary
"
${
REL_FILE
}
"
COMMAND
${
EMBED_OBJCOPY
}
--rename-section .data=.rodata,alloc,load,readonly,data,contents
"
${
OUT_FILE
}
"
WORKING_DIRECTORY
${
WORKING_DIRECTORY
}
DEPENDS
${
FILE
}
VERBATIM
)
else
()
set_property
(
DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS
${
FILE
}
)
# reads source file contents as hex string
file
(
READ
${
FILE
}
HEX_STRING HEX
)
# wraps the hex string into multiple lines
wrap_string
(
VARIABLE HEX_STRING AT_COLUMN 80
)
# adds '0x' prefix and comma suffix before and after every byte respectively
string
(
REGEX REPLACE
"([0-9a-f][0-9a-f])"
"0x
\\
1, "
ARRAY_VALUES
${
HEX_STRING
}
)
# removes trailing comma
string
(
REGEX REPLACE
", $"
""
ARRAY_VALUES
${
ARRAY_VALUES
}
)
file
(
WRITE
"
${
OUT_FILE
}
"
"
#include <cstddef>
extern const char _binary_
${
SYMBOL
}
_start[] = {
${
ARRAY_VALUES
}
};
extern const size_t _binary_
${
SYMBOL
}
_length = sizeof(_binary_
${
SYMBOL
}
_start);
"
)
endif
()
endforeach
()
endfunction
()
function
(
add_embed_library EMBED_NAME
)
set
(
options
)
set
(
oneValueArgs RELATIVE
)
set
(
multiValueArgs
)
cmake_parse_arguments
(
PARSE
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
file
(
MAKE_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/embed
)
file
(
MAKE_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/embed/
${
EMBED_NAME
}
)
set
(
EMBED_DIR
${
CMAKE_CURRENT_BINARY_DIR
}
/embed/
${
EMBED_NAME
}
)
set
(
SRC_FILE
"
${
EMBED_DIR
}
/
${
EMBED_NAME
}
.cpp"
)
set
(
HEADER_FILE
"
${
EMBED_DIR
}
/include/
${
EMBED_NAME
}
.hpp"
)
set
(
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
set
(
OUTPUT_FILES
)
set
(
SYMBOLS
)
message
(
STATUS
"Embedding files"
)
foreach
(
FILE
${
PARSE_UNPARSED_ARGUMENTS
}
)
embed_file
(
OUTPUT_FILE OUTPUT_SYMBOL
${
FILE
}
)
list
(
APPEND OUTPUT_FILES
${
OUTPUT_FILE
}
)
list
(
APPEND SYMBOLS
${
OUTPUT_SYMBOL
}
)
endforeach
()
message
(
STATUS
"Generating embedding library
${
EMBED_NAME
}
"
)
generate_embed_source
(
${
EMBED_NAME
}
SRC
${
SRC_FILE
}
HEADER
${
HEADER_FILE
}
OBJECTS
${
OUTPUT_FILES
}
SYMBOLS
${
SYMBOLS
}
RELATIVE
${
PARSE_RELATIVE
}
FILES
${
PARSE_UNPARSED_ARGUMENTS
}
)
set
(
INTERNAL_EMBED_LIB embed_lib_
${
EMBED_NAME
}
)
add_library
(
${
INTERNAL_EMBED_LIB
}
OBJECT
"
${
SRC_FILE
}
"
)
target_include_directories
(
${
INTERNAL_EMBED_LIB
}
PRIVATE
"
${
EMBED_DIR
}
/include"
)
target_compile_options
(
${
INTERNAL_EMBED_LIB
}
PRIVATE -Wno-reserved-identifier -Wno-extern-initializer -Wno-missing-variable-declarations
)
set_target_properties
(
${
INTERNAL_EMBED_LIB
}
PROPERTIES POSITION_INDEPENDENT_CODE On
)
add_library
(
${
EMBED_NAME
}
INTERFACE
)
if
(
EMBED_USE_LD
)
target_sources
(
${
EMBED_NAME
}
INTERFACE
${
OUTPUT_FILES
}
)
else
()
target_sources
(
${
INTERNAL_EMBED_LIB
}
PRIVATE
${
OUTPUT_FILES
}
)
endif
()
target_sources
(
${
EMBED_NAME
}
INTERFACE $<TARGET_OBJECTS:
${
INTERNAL_EMBED_LIB
}
>
)
target_include_directories
(
${
EMBED_NAME
}
INTERFACE
"
${
EMBED_DIR
}
/include"
)
endfunction
()
include/ck/ck.hpp
View file @
546a764e
...
@@ -4,11 +4,12 @@
...
@@ -4,11 +4,12 @@
#pragma once
#pragma once
#include "ck/config.h"
#include "ck/config.h"
#ifndef __HIPCC_RTC__
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "hip/hip_fp16.h"
#endif
#endif
#endif
#define CK_TIME_KERNEL 1
#define CK_TIME_KERNEL 1
...
...
include/ck/host_utility/device_prop.hpp
View file @
546a764e
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <string>
#include <string>
#include <map>
#include <map>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
...
@@ -59,3 +60,4 @@ inline bool is_xdl_supported()
...
@@ -59,3 +60,4 @@ inline bool is_xdl_supported()
}
}
}
// namespace ck
}
// namespace ck
#endif
include/ck/host_utility/kernel_launch.hpp
View file @
546a764e
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
...
@@ -150,3 +150,4 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -150,3 +150,4 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return
0
;
return
0
;
#endif
#endif
}
}
#endif
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
546a764e
...
@@ -2,16 +2,17 @@
...
@@ -2,16 +2,17 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <string>
#include <string>
#include <sstream>
#include <sstream>
#include "ck/stream_config.hpp"
#include "ck/stream_config.hpp"
#endif
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
#ifndef __HIPCC_RTC__
struct
BaseArgument
struct
BaseArgument
{
{
BaseArgument
()
=
default
;
BaseArgument
()
=
default
;
...
@@ -36,6 +37,7 @@ struct BaseInvoker
...
@@ -36,6 +37,7 @@ struct BaseInvoker
virtual
~
BaseInvoker
()
{}
virtual
~
BaseInvoker
()
{}
};
};
#endif
struct
BaseOperator
struct
BaseOperator
{
{
...
@@ -43,7 +45,9 @@ struct BaseOperator
...
@@ -43,7 +45,9 @@ struct BaseOperator
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
#ifndef __HIPCC_RTC__
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
virtual
std
::
string
GetTypeIdName
()
const
{
return
typeid
(
*
this
).
name
();
}
virtual
std
::
string
GetTypeIdName
()
const
{
return
typeid
(
*
this
).
name
();
}
...
@@ -56,7 +60,6 @@ struct BaseOperator
...
@@ -56,7 +60,6 @@ struct BaseOperator
return
oss
.
str
();
return
oss
.
str
();
};
};
virtual
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
)
const
{
return
0
;
}
virtual
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
)
const
{
return
0
;
}
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
)
const
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
)
const
...
@@ -64,7 +67,7 @@ struct BaseOperator
...
@@ -64,7 +67,7 @@ struct BaseOperator
assert
(
p_arg
);
assert
(
p_arg
);
p_arg
->
p_workspace_
=
p_workspace
;
p_arg
->
p_workspace_
=
p_workspace
;
}
}
#endif
virtual
~
BaseOperator
()
{}
virtual
~
BaseOperator
()
{}
};
};
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
View file @
546a764e
...
@@ -2,9 +2,10 @@
...
@@ -2,9 +2,10 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
#endif
#include "device_base.hpp"
#include "device_base.hpp"
...
@@ -28,6 +29,7 @@ template <typename ALayout,
...
@@ -28,6 +29,7 @@ template <typename ALayout,
bool
MaskOutUpperTriangle
>
// TODO: enum for mask type
bool
MaskOutUpperTriangle
>
// TODO: enum for mask type
struct
DeviceBatchedGemmSoftmaxGemm
:
public
BaseOperator
struct
DeviceBatchedGemmSoftmaxGemm
:
public
BaseOperator
{
{
#ifndef __HIPCC_RTC__
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b0
,
const
void
*
p_b0
,
...
@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
...
@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
CElementwiseOperation
c_element_op
)
=
0
;
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
#endif
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
View file @
546a764e
...
@@ -2,9 +2,11 @@
...
@@ -2,9 +2,11 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <array>
#include <array>
#endif
#include "ck/utility/array.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -34,23 +36,24 @@ struct DeviceGemmMultipleD : public BaseOperator
...
@@ -34,23 +36,24 @@ struct DeviceGemmMultipleD : public BaseOperator
{
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
#ifndef __HIPCC_RTC__
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
std
::
a
rray
<
const
void
*
,
NumDTensor
>
p_ds
,
ck
::
A
rray
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_e
,
ck
::
index_t
M
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideB
,
std
::
a
rray
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
A
rray
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
index_t
StrideE
,
ck
::
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
#endif
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
View file @
546a764e
...
@@ -28,7 +28,7 @@ enum struct GemmSpecialization
...
@@ -28,7 +28,7 @@ enum struct GemmSpecialization
NKOPadding
,
NKOPadding
,
MNKOPadding
,
MNKOPadding
,
};
};
#ifndef __HIPCC_RTC__
inline
std
::
string
getGemmSpecializationString
(
const
GemmSpecialization
&
s
)
inline
std
::
string
getGemmSpecializationString
(
const
GemmSpecialization
&
s
)
{
{
switch
(
s
)
switch
(
s
)
...
@@ -52,6 +52,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
...
@@ -52,6 +52,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
default:
return
"Unrecognized specialization!"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
}
#endif
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
546a764e
...
@@ -3,8 +3,12 @@
...
@@ -3,8 +3,12 @@
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
...
@@ -15,8 +19,6 @@
...
@@ -15,8 +19,6 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -126,7 +128,6 @@ __global__ void
...
@@ -126,7 +128,6 @@ __global__ void
// else
// else
// AccElement = -INFINITY
// AccElement = -INFINITY
// Otherwise, result may be wrong.
// Otherwise, result may be wrong.
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
// B0Layout
typename
BLayout
,
// B0Layout
typename
B1Layout
,
typename
B1Layout
,
...
@@ -430,6 +431,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -430,6 +431,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
matrix_padder
.
PadN
,
matrix_padder
.
PadN
,
MaskOutUpperTriangle
>
;
MaskOutUpperTriangle
>
;
#ifndef __HIPCC_RTC__
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -604,13 +606,103 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -604,13 +606,103 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
#endif
static
constexpr
bool
IsValidCompilationParameter
()
static
constexpr
bool
IsValidCompilationParameter
()
{
{
// TODO: properly implement this check
// TODO: properly implement this check
return
true
;
return
true
;
}
}
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
,
index_t
Gemm1NRaw_
)
{
// check vector load/store
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
)
{
if
(
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
)
{
if
(
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
)
{
if
(
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
)
{
if
(
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B1
if
constexpr
(
is_same_v
<
B1Layout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
B1Layout
,
Col
>
)
{
if
(
NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of C
if
constexpr
(
is_same_v
<
CLayout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
CLayout
,
Col
>
)
{
if
(
MRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
return
true
;
}
#ifndef __HIPCC_RTC__
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
if
(
!
ck
::
is_xdl_supported
())
...
@@ -625,29 +717,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -625,29 +717,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>
?
Gemm1NRaw
:
MRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
)
and
IsSupported
(
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
);
}
}
// polymorphic
// polymorphic
...
@@ -685,7 +760,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -685,7 +760,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b_element_op
,
acc_element_op
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
};
b1_element_op
,
c_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// polymorphic
...
@@ -765,6 +839,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -765,6 +839,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
str
.
str
();
return
str
.
str
();
}
}
#endif
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
struct
Descriptor
{
template
<
class
AGridDescriptor
>
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDescriptor
&
a_grid_desc
)
{
const
auto
a_grid_desc_m_k
=
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
BGridDescriptor
>
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDescriptor
&
b_grid_desc
)
{
const
auto
b_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
B1GridDescriptor
>
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDescriptor
&
b1_grid_desc
)
{
const
auto
b1_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc
);
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
CGridDescriptor
>
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
CGridDescriptor
&
c_grid_desc
)
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc
);
}
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
ADesc
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
using
B1GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
true
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
matrix_padder
.
PadN
,
MaskOutUpperTriangle
>
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
;
CGridDesc_M_N
c_grid_desc_m_n
;
C0MatrixMask
c0_matrix_mask
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
// element-wise op
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
B1ElementwiseOperation
b1_element_op
;
CElementwiseOperation
c_element_op
;
bool
has_main_k_block_loop
=
true
;
bool
is_valid
=
false
;
constexpr
Descriptor
(
ADesc
a
,
BDesc
b
,
B1Desc
b1
,
CDesc
c
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
B1ElementwiseOperation
b1_element_op_
,
CElementwiseOperation
c_element_op_
)
:
a_grid_desc_ak0_m_ak1
{
MakeAGridDescriptor_AK0_M_AK1
(
a
)},
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
block_2_ctile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
c0_matrix_mask
{
c
.
GetLength
(
I1
)},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
b1_element_op
{
b1_element_op_
},
c_element_op
{
c_element_op_
},
is_valid
{
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
block_2_ctile_map
)
and
IsSupported
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
))}
{
}
constexpr
bool
IsValid
()
const
{
return
is_valid
;
}
};
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
static
constexpr
auto
make_descriptor
(
ADesc
a
,
BDesc
b
,
B1Desc
b1
,
CDesc
c
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
B1ElementwiseOperation
b1_element_op
=
B1ElementwiseOperation
{},
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
return
Descriptor
<
ADesc
,
BDesc
,
B1Desc
,
CDesc
>
(
a
,
b
,
b1
,
c
,
a_element_op
,
b_element_op
,
b1_element_op
,
c_element_op
);
}
template
<
class
Desc
>
__device__
static
void
Run
(
const
Desc
&
desc
,
const
float
scale
,
const
ADataType
*
__restrict__
p_a_grid
,
const
ADataType
*
__restrict__
p_b_grid
,
const
ADataType
*
__restrict__
p_b1_grid
,
CDataType
*
__restrict__
p_c_grid
)
{
#ifndef __HIPCC_RTC__
assert
(
desc
.
is_valid
);
#endif
__shared__
char
p_shared_block
[
Desc
::
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
AccElementwiseOperation
acc_element_op
{
scale
};
if
(
desc
.
has_main_k_block_loop
)
{
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
else
{
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
546a764e
...
@@ -2,20 +2,22 @@
...
@@ -2,20 +2,22 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/
common_header
.hpp"
#include "ck/utility/
array
.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -225,9 +227,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -225,9 +227,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
MRaws
,
static
auto
MakeDsGridDescriptor_M_N
(
const
ck
::
A
rray
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
NRaws
,
const
ck
::
A
rray
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
DsStride
)
const
ck
::
A
rray
<
index_t
,
NumDTensor
>&
DsStride
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -308,20 +310,20 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -308,20 +310,20 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
// block-to-e-tile map
// block-to-e-tile map
using
Block2ETileMap
=
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
#ifndef __HIPCC_RTC__
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
void
*
p_a_grid
,
Argument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
const
void
*
p_b_grid
,
std
::
a
rray
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
ck
::
A
rray
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
void
*
p_e_grid
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
std
::
a
rray
<
index_t
,
NumDTensor
>
StrideDs
,
ck
::
A
rray
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -420,7 +422,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -420,7 +422,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
index_t
NRaw_
;
index_t
NRaw_
;
index_t
KRaw_
;
index_t
KRaw_
;
};
};
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
...
@@ -497,95 +498,100 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -497,95 +498,100 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
#endif
static
bool
IsSupported
Argument
(
const
Argument
&
arg
)
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
// check vector load/store
// check vector load/store
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
{
// FIXME: not rigorous
return
false
;
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
}
else
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
}
else
{
return
false
;
}
// check vector laod of B
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
{
// FIXME: not rigorous
return
false
;
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
}
else
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
}
else
{
return
false
;
}
// check vector load of Ds
// check vector load of Ds
// only support RowMajor for now
// only support RowMajor for now
bool
all_valid
=
true
;
bool
all_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
{
all_valid
=
false
;
}
});
if
(
!
all_valid
)
{
{
return
false
;
all_valid
=
false
;
}
}
});
// check vector store of E
if
(
!
all_valid
)
// only support RowMajor for now
{
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
return
false
;
{
}
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
// check vector store of E
return
false
;
// only support RowMajor for now
}
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
}
{
else
if
(
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
return
false
;
return
false
;
}
}
}
}
else
{
return
false
;
}
return
true
;
}
#ifndef __HIPCC_RTC__
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
return
IsSupported
(
arg
.
MRaw_
,
arg
.
NRaw_
,
arg
.
KRaw_
)
and
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
...
@@ -597,17 +603,16 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -597,17 +603,16 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
static
auto
MakeArgument
(
const
void
*
p_a
,
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
std
::
a
rray
<
const
void
*
,
NumDTensor
>
p_ds
,
ck
::
A
rray
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_e
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
std
::
a
rray
<
index_t
,
NumDTensor
>
StrideDs
,
ck
::
A
rray
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -635,14 +640,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -635,14 +640,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
std
::
a
rray
<
const
void
*
,
NumDTensor
>
p_ds
,
ck
::
A
rray
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_e
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
std
::
a
rray
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
A
rray
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -675,11 +680,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -675,11 +680,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
// clang-format off
str
<<
"DeviceGemmMultipleD_Xdl_CShuffle"
str
<<
"DeviceGemmMultipleD_Xdl_CShuffle"
...
@@ -708,6 +715,149 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -708,6 +715,149 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
str
.
str
();
return
str
.
str
();
}
}
#endif
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
struct
Descriptor
{
static
constexpr
auto
ds_tuple
()
{
return
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
DsDesc
{});
}
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
ADesc
{})))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
BDesc
{})))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_tuple
()))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
Block2ETileMap
block_2_etile_map
;
// element-wise op
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
CDEElementwiseOperation
cde_element_op
;
bool
has_main_k_block_loop
=
true
;
bool
is_valid
=
false
;
constexpr
Descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CDEElementwiseOperation
cde_element_op_
)
:
a_grid_desc_ak0_m_ak1
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a
))},
b_grid_desc_bk0_n_bk1
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b
))},
ds_grid_desc_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
ds
))},
e_grid_desc_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
e
))},
block_2_etile_map
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
e
))},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
cde_element_op
{
cde_element_op_
},
is_valid
{
GridwiseGemm
::
CheckValidity
(
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a
)),
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b
),
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
ds
),
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
e
),
block_2_etile_map
)
and
IsSupported
(
e
.
GetLength
(
I0
),
e
.
GetLength
(
I1
),
a
.
GetLength
(
I1
))}
{
}
constexpr
bool
IsValid
()
const
{
return
is_valid
;
}
};
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
static
constexpr
auto
make_descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
CDEElementwiseOperation
cde_element_op
=
CDEElementwiseOperation
{})
{
return
Descriptor
<
ADesc
,
BDesc
,
DsDesc
,
EDesc
>
(
a
,
b
,
ds
,
e
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
template
<
class
Desc
,
class
DsPointer
>
__device__
static
void
Run
(
const
Desc
&
desc
,
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
)
{
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
#ifndef __HIPCC_RTC__
assert
(
desc
.
is_valid
);
#endif
if
(
desc
.
has_main_k_block_loop
)
{
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
desc
.
cde_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_etile_map
);
}
else
{
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
desc
.
cde_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_etile_map
);
}
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
View file @
546a764e
...
@@ -13,6 +13,7 @@ enum struct MaskingSpecialization
...
@@ -13,6 +13,7 @@ enum struct MaskingSpecialization
MaskOutUpperTriangle
MaskOutUpperTriangle
};
};
#ifndef __HIPCC_RTC__
inline
std
::
string
getMaskingSpecializationString
(
const
MaskingSpecialization
&
s
)
inline
std
::
string
getMaskingSpecializationString
(
const
MaskingSpecialization
&
s
)
{
{
switch
(
s
)
switch
(
s
)
...
@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
...
@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
default:
return
"Unrecognized specialization!"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
}
#endif
struct
MaskDisabledPredicate
struct
MaskDisabledPredicate
{
{
...
@@ -53,7 +55,7 @@ struct MaskOutUpperTrianglePredicate
...
@@ -53,7 +55,7 @@ struct MaskOutUpperTrianglePredicate
template
<
typename
MaskOutPredicate
>
template
<
typename
MaskOutPredicate
>
struct
C0MatrixMask_impl
struct
C0MatrixMask_impl
{
{
C0MatrixMask_impl
(
index_t
NRaw
)
:
NRaw_
(
NRaw
),
predicate_
(
MaskOutPredicate
{})
{}
constexpr
C0MatrixMask_impl
(
index_t
NRaw
)
:
NRaw_
(
NRaw
),
predicate_
(
MaskOutPredicate
{})
{}
__host__
__device__
constexpr
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
__host__
__device__
constexpr
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
{
{
...
...
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
546a764e
...
@@ -406,7 +406,7 @@ struct G_NDHW : public BaseTensorLayout
...
@@ -406,7 +406,7 @@ struct G_NDHW : public BaseTensorLayout
template
<
template
<
typename
Layout
,
typename
Layout
,
typename
std
::
enable_if
<
std
::
is_base_of
<
BaseTensorLayout
,
Layout
>
::
value
,
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
ck
::
is_base_of
<
BaseTensorLayout
,
Layout
>
::
value
,
bool
>::
type
=
false
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Layout
&
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Layout
&
)
{
{
os
<<
Layout
::
name
;
os
<<
Layout
::
name
;
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
546a764e
...
@@ -354,6 +354,7 @@ struct FastGelu
...
@@ -354,6 +354,7 @@ struct FastGelu
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
#ifndef __HIPCC_RTC__
template
<
>
template
<
>
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
...
@@ -363,7 +364,7 @@ struct FastGelu
...
@@ -363,7 +364,7 @@ struct FastGelu
y
=
x
*
cdf
;
y
=
x
*
cdf
;
}
}
#endif
// device code, use lower precision "__expf" and "rcp"
// device code, use lower precision "__expf" and "rcp"
template
<
>
template
<
>
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
546a764e
...
@@ -5,10 +5,13 @@
...
@@ -5,10 +5,13 @@
#include "ck/utility/math.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#ifndef __HIPCC_RTC__
#include <limits>
#include <limits>
#include <stdlib.h>
#include <stdlib.h>
#endif
namespace
ck
{
namespace
ck
{
...
@@ -86,16 +89,16 @@ struct BlockToCTileMap_M00_N0_M01
...
@@ -86,16 +89,16 @@ struct BlockToCTileMap_M00_N0_M01
const
auto
M00
=
math
::
integer_divide_ceil
(
M0
,
M01
);
const
auto
M00
=
math
::
integer_divide_ceil
(
M0
,
M01
);
const
auto
m00_n0_m01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
const
auto
m00_n0_m01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_insert_transform
(
1
),
ck
::
make_tuple
(
make_insert_transform
(
1
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
ck
::
make_tuple
(
M00
,
M01
)),
make_pass_through_transform
(
make_tuple
(
N0
))),
make_pass_through_transform
(
ck
::
make_tuple
(
N0
))),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
ck
::
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}));
ck
::
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}));
const
auto
cblockid_to_m00_n0_m01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
const
auto
cblockid_to_m00_n0_m01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
M00
,
N0
,
M01
))),
ck
::
make_tuple
(
make_merge_transform
(
ck
::
make_tuple
(
1
,
M00
,
N0
,
M01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
ck
::
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
ck
::
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_n0_m01_to_m0_n0_block_cluster_adaptor
,
chain_tensor_adaptors
(
m00_n0_m01_to_m0_n0_block_cluster_adaptor
,
...
@@ -120,31 +123,33 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -120,31 +123,33 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
__host__
__device__
const
expr
BlockToCTileMap_M00_N0_M01Adapt
(
default
;
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
default
;
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
&
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
&
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
{
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
__host__
index_t
M01
=
8
)
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
BlockToCTileMap_M00_N0_M01Adapt
(
:
BlockToCTileMap_M00_N0_M01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
{
{
}
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
...
@@ -153,13 +158,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -153,13 +158,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
__host__
__device__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
{
return
true
;
return
true
;
}
}
...
@@ -227,13 +234,13 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -227,13 +234,13 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
* output {1, 2}
* output {1, 2}
*/
*/
return
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
return
ck
::
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
idx_N0_M01_local
/
M01_adapt
);
}
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
__host__
__device__
constexpr
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
const
CTileDim
&
/* c_tile_dim */
)
const
{
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
}
...
@@ -303,9 +310,9 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
...
@@ -303,9 +310,9 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
return
make_tuple
(
idx_ksplit
,
return
ck
::
make_tuple
(
idx_ksplit
,
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
idx_N0_M01_local
/
M01_adapt
);
}
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
...
@@ -402,17 +409,17 @@ struct BlockToCTileMap_M00_N00_M01_N01
...
@@ -402,17 +409,17 @@ struct BlockToCTileMap_M00_N00_M01_N01
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_insert_transform
(
1
),
// swallow the carry from lower dimensions
ck
::
make_tuple
(
make_insert_transform
(
1
),
// swallow the carry from lower dimensions
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
ck
::
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_unmerge_transform
(
ck
::
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
ck
::
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
ck
::
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
M00
,
N00
,
M01
,
N01
))),
ck
::
make_tuple
(
make_merge_transform
(
ck
::
make_tuple
(
1
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
ck
::
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
ck
::
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
...
@@ -521,17 +528,17 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
...
@@ -521,17 +528,17 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
const
auto
ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
const
auto
ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
KSplit
),
ck
::
make_tuple
(
make_pass_through_transform
(
KSplit
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
ck
::
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_unmerge_transform
(
ck
::
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
ck
::
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
ck
::
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor
=
const
auto
c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
KSplit
,
M00
,
N00
,
M01
,
N01
))),
ck
::
make_tuple
(
make_merge_transform
(
ck
::
make_tuple
(
KSplit
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
ck
::
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
ck
::
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_blockid_to_ksplit_m0_n0_block_cluster_adaptor
=
const
auto
c_blockid_to_ksplit_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
chain_tensor_adaptors
(
ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
...
@@ -649,13 +656,13 @@ struct BlockToCTileMap_3DGrid_KSplit
...
@@ -649,13 +656,13 @@ struct BlockToCTileMap_3DGrid_KSplit
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
std
::
make_tuple
(
N0
,
M0
,
k_split
);
return
ck
::
make_tuple
(
N0
,
M0
,
k_split
);
}
}
template
<
typename
TopIdx
>
template
<
typename
TopIdx
>
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
)
const
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
)
const
{
{
return
make_tuple
(
blockIdx
.
z
,
blockIdx
.
y
,
blockIdx
.
x
);
return
ck
::
make_tuple
(
blockIdx
.
z
,
blockIdx
.
y
,
blockIdx
.
x
);
}
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
...
@@ -773,7 +780,7 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -773,7 +780,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
.
get
();
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
.
get
();
uint32_t
best_sk_score
=
uint32_t
best_sk_score
=
std
::
n
umeric
_l
imits
<
int
>::
m
ax
();
// we need to find the smallest sk iters
ck
::
N
umeric
L
imits
<
int
32_t
>::
M
ax
();
// we need to find the smallest sk iters
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
tentative_sk_blocks
++
)
tentative_sk_blocks
++
)
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
546a764e
...
@@ -3,10 +3,11 @@
...
@@ -3,10 +3,11 @@
#pragma once
#pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#ifndef __HIPCC_RTC__
#include <iostream>
#endif
namespace
ck
{
namespace
ck
{
...
@@ -38,7 +39,9 @@ constexpr auto GridwiseGemmPipeline_Selector()
...
@@ -38,7 +39,9 @@ constexpr auto GridwiseGemmPipeline_Selector()
}
}
else
else
{
{
#ifndef __HIPCC_RTC__
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
#endif
}
}
}
}
...
...
include/ck/utility/amd_wave_read_first_lane.hpp
View file @
546a764e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -7,10 +7,12 @@
...
@@ -7,10 +7,12 @@
#include "ck/utility/functional2.hpp"
#include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math.hpp"
#ifndef __HIPCC_RTC__
#include <array>
#include <array>
#include <cstddef>
#include <cstddef>
#include <cstdint>
#include <cstdint>
#include <type_traits>
#include <type_traits>
#endif
namespace
ck
{
namespace
ck
{
namespace
detail
{
namespace
detail
{
...
@@ -37,7 +39,7 @@ struct get_carrier<3>
...
@@ -37,7 +39,7 @@ struct get_carrier<3>
{
{
using
value_type
=
uint32_t
;
using
value_type
=
uint32_t
;
std
::
array
<
std
::
byte
,
3
>
bytes
;
ck
::
byte
bytes
[
3
]
;
static_assert
(
sizeof
(
bytes
)
<=
sizeof
(
value_type
));
static_assert
(
sizeof
(
bytes
)
<=
sizeof
(
value_type
));
// replacement of host std::copy_n()
// replacement of host std::copy_n()
...
@@ -59,24 +61,21 @@ struct get_carrier<3>
...
@@ -59,24 +61,21 @@ struct get_carrier<3>
}
}
// method to trigger template substitution failure
// method to trigger template substitution failure
__device__
carrier
(
const
carrier
&
other
)
noexcept
__device__
carrier
(
const
carrier
&
other
)
noexcept
{
copy_n
(
&
other
.
bytes
[
0
],
3
,
&
bytes
[
0
]);
}
{
copy_n
(
other
.
bytes
.
begin
(),
bytes
.
size
(),
bytes
.
begin
());
}
public:
public:
__device__
carrier
&
operator
=
(
value_type
value
)
noexcept
__device__
carrier
&
operator
=
(
value_type
value
)
noexcept
{
{
copy_n
(
reinterpret_cast
<
const
std
::
byte
*>
(
&
value
),
bytes
.
size
()
,
bytes
.
begin
()
);
copy_n
(
reinterpret_cast
<
const
ck
::
byte
*>
(
&
value
),
3
,
&
bytes
[
0
]
);
return
*
this
;
return
*
this
;
}
}
__device__
operator
value_type
()
const
noexcept
__device__
operator
value_type
()
const
noexcept
{
{
std
::
byte
result
[
sizeof
(
value_type
)];
ck
::
byte
result
[
sizeof
(
value_type
)];
copy_n
(
bytes
.
begin
(),
bytes
.
size
()
,
result
);
copy_n
(
&
bytes
[
0
],
3
,
result
);
return
*
reinterpret_cast
<
const
value_type
*>
(
result
);
return
*
reinterpret_cast
<
const
value_type
*>
(
result
);
}
}
...
@@ -100,17 +99,17 @@ __device__ inline int32_t amd_wave_read_first_lane(int32_t value)
...
@@ -100,17 +99,17 @@ __device__ inline int32_t amd_wave_read_first_lane(int32_t value)
return
__builtin_amdgcn_readfirstlane
(
value
);
return
__builtin_amdgcn_readfirstlane
(
value
);
}
}
template
<
template
<
typename
Object
,
typename
Object
,
typename
=
ck
::
enable_if_t
<
ck
::
is_class
<
Object
>
::
value
&&
typename
=
std
::
enable_if_t
<
std
::
is_class_v
<
Object
>
&&
std
::
is_trivially_copyable
_v
<
Object
>>>
ck
::
is_trivially_copyable
<
Object
>
::
value
>>
__device__
auto
amd_wave_read_first_lane
(
const
Object
&
obj
)
__device__
auto
amd_wave_read_first_lane
(
const
Object
&
obj
)
{
{
using
Size
=
unsigned
;
using
Size
=
unsigned
;
constexpr
Size
SgprSize
=
4
;
constexpr
Size
SgprSize
=
4
;
constexpr
Size
ObjectSize
=
sizeof
(
Object
);
constexpr
Size
ObjectSize
=
sizeof
(
Object
);
auto
*
const
from_obj
=
reinterpret_cast
<
const
std
::
byte
*>
(
&
obj
);
auto
*
const
from_obj
=
reinterpret_cast
<
const
ck
::
byte
*>
(
&
obj
);
alignas
(
Object
)
std
::
byte
to_obj
[
ObjectSize
];
alignas
(
Object
)
ck
::
byte
to_obj
[
ObjectSize
];
constexpr
Size
RemainedSize
=
ObjectSize
%
SgprSize
;
constexpr
Size
RemainedSize
=
ObjectSize
%
SgprSize
;
constexpr
Size
CompleteSgprCopyBoundary
=
ObjectSize
-
RemainedSize
;
constexpr
Size
CompleteSgprCopyBoundary
=
ObjectSize
-
RemainedSize
;
...
...
include/ck/utility/array.hpp
View file @
546a764e
...
@@ -52,7 +52,7 @@ template <typename X, typename... Xs>
...
@@ -52,7 +52,7 @@ template <typename X, typename... Xs>
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
{
{
using
data_type
=
remove_cvref_t
<
X
>
;
using
data_type
=
remove_cvref_t
<
X
>
;
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Xs
>
(
xs
)...};
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
ck
::
forward
<
X
>
(
x
),
ck
::
forward
<
Xs
>
(
xs
)...};
}
}
// make empty array
// make empty array
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment