Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
74284dd2
Unverified
Commit
74284dd2
authored
Jul 25, 2023
by
rocking
Committed by
GitHub
Jul 25, 2023
Browse files
Merge branch 'develop' into avgpool_bwd
parents
01aeb901
50643dd5
Changes
72
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1455 additions
and
727 deletions
+1455
-727
CMakeLists.txt
CMakeLists.txt
+89
-14
Dockerfile
Dockerfile
+5
-0
Jenkinsfile
Jenkinsfile
+16
-0
client_example/09_quantization/CMakeLists.txt
client_example/09_quantization/CMakeLists.txt
+2
-0
client_example/CMakeLists.txt
client_example/CMakeLists.txt
+25
-0
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-0
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+10
-4
example/14_gemm_quantization/CMakeLists.txt
example/14_gemm_quantization/CMakeLists.txt
+3
-1
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
..._gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
+3
-1
example/40_conv2d_fwd_quantization/CMakeLists.txt
example/40_conv2d_fwd_quantization/CMakeLists.txt
+3
-1
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
...n/gpu/device/convolution_backward_data_specialization.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
..._batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
+4
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+156
-106
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+186
-234
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
...tched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
+53
-65
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+719
-252
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+0
-18
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
...nsor_operation_instance/gpu/convolution_backward_data.hpp
+14
-7
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
...include/ck/library/tensor_operation_instance/gpu/gemm.hpp
+24
-22
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
...nv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
+141
-0
No files found.
CMakeLists.txt
View file @
74284dd2
...
@@ -5,6 +5,31 @@ project(composable_kernel)
...
@@ -5,6 +5,31 @@ project(composable_kernel)
list
(
APPEND CMAKE_MODULE_PATH
"
${
PROJECT_SOURCE_DIR
}
/cmake"
)
list
(
APPEND CMAKE_MODULE_PATH
"
${
PROJECT_SOURCE_DIR
}
/cmake"
)
if
(
DTYPES
)
add_definitions
(
-DDTYPES
)
if
(
DTYPES MATCHES
"int8"
)
add_definitions
(
-D__int8__
)
endif
()
if
(
DTYPES MATCHES
"fp8"
)
add_definitions
(
-D__fp8__
)
endif
()
if
(
DTYPES MATCHES
"fp16"
)
add_definitions
(
-D__fp16__
)
endif
()
if
(
DTYPES MATCHES
"fp32"
)
add_definitions
(
-D__fp32__
)
endif
()
if
(
DTYPES MATCHES
"fp64"
)
add_definitions
(
-D__fp64__
)
endif
()
if
(
DTYPES MATCHES
"bf16"
)
add_definitions
(
-D__bf16__
)
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
else
()
add_definitions
(
-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__
)
endif
()
enable_testing
()
enable_testing
()
set
(
ROCM_SYMLINK_LIBS OFF
)
set
(
ROCM_SYMLINK_LIBS OFF
)
...
@@ -16,11 +41,24 @@ include(ROCMSetupVersion)
...
@@ -16,11 +41,24 @@ include(ROCMSetupVersion)
include
(
ROCMInstallSymlinks
)
include
(
ROCMInstallSymlinks
)
include
(
ROCMCreatePackage
)
include
(
ROCMCreatePackage
)
include
(
CheckCXXCompilerFlag
)
include
(
CheckCXXCompilerFlag
)
include
(
ROCMCheckTargetIds
)
rocm_setup_version
(
VERSION 0.2.0
)
rocm_setup_version
(
VERSION 0.2.0
)
include
(
TargetFlags
)
include
(
TargetFlags
)
list
(
APPEND CMAKE_PREFIX_PATH
${
CMAKE_INSTALL_PREFIX
}
${
CMAKE_INSTALL_PREFIX
}
/llvm
${
CMAKE_INSTALL_PREFIX
}
/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip
)
list
(
APPEND CMAKE_PREFIX_PATH
${
CMAKE_INSTALL_PREFIX
}
${
CMAKE_INSTALL_PREFIX
}
/llvm
${
CMAKE_INSTALL_PREFIX
}
/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip
)
message
(
"GPU_TARGETS=
${
GPU_TARGETS
}
"
)
message
(
"checking which targets are supported"
)
#This is the list of targets to be used in case GPU_TARGETS is not set on command line
#These targets will be filtered and only supported ones will be used
#Setting GPU_TARGETS on command line will override this list
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
TARGETS
"gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102"
)
message
(
"Supported GPU_TARGETS=
${
DEFAULT_GPU_TARGETS
}
"
)
set
(
AMDGPU_TARGETS
"
${
DEFAULT_GPU_TARGETS
}
"
CACHE STRING
" "
)
find_package
(
hip
)
option
(
USE_BITINT_EXTENSION_INT4,
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
option
(
USE_BITINT_EXTENSION_INT4,
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
option
(
USE_OPT_NAVI3X,
"Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons."
OFF
)
option
(
USE_OPT_NAVI3X,
"Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons."
OFF
)
...
@@ -258,31 +296,68 @@ file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp"
...
@@ -258,31 +296,68 @@ file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp"
file
(
GLOB dir_list RELATIVE
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/*
)
file
(
GLOB dir_list RELATIVE
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/*
)
set
(
CK_DEVICE_INSTANCES
)
set
(
CK_DEVICE_INSTANCES
)
FOREACH
(
subdir_path
${
dir_list
}
)
FOREACH
(
subdir_path
${
dir_list
}
)
IF
(
IS_DIRECTORY
"
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/
${
subdir_path
}
"
)
set
(
target_dir
)
list
(
APPEND CK_DEVICE_INSTANCES device_
${
subdir_path
}
_instance
)
IF
(
IS_DIRECTORY
"
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/
${
subdir_path
}
"
)
ENDIF
()
set
(
cmake_instance
)
file
(
READ
"
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/
${
subdir_path
}
/CMakeLists.txt"
cmake_instance
)
set
(
add_inst 0
)
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp8
\"
"
AND DTYPES MATCHES
"fp8"
)
#message("fp8 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp16
\"
"
AND DTYPES MATCHES
"fp16"
)
#message("fp16 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp32
\"
"
AND DTYPES MATCHES
"fp32"
)
#message("fp32 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp64
\"
"
AND DTYPES MATCHES
"fp64"
)
#message("fp64 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
bf16
\"
"
AND DTYPES MATCHES
"bf16"
)
#message("bf16 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
int8
\"
"
AND DTYPES MATCHES
"int8"
)
#message("int8 instance found!")
set
(
add_inst 1
)
endif
()
if
(
NOT
"
${
cmake_instance
}
"
MATCHES
"DTYPES"
)
#message("instance should be built for all types!")
set
(
add_inst 1
)
endif
()
if
(
add_inst EQUAL 1 OR NOT DEFINED DTYPES
)
list
(
APPEND CK_DEVICE_INSTANCES device_
${
subdir_path
}
_instance
)
endif
()
ENDIF
()
ENDFOREACH
()
ENDFOREACH
()
add_custom_target
(
instances DEPENDS utility;
${
CK_DEVICE_INSTANCES
}
SOURCES
${
INSTANCE_FILES
}
)
add_custom_target
(
instances DEPENDS utility;
${
CK_DEVICE_INSTANCES
}
SOURCES
${
INSTANCE_FILES
}
)
add_subdirectory
(
library
)
rocm_package_setup_component
(
tests
if
(
NOT DEFINED INSTANCES_ONLY
)
rocm_package_setup_component
(
tests
LIBRARY_NAME composablekernel
LIBRARY_NAME composablekernel
PACKAGE_NAME tests
# Prevent -static suffix on package name
PACKAGE_NAME tests
# Prevent -static suffix on package name
)
)
rocm_package_setup_component
(
examples
rocm_package_setup_component
(
examples
LIBRARY_NAME composablekernel
LIBRARY_NAME composablekernel
PACKAGE_NAME examples
PACKAGE_NAME examples
)
)
rocm_package_setup_component
(
profiler
rocm_package_setup_component
(
profiler
LIBRARY_NAME composablekernel
LIBRARY_NAME composablekernel
PACKAGE_NAME ckProfiler
PACKAGE_NAME ckProfiler
)
)
add_subdirectory
(
library
)
add_subdirectory
(
example
)
add_subdirectory
(
example
)
add_subdirectory
(
test
)
add_subdirectory
(
test
)
add_subdirectory
(
profiler
)
add_subdirectory
(
profiler
)
endif
(
)
#Create an interface target for the include only files and call it "composablekernels"
#Create an interface target for the include only files and call it "composablekernels"
include
(
CMakePackageConfigHelpers
)
include
(
CMakePackageConfigHelpers
)
...
...
Dockerfile
View file @
74284dd2
...
@@ -48,6 +48,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
...
@@ -48,6 +48,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libpthread-stubs0-dev
\
libpthread-stubs0-dev
\
llvm-amdgpu
\
llvm-amdgpu
\
pkg-config
\
pkg-config
\
python
\
python3
\
python3
\
python3-dev
\
python3-dev
\
python3-pip
\
python3-pip
\
...
@@ -63,6 +64,10 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
...
@@ -63,6 +64,10 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
rm
-rf
/var/lib/apt/lists/
*
rm
-rf
/var/lib/apt/lists/
*
#Install latest version of cmake
#Install latest version of cmake
RUN
wget
-qO
/usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip
RUN
gunzip
/usr/local/bin/ninja.gz
RUN
chmod
a+x /usr/local/bin/ninja
RUN
git clone https://github.com/nico/ninjatracing.git
RUN
apt purge
--auto-remove
-y
cmake
RUN
apt purge
--auto-remove
-y
cmake
RUN
apt update
RUN
apt update
RUN
apt
install
-y
software-properties-common lsb-release
RUN
apt
install
-y
software-properties-common lsb-release
...
...
Jenkinsfile
View file @
74284dd2
...
@@ -749,6 +749,22 @@ pipeline {
...
@@ -749,6 +749,22 @@ pipeline {
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
}
}
}
}
stage
(
"Build CK and run Tests on Navi32"
)
{
when
{
beforeAgent
true
expression
{
!
params
.
RUN_FULL_QA
.
toBoolean
()
}
}
agent
{
label
rocmnode
(
"navi32"
)
}
environment
{
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DDTYPES="fp16;fp32;bf16" -DGPU_TARGETS="gfx1101" """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDTYPES="fp16;fp32;bf16" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
}
steps
{
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
}
}
}
}
}
}
...
...
client_example/09_quantization/CMakeLists.txt
View file @
74284dd2
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_executable
(
client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp
)
add_executable
(
client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations
)
...
@@ -18,3 +19,4 @@ target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable
...
@@ -18,3 +19,4 @@ target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable
add_executable
(
client_gemm_quantization gemm_quantization.cpp
)
add_executable
(
client_gemm_quantization gemm_quantization.cpp
)
target_link_libraries
(
client_gemm_quantization PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_gemm_quantization PRIVATE composable_kernel::device_operations
)
endif
()
client_example/CMakeLists.txt
View file @
74284dd2
...
@@ -2,6 +2,31 @@ cmake_minimum_required(VERSION 3.15)
...
@@ -2,6 +2,31 @@ cmake_minimum_required(VERSION 3.15)
project
(
ck_app
)
project
(
ck_app
)
add_compile_options
(
-std=c++17
)
add_compile_options
(
-std=c++17
)
if
(
DTYPES
)
add_definitions
(
-DDTYPES
)
if
(
DTYPES MATCHES
"int8"
)
add_definitions
(
-D__int8__
)
endif
()
if
(
DTYPES MATCHES
"fp8"
)
add_definitions
(
-D__fp8__
)
endif
()
if
(
DTYPES MATCHES
"fp16"
)
add_definitions
(
-D__fp16__
)
endif
()
if
(
DTYPES MATCHES
"fp32"
)
add_definitions
(
-D__fp32__
)
endif
()
if
(
DTYPES MATCHES
"fp64"
)
add_definitions
(
-D__fp64__
)
endif
()
if
(
DTYPES MATCHES
"bf16"
)
add_definitions
(
-D__bf16__
)
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
else
()
add_definitions
(
-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__
)
endif
()
find_package
(
composable_kernel 1.0.0 COMPONENTS device_operations
)
find_package
(
composable_kernel 1.0.0 COMPONENTS device_operations
)
find_package
(
hip REQUIRED PATHS /opt/rocm
)
find_package
(
hip REQUIRED PATHS /opt/rocm
)
message
(
STATUS
"Build with HIP
${
hip_VERSION
}
"
)
message
(
STATUS
"Build with HIP
${
hip_VERSION
}
"
)
...
...
cmake/EnableCompilerWarnings.cmake
View file @
74284dd2
...
@@ -67,6 +67,7 @@ else()
...
@@ -67,6 +67,7 @@ else()
-Wunused
-Wunused
-Wno-reserved-identifier
-Wno-reserved-identifier
-Werror
-Werror
-Wno-option-ignored
-Wsign-compare
-Wsign-compare
-Wno-extra-semi-stmt
-Wno-extra-semi-stmt
)
)
...
...
example/01_gemm/CMakeLists.txt
View file @
74284dd2
...
@@ -2,11 +2,14 @@ add_custom_target(example_gemm_dl)
...
@@ -2,11 +2,14 @@ add_custom_target(example_gemm_dl)
add_example_executable
(
example_gemm_dl_fp32 gemm_dl_fp32.cpp
)
add_example_executable
(
example_gemm_dl_fp32 gemm_dl_fp32.cpp
)
add_example_executable
(
example_gemm_dl_fp16 gemm_dl_fp16.cpp
)
add_example_executable
(
example_gemm_dl_fp16 gemm_dl_fp16.cpp
)
add_example_executable
(
example_gemm_dl_int8 gemm_dl_int8.cpp
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp32
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp32
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
add_dependencies
(
example_gemm_dl example_gemm_dl_int8
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_dl_int8 gemm_dl_int8.cpp
)
add_dependencies
(
example_gemm_dl example_gemm_dl_int8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_dl_int4 gemm_dl_int4.cpp
)
add_example_executable
(
example_gemm_dl_int4 gemm_dl_int4.cpp
)
...
@@ -19,13 +22,16 @@ add_custom_target(example_gemm_xdl)
...
@@ -19,13 +22,16 @@ add_custom_target(example_gemm_xdl)
add_example_executable
(
example_gemm_xdl_fp16 gemm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_fp16 gemm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_xdl_int8 gemm_xdl_int8.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_xdl_int8 gemm_xdl_int8.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_xdl_int4 gemm_xdl_int4.cpp
)
add_example_executable
(
example_gemm_xdl_int4 gemm_xdl_int4.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
...
...
example/14_gemm_quantization/CMakeLists.txt
View file @
74284dd2
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
# dlops
# dlops
add_example_executable
(
example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp
)
add_example_executable
(
example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp
)
...
@@ -10,4 +11,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -10,4 +11,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp
)
add_example_executable
(
example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
endif
()
\ No newline at end of file
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
View file @
74284dd2
...
@@ -173,6 +173,8 @@ using DeviceGemmInstance =
...
@@ -173,6 +173,8 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
9
,
// D0sTransferSrcVectorDim
4
,
// D0sTransferSrcScalaerPerVector
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
...
@@ -189,7 +191,7 @@ int main(int argc, char* argv[])
...
@@ -189,7 +191,7 @@ int main(int argc, char* argv[])
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
bool
time_kernel
=
fals
e
;
bool
time_kernel
=
tru
e
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
1024
;
ck
::
index_t
M
=
1024
;
...
...
example/40_conv2d_fwd_quantization/CMakeLists.txt
View file @
74284dd2
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
@@ -25,4 +26,5 @@ add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_i
...
@@ -25,4 +26,5 @@ add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_i
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp
)
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp
)
# Conv + bias + tanh perchannel quantization
# Conv + bias + tanh perchannel quantization
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp
)
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp
)
\ No newline at end of file
endif
()
\ No newline at end of file
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
View file @
74284dd2
...
@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat
...
@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat
switch
(
s
)
switch
(
s
)
{
{
case
ConvolutionBackwardDataSpecialization
::
Default
:
return
"Default"
;
case
ConvolutionBackwardDataSpecialization
::
Default
:
return
"Default"
;
case
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
:
case
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
:
return
"Filter1x1Stride1Pad0"
;
return
"FFilter1x1Stride1Pad0"
;
default:
return
"Unrecognized specialization!"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
View file @
74284dd2
...
@@ -196,6 +196,8 @@ template <typename A0Layout,
...
@@ -196,6 +196,8 @@ template <typename A0Layout,
index_t
B0BlockTransferSrcScalarPerVector
,
index_t
B0BlockTransferSrcScalarPerVector
,
index_t
B0BlockTransferDstScalarPerVector_BK1
,
index_t
B0BlockTransferDstScalarPerVector_BK1
,
bool
B0BlockLdsExtraN
,
bool
B0BlockLdsExtraN
,
index_t
CDE0BlockTransferSrcVectorDim
,
index_t
CDE0BlockTransferSrcScalaerPerVector
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
typename
B1BlockTransferSrcAccessOrder
,
...
@@ -492,6 +494,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -492,6 +494,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
B0BlockTransferDstScalarPerVector_BK1
,
B0BlockTransferDstScalarPerVector_BK1
,
true
,
true
,
B0BlockLdsExtraN
,
B0BlockLdsExtraN
,
CDE0BlockTransferSrcVectorDim
,
CDE0BlockTransferSrcScalaerPerVector
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcAccessOrder
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
74284dd2
...
@@ -258,7 +258,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -258,7 +258,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CDEElementwiseOp
>
CDEElementwiseOp
>
{
{
// FIXME
// FIXME
static_assert
(
NDimSpatial
==
2
,
"wrong! only implemented for 2D now"
);
static_assert
(
NDimSpatial
==
2
||
NDimSpatial
==
3
,
"wrong! only implemented for 2D and 3D now"
);
using
DeviceOp
=
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
;
using
DeviceOp
=
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
;
...
@@ -491,130 +492,172 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -491,130 +492,172 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_c_wis_strides
[
i
][
0
];
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_c_wis_strides
[
i
][
0
];
});
});
static
constexpr
auto
NonSpatialDimsNum
=
Number
<
3
>
{};
static
constexpr
auto
DIdx
=
Number
<
NonSpatialDimsNum
>
{};
static
constexpr
auto
HIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
>
{}
:
Number
<
NonSpatialDimsNum
+
1
>
{};
static
constexpr
auto
WIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
+
1
>
{}
:
Number
<
NonSpatialDimsNum
+
2
>
{};
static
constexpr
auto
ZIdx
=
Number
<
NonSpatialDimsNum
>
{};
static
constexpr
auto
YIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
>
{}
:
Number
<
NonSpatialDimsNum
+
1
>
{};
static
constexpr
auto
XIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
+
1
>
{}
:
Number
<
NonSpatialDimsNum
+
2
>
{};
// problem definition
// problem definition
const
index_t
Y
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
Z
=
b_g_k_c_xs_lengths
[
ZIdx
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
Y
=
b_g_k_c_xs_lengths
[
YIdx
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
XIdx
];
const
index_t
ConvStrideH
=
conv_filter_strides_
[
0
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideW
=
conv_filter_strides_
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationH
=
conv_filter_dilations_
[
0
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations_
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
NDimSpatial
==
3
?
ConvStrideD
/
GcdStrideDilationD
:
1
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
for
(
index_t
i_
y
tilde
=
0
;
i_
y
tilde
<
Y
Tilde
;
++
i_
y
tilde
)
for
(
index_t
i_
z
tilde
=
0
;
i_
z
tilde
<
Z
Tilde
;
++
i_
z
tilde
)
{
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
{
// check slice is valid
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
if
(
YDotSlice
*
XDotSlice
<=
0
)
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
YTilde
;
++
i_ytilde
)
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
{
{
continue
;
// check slice is valid
}
const
auto
ZDotSlice
=
NDimSpatial
==
3
?
math
::
integer_divide_ceil
(
Z
-
i_ztilde
,
ZTilde
)
:
1
;
const
auto
a_grid_desc_ak0_m_ak1
=
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
transform_conv_to_gemm
.
template
MakeADescriptor_AK0_M_AK1
<
ALayout
>(
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
if
(
YDotSlice
*
XDotSlice
*
ZDotSlice
<=
0
)
b_g_k_c_xs_lengths
,
{
b_g_k_c_xs_strides
,
continue
;
e_g_n_c_wis_lengths
,
}
e_g_n_c_wis_strides
,
conv_filter_strides
,
std
::
array
<
index_t
,
NDimSpatial
>
tildes
;
conv_filter_dilations
,
if
constexpr
(
NDimSpatial
==
2
)
input_left_pads
,
{
input_right_pads
,
tildes
=
{
i_ytilde
,
i_xtilde
};
{
i_ytilde
,
i_xtilde
});
}
else
if
constexpr
(
NDimSpatial
==
3
)
const
auto
b_grid_desc_bk0_n_bk1
=
{
transform_conv_to_gemm
.
template
MakeBDescriptor_BK0_N_BK1
<
BLayout
>(
tildes
=
{
i_ztilde
,
i_ytilde
,
i_xtilde
};
a_g_n_k_wos_lengths
,
}
a_g_n_k_wos_strides
,
else
b_g_k_c_xs_lengths
,
{
b_g_k_c_xs_strides
,
throw
std
::
runtime_error
(
"wrong! only implemented for 2D and 3D now"
);
e_g_n_c_wis_lengths
,
}
e_g_n_c_wis_strides
,
conv_filter_strides
,
const
auto
a_grid_desc_ak0_m_ak1
=
conv_filter_dilations
,
transform_conv_to_gemm
.
template
MakeADescriptor_AK0_M_AK1
<
ALayout
>(
input_left_pads
,
input_right_pads
,
{
i_ytilde
,
i_xtilde
});
DsGridDesc_M_N
ds_grid_desc_m_n
;
// populate Ds desc
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
ds_grid_desc_m_n
(
i
)
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
DLayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
b_g_k_c_xs_strides
,
ds
_g_n_c_wis_lengths
[
i
]
,
e
_g_n_c_wis_lengths
,
ds
_g_n_c_wis_strides
[
i
]
,
e
_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
{
i_ytilde
,
i_xtilde
});
tildes
);
});
const
auto
e_grid_desc_m_n
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
ELayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
{
i_ytilde
,
i_xtilde
});
// desc for problem definition
const
auto
a_grid_desc_m_k
=
transform_k0_m_k1_to_m_k
(
a_grid_desc_ak0_m_ak1
);
const
auto
b_grid_desc_n_k
=
transform_k0_m_k1_to_m_k
(
b_grid_desc_bk0_n_bk1
);
a_grid_desc_m_k_container_
.
push_back
(
a_grid_desc_m_k
);
b_grid_desc_n_k_container_
.
push_back
(
b_grid_desc_n_k
);
ds_grid_desc_m_n_container_
.
push_back
(
ds_grid_desc_m_n
);
e_grid_desc_m_n_container_
.
push_back
(
e_grid_desc_m_n
);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_
.
push_back
(
a_grid_desc_ak0_m_ak1
);
b_grid_desc_bk0_n_bk1_container_
.
push_back
(
b_grid_desc_bk0_n_bk1
);
// block-to-e-tile-map
auto
block_2_etile_map
=
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
);
block_2_etile_map_container_
.
push_back
(
block_2_etile_map
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
block_2_etile_map
))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
.
push_back
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_
.
push_back
(
const
auto
b_grid_desc_bk0_n_bk1
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
transform_conv_to_gemm
.
template
MakeBDescriptor_BK0_N_BK1
<
BLayout
>(
e_grid_desc_m_n
));
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
tildes
);
DsGridDesc_M_N
ds_grid_desc_m_n
;
// populate Ds desc
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
ds_grid_desc_m_n
(
i
)
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
DLayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_c_wis_lengths
[
i
],
ds_g_n_c_wis_strides
[
i
],
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
tildes
);
});
const
auto
e_grid_desc_m_n
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
ELayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
tildes
);
// desc for problem definition
const
auto
a_grid_desc_m_k
=
transform_k0_m_k1_to_m_k
(
a_grid_desc_ak0_m_ak1
);
const
auto
b_grid_desc_n_k
=
transform_k0_m_k1_to_m_k
(
b_grid_desc_bk0_n_bk1
);
a_grid_desc_m_k_container_
.
push_back
(
a_grid_desc_m_k
);
b_grid_desc_n_k_container_
.
push_back
(
b_grid_desc_n_k
);
ds_grid_desc_m_n_container_
.
push_back
(
ds_grid_desc_m_n
);
e_grid_desc_m_n_container_
.
push_back
(
e_grid_desc_m_n
);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_
.
push_back
(
a_grid_desc_ak0_m_ak1
);
b_grid_desc_bk0_n_bk1_container_
.
push_back
(
b_grid_desc_bk0_n_bk1
);
// block-to-e-tile-map
auto
block_2_etile_map
=
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
);
block_2_etile_map_container_
.
push_back
(
block_2_etile_map
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
block_2_etile_map
))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
.
push_back
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_
.
push_back
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
));
}
}
}
}
}
}
}
...
@@ -803,7 +846,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -803,7 +846,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// vector load for A matrix from global memory to LDS
// vector load for A matrix from global memory to LDS
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
)
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGK
>
)
{
{
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
ConvK
%
ABlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
ConvK
%
ABlockTransferSrcScalarPerVector
==
0
))
{
{
...
@@ -816,7 +861,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -816,7 +861,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
// vector load for B matrix from global memory to LDS
// vector load for B matrix from global memory to LDS
if
constexpr
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
)
if
constexpr
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
)
{
{
if
(
!
(
BBlockTransferSrcVectorDim
==
1
&&
ConvC
%
BBlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
BBlockTransferSrcVectorDim
==
1
&&
ConvC
%
BBlockTransferSrcScalarPerVector
==
0
))
{
{
...
@@ -835,7 +881,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -835,7 +881,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
if
constexpr
(
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
G_C
>
)
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
G_C
>
)
...
@@ -859,7 +907,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -859,7 +907,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// vector store for E
// vector store for E
if
constexpr
(
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
if
constexpr
(
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
NHWGC
>
)
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
GNDHWC
>
||
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
NDHWGC
>
)
{
{
// vector store C matrix into global memory
// vector store C matrix into global memory
if
(
!
(
ConvC
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
if
(
!
(
ConvC
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
74284dd2
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
View file @
74284dd2
...
@@ -67,6 +67,8 @@ template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same dat
...
@@ -67,6 +67,8 @@ template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same dat
index_t
B0BlockTransferDstScalarPerVector_BK1
,
index_t
B0BlockTransferDstScalarPerVector_BK1
,
bool
B0ThreadTransferSrcResetCoordinateAfterRun
,
// ignored
bool
B0ThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
B0BlockLdsExtraN
,
index_t
B0BlockLdsExtraN
,
index_t
CDE0BlockTransferSrcVectorDim
,
index_t
CDE0BlockTransferSrcScalarPerVector
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
typename
B1BlockTransferSrcAccessOrder
,
...
@@ -710,13 +712,13 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -710,13 +712,13 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
constexpr
auto
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
constexpr
auto
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
I1
,
// MRepeat
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
I
1
,
// MWaveId
m
1
,
// MWaveId
I
1
,
// NWaveId
n
1
,
// NWaveId
I1
,
// MPerXdl
m2
,
// MPerXdl
I1
,
// NGroupNum
n2
,
// NGroupNum
I1
,
// NInputNum
n3
,
// NInputNum
n4
));
// registerNum
n4
));
// registerNum
auto
d0s_thread_buf
=
generate_tuple
(
auto
d0s_thread_buf
=
generate_tuple
(
...
@@ -732,8 +734,9 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -732,8 +734,9 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
constexpr
auto
acc0_thread_desc
=
make_naive_tensor_descriptor_packed
(
static_assert
(
CDE0BlockTransferSrcScalarPerVector
<=
n4
,
make_tuple
(
Number
<
Gemm0MXdlPerWave
>
{},
Number
<
Gemm0NXdlPerWave
>
{},
n2
,
n4
));
"vector load must be not greater than n4"
);
static_assert
(
n4
%
CDE0BlockTransferSrcScalarPerVector
==
0
);
auto
d0s_threadwise_copy
=
generate_tuple
(
auto
d0s_threadwise_copy
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -742,10 +745,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -742,10 +745,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
A0B0B1DataType
,
A0B0B1DataType
,
decltype
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
]),
decltype
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
]),
decltype
(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
Sequence
<
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
n4
>
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
9
,
// CDE0BlockTransferSrcVectorDim
n4
,
CDE0BlockTransferSrcScalarPerVector
,
1
,
1
,
false
>
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
false
>
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
...
@@ -898,66 +910,42 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -898,66 +910,42 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
blockwise_gemm0
,
blockwise_gemm0
,
acc0_thread_buf
,
acc0_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// bias+gelu
// multiple d
if
constexpr
(
NumD0Tensor
)
{
{
static_for
<
0
,
Gemm0MXdlPerWave
,
1
>
{}([
&
](
auto
mr
)
{
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
Gemm0NXdlPerWave
,
1
>
{}([
&
](
auto
nr
)
{
d0s_threadwise_copy
(
i
).
Run
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
groupid
)
{
d0s_grid_buf
[
i
],
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0s_threadwise_copy
(
i
).
Run
(
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
d0s_thread_buf
(
i
));
d0s_grid_buf
[
i
],
});
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
static_for
<
0
,
m0
*
n0
*
n2
*
n4
,
1
>
{}([
&
](
auto
i
)
{
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
// get reference to src data
d0s_thread_buf
(
i
));
const
auto
src_data_refs
=
generate_tie
(
});
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
return
d0s_thread_buf
[
iSrc
][
i
];
},
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
i
)
{
Number
<
NumD0Tensor
>
{});
constexpr
index_t
c_offset
=
acc0_thread_desc
.
CalculateOffset
(
make_tuple
(
mr
,
nr
,
groupid
,
i
));
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// get reference to src data
// return type should be lvalue
const
auto
src_data_refs
=
generate_tie
(
[
&
](
auto
)
->
auto
&
{
return
acc0_thread_buf
(
i
);
},
// return type should be lvalue
Number
<
2
>
{});
[
&
](
auto
iSrc
)
->
const
auto
&
{
return
d0s_thread_buf
[
iSrc
][
i
];
unpack2
(
cde0_element_op
,
dst_data_refs
,
src_data_refs
);
},
Number
<
NumD0Tensor
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
)
->
auto
&
{
return
acc0_thread_buf
(
Number
<
c_offset
>
{});
},
Number
<
2
>
{});
unpack2
(
cde0_element_op
,
dst_data_refs
,
src_data_refs
);
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
));
});
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
-
n2
.
value
,
0
,
0
));
});
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
0
,
1
,
-
Gemm0NXdlPerWave
,
0
,
0
,
0
,
0
,
0
,
0
));
});
});
});
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_threadwise_copy
(
i
).
MoveSrcSliceWindow
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
make_multi_index
(
0
,
1
,
-
Gemm0MXdlPerWave
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
});
});
}
}
else
{
static_for
<
0
,
acc0_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
cde0_element_op
(
acc_thread_buf
(
i
),
acc0_thread_buf
[
i
]);
});
}
// gemm1
// gemm1
{
{
// TODO: explore using dynamic buffer for a1 thread buffer
// TODO: explore using dynamic buffer for a1 thread buffer
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
View file @
74284dd2
This diff is collapsed.
Click to expand it.
include/ck/utility/type_convert.hpp
View file @
74284dd2
...
@@ -62,24 +62,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_
...
@@ -62,24 +62,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
}
// convert bfp16 to int32 via fp32
template
<
>
inline
__host__
__device__
constexpr
int32_t
type_convert
<
int32_t
,
bhalf_t
>
(
bhalf_t
x
)
{
float
x_fp32
=
type_convert
<
float
>
(
x
);
return
static_cast
<
int32_t
>
(
x_fp32
);
}
// convert int32 to bfp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
int32_t
>
(
int32_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// convert bfp16 to int8 via fp32
// convert bfp16 to int8 via fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
int8_t
type_convert
<
int8_t
,
bhalf_t
>
(
bhalf_t
x
)
inline
__host__
__device__
constexpr
int8_t
type_convert
<
int8_t
,
bhalf_t
>
(
bhalf_t
x
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
View file @
74284dd2
...
@@ -39,7 +39,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
...
@@ -39,7 +39,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
NWC
,
KXC
,
NWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceConvBwdData
<
1
,
NWC
,
KXC
,
NWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
#ifdef __int8__
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
NWC
,
NWC
,
...
@@ -51,7 +51,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
...
@@ -51,7 +51,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
// conv2d backward data
// conv2d backward data
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
...
@@ -88,7 +88,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
...
@@ -88,7 +88,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#ifdef __int8__
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
NHWC
,
...
@@ -100,7 +100,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
...
@@ -100,7 +100,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
// conv2d dl
// conv2d dl
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
(
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
...
@@ -125,7 +125,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
...
@@ -125,7 +125,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#ifdef __int8__
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
NHWC
,
...
@@ -137,6 +137,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
...
@@ -137,6 +137,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
// conv3d backward data
// conv3d backward data
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
...
@@ -173,7 +174,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
...
@@ -173,7 +174,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#ifdef __int8__
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
NDHWC
,
NDHWC
,
...
@@ -185,7 +186,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
...
@@ -185,7 +186,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
InLayout
,
typename
WeiLayout
,
typename
WeiLayout
,
...
@@ -239,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -239,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
op_ptrs
);
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
op_ptrs
);
}
}
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
OutDataType
,
int8_t
>
)
{
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
op_ptrs
);
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
op_ptrs
);
}
}
#endif
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
...
@@ -266,12 +269,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -266,12 +269,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
}
}
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
OutDataType
,
int8_t
>
)
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
}
}
#endif
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
is_same_v
<
WeiLayout
,
KZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWK
>
)
is_same_v
<
WeiLayout
,
KZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWK
>
)
...
@@ -292,11 +297,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -292,11 +297,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
op_ptrs
);
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
op_ptrs
);
}
}
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
OutDataType
,
int8_t
>
)
{
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
op_ptrs
);
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
op_ptrs
);
}
}
#endif
}
}
return
op_ptrs
;
return
op_ptrs
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
View file @
74284dd2
...
@@ -77,7 +77,7 @@ void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
...
@@ -77,7 +77,7 @@ void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
#ifdef __int8__
void
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
void
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemm
<
Col
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -118,6 +118,27 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
...
@@ -118,6 +118,27 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
DeviceGemm
<
Row
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemm
<
Row
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
void
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -183,26 +204,6 @@ void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
...
@@ -183,26 +204,6 @@ void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
DeviceGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
void
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -388,6 +389,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -388,6 +389,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
op_ptrs
);
}
}
}
}
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
else
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
int8_t
>
)
is_same_v
<
CDataType
,
int8_t
>
)
{
{
...
@@ -420,7 +422,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -420,7 +422,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances
(
op_ptrs
);
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances
(
op_ptrs
);
}
}
}
}
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
0 → 100644
View file @
74284dd2
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment