Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f0fd0263
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "26acdcf44e9e0c64fe0918b9cf59a61ce3339757"
Commit
f0fd0263
authored
Jul 21, 2023
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
4e911f3e
a8fafc3f
Changes
94
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
461 additions
and
170 deletions
+461
-170
CMakeLists.txt
CMakeLists.txt
+89
-14
Dockerfile
Dockerfile
+5
-0
Jenkinsfile
Jenkinsfile
+16
-0
client_example/09_quantization/CMakeLists.txt
client_example/09_quantization/CMakeLists.txt
+2
-0
client_example/11_grouped_conv_bwd_weight/common.hpp
client_example/11_grouped_conv_bwd_weight/common.hpp
+10
-4
client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
+22
-1
client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
+24
-2
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
+20
-7
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
...rouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
+25
-11
client_example/CMakeLists.txt
client_example/CMakeLists.txt
+25
-0
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-0
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+10
-4
example/14_gemm_quantization/CMakeLists.txt
example/14_gemm_quantization/CMakeLists.txt
+3
-1
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
+15
-3
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
+15
-3
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
...d_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
+6
-0
example/40_conv2d_fwd_quantization/CMakeLists.txt
example/40_conv2d_fwd_quantization/CMakeLists.txt
+3
-1
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
...n/gpu/device/convolution_backward_data_specialization.hpp
+1
-2
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
...r_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
+13
-11
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+156
-106
No files found.
CMakeLists.txt
View file @
f0fd0263
...
@@ -5,6 +5,31 @@ project(composable_kernel)
...
@@ -5,6 +5,31 @@ project(composable_kernel)
list
(
APPEND CMAKE_MODULE_PATH
"
${
PROJECT_SOURCE_DIR
}
/cmake"
)
list
(
APPEND CMAKE_MODULE_PATH
"
${
PROJECT_SOURCE_DIR
}
/cmake"
)
if
(
DTYPES
)
add_definitions
(
-DDTYPES
)
if
(
DTYPES MATCHES
"int8"
)
add_definitions
(
-D__int8__
)
endif
()
if
(
DTYPES MATCHES
"fp8"
)
add_definitions
(
-D__fp8__
)
endif
()
if
(
DTYPES MATCHES
"fp16"
)
add_definitions
(
-D__fp16__
)
endif
()
if
(
DTYPES MATCHES
"fp32"
)
add_definitions
(
-D__fp32__
)
endif
()
if
(
DTYPES MATCHES
"fp64"
)
add_definitions
(
-D__fp64__
)
endif
()
if
(
DTYPES MATCHES
"bf16"
)
add_definitions
(
-D__bf16__
)
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
else
()
add_definitions
(
-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__
)
endif
()
enable_testing
()
enable_testing
()
set
(
ROCM_SYMLINK_LIBS OFF
)
set
(
ROCM_SYMLINK_LIBS OFF
)
...
@@ -16,11 +41,24 @@ include(ROCMSetupVersion)
...
@@ -16,11 +41,24 @@ include(ROCMSetupVersion)
include
(
ROCMInstallSymlinks
)
include
(
ROCMInstallSymlinks
)
include
(
ROCMCreatePackage
)
include
(
ROCMCreatePackage
)
include
(
CheckCXXCompilerFlag
)
include
(
CheckCXXCompilerFlag
)
include
(
ROCMCheckTargetIds
)
rocm_setup_version
(
VERSION 0.2.1
)
rocm_setup_version
(
VERSION 0.2.1
)
include
(
TargetFlags
)
include
(
TargetFlags
)
list
(
APPEND CMAKE_PREFIX_PATH
${
CMAKE_INSTALL_PREFIX
}
${
CMAKE_INSTALL_PREFIX
}
/llvm
${
CMAKE_INSTALL_PREFIX
}
/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip
)
list
(
APPEND CMAKE_PREFIX_PATH
${
CMAKE_INSTALL_PREFIX
}
${
CMAKE_INSTALL_PREFIX
}
/llvm
${
CMAKE_INSTALL_PREFIX
}
/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip
)
message
(
"GPU_TARGETS=
${
GPU_TARGETS
}
"
)
message
(
"checking which targets are supported"
)
#This is the list of targets to be used in case GPU_TARGETS is not set on command line
#These targets will be filtered and only supported ones will be used
#Setting GPU_TARGETS on command line will override this list
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
TARGETS
"gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102"
)
message
(
"Supported GPU_TARGETS=
${
DEFAULT_GPU_TARGETS
}
"
)
set
(
AMDGPU_TARGETS
"
${
DEFAULT_GPU_TARGETS
}
"
CACHE STRING
" "
)
find_package
(
hip
)
option
(
USE_BITINT_EXTENSION_INT4,
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
option
(
USE_BITINT_EXTENSION_INT4,
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
option
(
USE_OPT_NAVI3X,
"Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons."
OFF
)
option
(
USE_OPT_NAVI3X,
"Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons."
OFF
)
...
@@ -258,31 +296,68 @@ file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp"
...
@@ -258,31 +296,68 @@ file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp"
file
(
GLOB dir_list RELATIVE
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/*
)
file
(
GLOB dir_list RELATIVE
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/*
)
set
(
CK_DEVICE_INSTANCES
)
set
(
CK_DEVICE_INSTANCES
)
FOREACH
(
subdir_path
${
dir_list
}
)
FOREACH
(
subdir_path
${
dir_list
}
)
IF
(
IS_DIRECTORY
"
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/
${
subdir_path
}
"
)
set
(
target_dir
)
list
(
APPEND CK_DEVICE_INSTANCES device_
${
subdir_path
}
_instance
)
IF
(
IS_DIRECTORY
"
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/
${
subdir_path
}
"
)
ENDIF
()
set
(
cmake_instance
)
file
(
READ
"
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/
${
subdir_path
}
/CMakeLists.txt"
cmake_instance
)
set
(
add_inst 0
)
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp8
\"
"
AND DTYPES MATCHES
"fp8"
)
#message("fp8 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp16
\"
"
AND DTYPES MATCHES
"fp16"
)
#message("fp16 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp32
\"
"
AND DTYPES MATCHES
"fp32"
)
#message("fp32 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp64
\"
"
AND DTYPES MATCHES
"fp64"
)
#message("fp64 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
bf16
\"
"
AND DTYPES MATCHES
"bf16"
)
#message("bf16 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
int8
\"
"
AND DTYPES MATCHES
"int8"
)
#message("int8 instance found!")
set
(
add_inst 1
)
endif
()
if
(
NOT
"
${
cmake_instance
}
"
MATCHES
"DTYPES"
)
#message("instance should be built for all types!")
set
(
add_inst 1
)
endif
()
if
(
add_inst EQUAL 1 OR NOT DEFINED DTYPES
)
list
(
APPEND CK_DEVICE_INSTANCES device_
${
subdir_path
}
_instance
)
endif
()
ENDIF
()
ENDFOREACH
()
ENDFOREACH
()
add_custom_target
(
instances DEPENDS utility;
${
CK_DEVICE_INSTANCES
}
SOURCES
${
INSTANCE_FILES
}
)
add_custom_target
(
instances DEPENDS utility;
${
CK_DEVICE_INSTANCES
}
SOURCES
${
INSTANCE_FILES
}
)
add_subdirectory
(
library
)
rocm_package_setup_component
(
tests
if
(
NOT DEFINED INSTANCES_ONLY
)
rocm_package_setup_component
(
tests
LIBRARY_NAME composablekernel
LIBRARY_NAME composablekernel
PACKAGE_NAME tests
# Prevent -static suffix on package name
PACKAGE_NAME tests
# Prevent -static suffix on package name
)
)
rocm_package_setup_component
(
examples
rocm_package_setup_component
(
examples
LIBRARY_NAME composablekernel
LIBRARY_NAME composablekernel
PACKAGE_NAME examples
PACKAGE_NAME examples
)
)
rocm_package_setup_component
(
profiler
rocm_package_setup_component
(
profiler
LIBRARY_NAME composablekernel
LIBRARY_NAME composablekernel
PACKAGE_NAME ckProfiler
PACKAGE_NAME ckProfiler
)
)
add_subdirectory
(
library
)
add_subdirectory
(
example
)
add_subdirectory
(
example
)
add_subdirectory
(
test
)
add_subdirectory
(
test
)
add_subdirectory
(
profiler
)
add_subdirectory
(
profiler
)
endif
(
)
#Create an interface target for the include only files and call it "composablekernels"
#Create an interface target for the include only files and call it "composablekernels"
include
(
CMakePackageConfigHelpers
)
include
(
CMakePackageConfigHelpers
)
...
...
Dockerfile
View file @
f0fd0263
...
@@ -48,6 +48,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
...
@@ -48,6 +48,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libpthread-stubs0-dev
\
libpthread-stubs0-dev
\
llvm-amdgpu
\
llvm-amdgpu
\
pkg-config
\
pkg-config
\
python
\
python3
\
python3
\
python3-dev
\
python3-dev
\
python3-pip
\
python3-pip
\
...
@@ -63,6 +64,10 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
...
@@ -63,6 +64,10 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
rm
-rf
/var/lib/apt/lists/
*
rm
-rf
/var/lib/apt/lists/
*
#Install latest version of cmake
#Install latest version of cmake
RUN
wget
-qO
/usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip
RUN
gunzip
/usr/local/bin/ninja.gz
RUN
chmod
a+x /usr/local/bin/ninja
RUN
git clone https://github.com/nico/ninjatracing.git
RUN
apt purge
--auto-remove
-y
cmake
RUN
apt purge
--auto-remove
-y
cmake
RUN
apt update
RUN
apt update
RUN
apt
install
-y
software-properties-common lsb-release
RUN
apt
install
-y
software-properties-common lsb-release
...
...
Jenkinsfile
View file @
f0fd0263
...
@@ -749,6 +749,22 @@ pipeline {
...
@@ -749,6 +749,22 @@ pipeline {
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
}
}
}
}
stage
(
"Build CK and run Tests on Navi32"
)
{
when
{
beforeAgent
true
expression
{
!
params
.
RUN_FULL_QA
.
toBoolean
()
}
}
agent
{
label
rocmnode
(
"navi32"
)
}
environment
{
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DDTYPES="fp16;fp32;bf16" -DGPU_TARGETS="gfx1101" """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDTYPES="fp16;fp32;bf16" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
}
steps
{
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
}
}
}
}
}
}
...
...
client_example/09_quantization/CMakeLists.txt
View file @
f0fd0263
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_executable
(
client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp
)
add_executable
(
client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations
)
...
@@ -18,3 +19,4 @@ target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable
...
@@ -18,3 +19,4 @@ target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable
add_executable
(
client_gemm_quantization gemm_quantization.cpp
)
add_executable
(
client_gemm_quantization gemm_quantization.cpp
)
target_link_libraries
(
client_gemm_quantization PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_gemm_quantization PRIVATE composable_kernel::device_operations
)
endif
()
client_example/11_grouped_conv_bwd_weight/common.hpp
View file @
f0fd0263
...
@@ -101,13 +101,15 @@ template <ck::index_t NumDimSpatial,
...
@@ -101,13 +101,15 @@ template <ck::index_t NumDimSpatial,
typename
WeiLayout
,
typename
WeiLayout
,
typename
OutLayout
>
typename
OutLayout
>
bool
run_grouped_conv_bwd_weight
(
bool
run_grouped_conv_bwd_weight
(
ck
::
index_t
G
,
const
ck
::
index_t
G
,
ck
::
index_t
N
,
const
ck
::
index_t
N
,
ck
::
index_t
K
,
const
ck
::
index_t
K
,
ck
::
index_t
C
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
input_left_pads
,
...
@@ -157,6 +159,8 @@ bool run_grouped_conv_bwd_weight(
...
@@ -157,6 +159,8 @@ bool run_grouped_conv_bwd_weight(
input_spatial_lengths
,
input_spatial_lengths
,
filter_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -224,6 +228,8 @@ bool run_grouped_conv_bwd_weight(
...
@@ -224,6 +228,8 @@ bool run_grouped_conv_bwd_weight(
input_spatial_lengths
,
input_spatial_lengths
,
filter_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
View file @
f0fd0263
...
@@ -22,6 +22,15 @@ static constexpr ck::index_t C = 192;
...
@@ -22,6 +22,15 @@ static constexpr ck::index_t C = 192;
static
constexpr
ck
::
index_t
X
=
3
;
static
constexpr
ck
::
index_t
X
=
3
;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_spatial_lengths
{
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_spatial_lengths
{
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_spatial_lengths
{
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
{
1
};
int
main
()
int
main
()
{
{
...
@@ -31,7 +40,19 @@ int main()
...
@@ -31,7 +40,19 @@ int main()
OutDataType
,
OutDataType
,
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
OutLayout
>
(
G
,
N
,
K
,
C
,
{
Wi
},
{
X
},
{
Wo
},
{
1
},
{
1
},
{
1
},
{
1
})
OutLayout
>
(
G
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)
?
EXIT_SUCCESS
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
:
EXIT_FAILURE
;
}
}
client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
View file @
f0fd0263
...
@@ -25,6 +25,17 @@ static constexpr ck::index_t Hi = 28;
...
@@ -25,6 +25,17 @@ static constexpr ck::index_t Hi = 28;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_spatial_lengths
{
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_spatial_lengths
{
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_spatial_lengths
{
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
{
1
,
1
};
int
main
()
int
main
()
{
{
...
@@ -34,8 +45,19 @@ int main()
...
@@ -34,8 +45,19 @@ int main()
OutDataType
,
OutDataType
,
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
OutLayout
>
(
OutLayout
>
(
G
,
G
,
N
,
K
,
C
,
{
Hi
,
Wi
},
{
Y
,
X
},
{
Ho
,
Wo
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
})
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)
?
EXIT_SUCCESS
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
:
EXIT_FAILURE
;
}
}
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
View file @
f0fd0263
...
@@ -28,6 +28,17 @@ static constexpr ck::index_t Wi = 3;
...
@@ -28,6 +28,17 @@ static constexpr ck::index_t Wi = 3;
static
constexpr
ck
::
index_t
Do
=
28
;
static
constexpr
ck
::
index_t
Do
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
3
;
static
constexpr
ck
::
index_t
Wo
=
3
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_spatial_lengths
{
Di
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_spatial_lengths
{
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_spatial_lengths
{
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
{
1
,
1
,
1
};
int
main
()
int
main
()
{
{
...
@@ -41,13 +52,15 @@ int main()
...
@@ -41,13 +52,15 @@ int main()
N
,
N
,
K
,
K
,
C
,
C
,
{
Di
,
Hi
,
Wi
},
input_spatial_lengths
,
{
Z
,
Y
,
X
},
filter_spatial_lengths
,
{
Do
,
Ho
,
Wo
},
output_spatial_lengths
,
{
1
,
1
,
1
},
input_strides
,
{
1
,
1
,
1
},
output_strides
,
{
1
,
1
,
1
},
conv_filter_strides
,
{
1
,
1
,
1
})
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)
?
EXIT_SUCCESS
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
:
EXIT_FAILURE
;
}
}
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
View file @
f0fd0263
...
@@ -28,6 +28,17 @@ static constexpr ck::index_t Wi = 3;
...
@@ -28,6 +28,17 @@ static constexpr ck::index_t Wi = 3;
static
constexpr
ck
::
index_t
Do
=
28
;
static
constexpr
ck
::
index_t
Do
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
3
;
static
constexpr
ck
::
index_t
Wo
=
3
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_spatial_lengths
{
Di
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_spatial_lengths
{
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_spatial_lengths
{
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
{
1
,
1
,
1
};
int
main
()
int
main
()
{
{
...
@@ -37,17 +48,20 @@ int main()
...
@@ -37,17 +48,20 @@ int main()
OutDataType
,
OutDataType
,
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
OutLayout
>
(
G
,
OutLayout
>
(
N
,
G
,
K
,
N
,
C
,
K
,
{
Di
,
Hi
,
Wi
},
C
,
{
Z
,
Y
,
X
},
{
Di
,
Hi
,
Wi
},
{
Do
,
Ho
,
Wo
},
{
Z
,
Y
,
X
},
{
1
,
1
,
1
},
{
Do
,
Ho
,
Wo
},
{
1
,
1
,
1
},
{
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
},
{
1
,
1
,
1
},
{
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
},
{
1
,
1
,
1
})
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
})
?
EXIT_SUCCESS
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
:
EXIT_FAILURE
;
}
}
client_example/CMakeLists.txt
View file @
f0fd0263
...
@@ -2,6 +2,31 @@ cmake_minimum_required(VERSION 3.15)
...
@@ -2,6 +2,31 @@ cmake_minimum_required(VERSION 3.15)
project
(
ck_app
)
project
(
ck_app
)
add_compile_options
(
-std=c++17
)
add_compile_options
(
-std=c++17
)
if
(
DTYPES
)
add_definitions
(
-DDTYPES
)
if
(
DTYPES MATCHES
"int8"
)
add_definitions
(
-D__int8__
)
endif
()
if
(
DTYPES MATCHES
"fp8"
)
add_definitions
(
-D__fp8__
)
endif
()
if
(
DTYPES MATCHES
"fp16"
)
add_definitions
(
-D__fp16__
)
endif
()
if
(
DTYPES MATCHES
"fp32"
)
add_definitions
(
-D__fp32__
)
endif
()
if
(
DTYPES MATCHES
"fp64"
)
add_definitions
(
-D__fp64__
)
endif
()
if
(
DTYPES MATCHES
"bf16"
)
add_definitions
(
-D__bf16__
)
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
else
()
add_definitions
(
-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__
)
endif
()
find_package
(
composable_kernel 1.0.0 COMPONENTS device_operations
)
find_package
(
composable_kernel 1.0.0 COMPONENTS device_operations
)
find_package
(
hip REQUIRED PATHS /opt/rocm
)
find_package
(
hip REQUIRED PATHS /opt/rocm
)
message
(
STATUS
"Build with HIP
${
hip_VERSION
}
"
)
message
(
STATUS
"Build with HIP
${
hip_VERSION
}
"
)
...
...
cmake/EnableCompilerWarnings.cmake
View file @
f0fd0263
...
@@ -67,6 +67,7 @@ else()
...
@@ -67,6 +67,7 @@ else()
-Wunused
-Wunused
-Wno-reserved-identifier
-Wno-reserved-identifier
-Werror
-Werror
-Wno-option-ignored
-Wsign-compare
-Wsign-compare
-Wno-extra-semi-stmt
-Wno-extra-semi-stmt
)
)
...
...
example/01_gemm/CMakeLists.txt
View file @
f0fd0263
...
@@ -2,11 +2,14 @@ add_custom_target(example_gemm_dl)
...
@@ -2,11 +2,14 @@ add_custom_target(example_gemm_dl)
add_example_executable
(
example_gemm_dl_fp32 gemm_dl_fp32.cpp
)
add_example_executable
(
example_gemm_dl_fp32 gemm_dl_fp32.cpp
)
add_example_executable
(
example_gemm_dl_fp16 gemm_dl_fp16.cpp
)
add_example_executable
(
example_gemm_dl_fp16 gemm_dl_fp16.cpp
)
add_example_executable
(
example_gemm_dl_int8 gemm_dl_int8.cpp
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp32
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp32
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
add_dependencies
(
example_gemm_dl example_gemm_dl_int8
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_dl_int8 gemm_dl_int8.cpp
)
add_dependencies
(
example_gemm_dl example_gemm_dl_int8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_dl_int4 gemm_dl_int4.cpp
)
add_example_executable
(
example_gemm_dl_int4 gemm_dl_int4.cpp
)
...
@@ -19,13 +22,16 @@ add_custom_target(example_gemm_xdl)
...
@@ -19,13 +22,16 @@ add_custom_target(example_gemm_xdl)
add_example_executable
(
example_gemm_xdl_fp16 gemm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_fp16 gemm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_xdl_int8 gemm_xdl_int8.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_xdl_int8 gemm_xdl_int8.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_xdl_int4 gemm_xdl_int4.cpp
)
add_example_executable
(
example_gemm_xdl_int4 gemm_xdl_int4.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
...
...
example/14_gemm_quantization/CMakeLists.txt
View file @
f0fd0263
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
# dlops
# dlops
add_example_executable
(
example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp
)
add_example_executable
(
example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp
)
...
@@ -10,4 +11,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -10,4 +11,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp
)
add_example_executable
(
example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
endif
()
\ No newline at end of file
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
View file @
f0fd0263
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "common.hpp"
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_
gnwc_gkxc_gnwk_
xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
using
InDataType
=
BF16
;
using
InDataType
=
BF16
;
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
...
@@ -17,8 +17,20 @@ using OutElementOp = PassThrough;
...
@@ -17,8 +17,20 @@ using OutElementOp = PassThrough;
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
// NDimSpatial
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
ck
::
tensor_layout
::
convolution
::
GNHWC
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
ck
::
tensor_layout
::
convolution
::
GKYXC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
ck
::
tensor_layout
::
convolution
::
GNHWK
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>
,
InDataType
,
// InDataType
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
OutDataType
,
// OutDataType
...
...
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
View file @
f0fd0263
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "common.hpp"
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_
gnwc_gkxc_gnwk_
xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
using
InDataType
=
F16
;
using
InDataType
=
F16
;
using
WeiDataType
=
F16
;
using
WeiDataType
=
F16
;
...
@@ -16,8 +16,20 @@ using OutElementOp = PassThrough;
...
@@ -16,8 +16,20 @@ using OutElementOp = PassThrough;
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
// NDimSpatial
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
ck
::
tensor_layout
::
convolution
::
GNHWC
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
ck
::
tensor_layout
::
convolution
::
GKYXC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
ck
::
tensor_layout
::
convolution
::
GNHWK
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>
,
InDataType
,
// InDataType
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
OutDataType
,
// OutDataType
...
...
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
View file @
f0fd0263
...
@@ -75,6 +75,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -75,6 +75,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
...
@@ -85,6 +87,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -85,6 +87,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
range_copy
(
conv_param
.
input_spatial_lengths_
,
begin
(
input_spatial_lengths
));
range_copy
(
conv_param
.
input_spatial_lengths_
,
begin
(
input_spatial_lengths
));
range_copy
(
conv_param
.
filter_spatial_lengths_
,
begin
(
filter_spatial_lengths
));
range_copy
(
conv_param
.
filter_spatial_lengths_
,
begin
(
filter_spatial_lengths
));
range_copy
(
conv_param
.
output_spatial_lengths_
,
begin
(
output_spatial_lengths
));
range_copy
(
conv_param
.
output_spatial_lengths_
,
begin
(
output_spatial_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
begin
(
input_strides
));
range_copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
begin
(
output_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
range_copy
(
conv_param
.
input_left_pads_
,
begin
(
input_left_pads
));
range_copy
(
conv_param
.
input_left_pads_
,
begin
(
input_left_pads
));
...
@@ -103,6 +107,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -103,6 +107,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
input_spatial_lengths
,
input_spatial_lengths
,
filter_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
example/40_conv2d_fwd_quantization/CMakeLists.txt
View file @
f0fd0263
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
@@ -25,4 +26,5 @@ add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_i
...
@@ -25,4 +26,5 @@ add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_i
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp
)
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp
)
# Conv + bias + tanh perchannel quantization
# Conv + bias + tanh perchannel quantization
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp
)
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp
)
\ No newline at end of file
endif
()
\ No newline at end of file
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
View file @
f0fd0263
...
@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat
...
@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat
switch
(
s
)
switch
(
s
)
{
{
case
ConvolutionBackwardDataSpecialization
::
Default
:
return
"Default"
;
case
ConvolutionBackwardDataSpecialization
::
Default
:
return
"Default"
;
case
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
:
case
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
:
return
"Filter1x1Stride1Pad0"
;
return
"FFilter1x1Stride1Pad0"
;
default:
return
"Unrecognized specialization!"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
}
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
View file @
f0fd0263
...
@@ -27,17 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
...
@@ -27,17 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer
(
const
void
*
p_in
,
MakeArgumentPointer
(
const
void
*
p_in
,
void
*
p_wei
,
void
*
p_wei
,
const
void
*
p_out
,
const
void
*
p_out
,
ck
::
index_t
G
,
const
ck
::
index_t
G
,
ck
::
index_t
N
,
const
ck
::
index_t
N
,
ck
::
index_t
K
,
const
ck
::
index_t
K
,
ck
::
index_t
C
,
const
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
f0fd0263
...
@@ -258,7 +258,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -258,7 +258,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CDEElementwiseOp
>
CDEElementwiseOp
>
{
{
// FIXME
// FIXME
static_assert
(
NDimSpatial
==
2
,
"wrong! only implemented for 2D now"
);
static_assert
(
NDimSpatial
==
2
||
NDimSpatial
==
3
,
"wrong! only implemented for 2D and 3D now"
);
using
DeviceOp
=
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
;
using
DeviceOp
=
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
;
...
@@ -491,130 +492,172 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -491,130 +492,172 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_c_wis_strides
[
i
][
0
];
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_c_wis_strides
[
i
][
0
];
});
});
static
constexpr
auto
NonSpatialDimsNum
=
Number
<
3
>
{};
static
constexpr
auto
DIdx
=
Number
<
NonSpatialDimsNum
>
{};
static
constexpr
auto
HIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
>
{}
:
Number
<
NonSpatialDimsNum
+
1
>
{};
static
constexpr
auto
WIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
+
1
>
{}
:
Number
<
NonSpatialDimsNum
+
2
>
{};
static
constexpr
auto
ZIdx
=
Number
<
NonSpatialDimsNum
>
{};
static
constexpr
auto
YIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
>
{}
:
Number
<
NonSpatialDimsNum
+
1
>
{};
static
constexpr
auto
XIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
+
1
>
{}
:
Number
<
NonSpatialDimsNum
+
2
>
{};
// problem definition
// problem definition
const
index_t
Y
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
Z
=
b_g_k_c_xs_lengths
[
ZIdx
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
Y
=
b_g_k_c_xs_lengths
[
YIdx
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
XIdx
];
const
index_t
ConvStrideH
=
conv_filter_strides_
[
0
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideW
=
conv_filter_strides_
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationH
=
conv_filter_dilations_
[
0
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations_
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
NDimSpatial
==
3
?
ConvStrideD
/
GcdStrideDilationD
:
1
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
for
(
index_t
i_
y
tilde
=
0
;
i_
y
tilde
<
Y
Tilde
;
++
i_
y
tilde
)
for
(
index_t
i_
z
tilde
=
0
;
i_
z
tilde
<
Z
Tilde
;
++
i_
z
tilde
)
{
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
{
// check slice is valid
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
if
(
YDotSlice
*
XDotSlice
<=
0
)
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
YTilde
;
++
i_ytilde
)
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
{
{
continue
;
// check slice is valid
}
const
auto
ZDotSlice
=
NDimSpatial
==
3
?
math
::
integer_divide_ceil
(
Z
-
i_ztilde
,
ZTilde
)
:
1
;
const
auto
a_grid_desc_ak0_m_ak1
=
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
transform_conv_to_gemm
.
template
MakeADescriptor_AK0_M_AK1
<
ALayout
>(
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
if
(
YDotSlice
*
XDotSlice
*
ZDotSlice
<=
0
)
b_g_k_c_xs_lengths
,
{
b_g_k_c_xs_strides
,
continue
;
e_g_n_c_wis_lengths
,
}
e_g_n_c_wis_strides
,
conv_filter_strides
,
std
::
array
<
index_t
,
NDimSpatial
>
tildes
;
conv_filter_dilations
,
if
constexpr
(
NDimSpatial
==
2
)
input_left_pads
,
{
input_right_pads
,
tildes
=
{
i_ytilde
,
i_xtilde
};
{
i_ytilde
,
i_xtilde
});
}
else
if
constexpr
(
NDimSpatial
==
3
)
const
auto
b_grid_desc_bk0_n_bk1
=
{
transform_conv_to_gemm
.
template
MakeBDescriptor_BK0_N_BK1
<
BLayout
>(
tildes
=
{
i_ztilde
,
i_ytilde
,
i_xtilde
};
a_g_n_k_wos_lengths
,
}
a_g_n_k_wos_strides
,
else
b_g_k_c_xs_lengths
,
{
b_g_k_c_xs_strides
,
throw
std
::
runtime_error
(
"wrong! only implemented for 2D and 3D now"
);
e_g_n_c_wis_lengths
,
}
e_g_n_c_wis_strides
,
conv_filter_strides
,
const
auto
a_grid_desc_ak0_m_ak1
=
conv_filter_dilations
,
transform_conv_to_gemm
.
template
MakeADescriptor_AK0_M_AK1
<
ALayout
>(
input_left_pads
,
input_right_pads
,
{
i_ytilde
,
i_xtilde
});
DsGridDesc_M_N
ds_grid_desc_m_n
;
// populate Ds desc
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
ds_grid_desc_m_n
(
i
)
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
DLayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
b_g_k_c_xs_strides
,
ds
_g_n_c_wis_lengths
[
i
]
,
e
_g_n_c_wis_lengths
,
ds
_g_n_c_wis_strides
[
i
]
,
e
_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
{
i_ytilde
,
i_xtilde
});
tildes
);
});
const
auto
e_grid_desc_m_n
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
ELayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
{
i_ytilde
,
i_xtilde
});
// desc for problem definition
const
auto
a_grid_desc_m_k
=
transform_k0_m_k1_to_m_k
(
a_grid_desc_ak0_m_ak1
);
const
auto
b_grid_desc_n_k
=
transform_k0_m_k1_to_m_k
(
b_grid_desc_bk0_n_bk1
);
a_grid_desc_m_k_container_
.
push_back
(
a_grid_desc_m_k
);
b_grid_desc_n_k_container_
.
push_back
(
b_grid_desc_n_k
);
ds_grid_desc_m_n_container_
.
push_back
(
ds_grid_desc_m_n
);
e_grid_desc_m_n_container_
.
push_back
(
e_grid_desc_m_n
);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_
.
push_back
(
a_grid_desc_ak0_m_ak1
);
b_grid_desc_bk0_n_bk1_container_
.
push_back
(
b_grid_desc_bk0_n_bk1
);
// block-to-e-tile-map
auto
block_2_etile_map
=
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
);
block_2_etile_map_container_
.
push_back
(
block_2_etile_map
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
block_2_etile_map
))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
.
push_back
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_
.
push_back
(
const
auto
b_grid_desc_bk0_n_bk1
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
transform_conv_to_gemm
.
template
MakeBDescriptor_BK0_N_BK1
<
BLayout
>(
e_grid_desc_m_n
));
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
tildes
);
DsGridDesc_M_N
ds_grid_desc_m_n
;
// populate Ds desc
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
ds_grid_desc_m_n
(
i
)
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
DLayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_c_wis_lengths
[
i
],
ds_g_n_c_wis_strides
[
i
],
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
tildes
);
});
const
auto
e_grid_desc_m_n
=
transform_conv_to_gemm
.
template
MakeCDescriptor_M_N
<
ELayout
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_c_wis_lengths
,
e_g_n_c_wis_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
tildes
);
// desc for problem definition
const
auto
a_grid_desc_m_k
=
transform_k0_m_k1_to_m_k
(
a_grid_desc_ak0_m_ak1
);
const
auto
b_grid_desc_n_k
=
transform_k0_m_k1_to_m_k
(
b_grid_desc_bk0_n_bk1
);
a_grid_desc_m_k_container_
.
push_back
(
a_grid_desc_m_k
);
b_grid_desc_n_k_container_
.
push_back
(
b_grid_desc_n_k
);
ds_grid_desc_m_n_container_
.
push_back
(
ds_grid_desc_m_n
);
e_grid_desc_m_n_container_
.
push_back
(
e_grid_desc_m_n
);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_
.
push_back
(
a_grid_desc_ak0_m_ak1
);
b_grid_desc_bk0_n_bk1_container_
.
push_back
(
b_grid_desc_bk0_n_bk1
);
// block-to-e-tile-map
auto
block_2_etile_map
=
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
);
block_2_etile_map_container_
.
push_back
(
block_2_etile_map
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
block_2_etile_map
))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
.
push_back
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_
.
push_back
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
));
}
}
}
}
}
}
}
...
@@ -803,7 +846,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -803,7 +846,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// vector load for A matrix from global memory to LDS
// vector load for A matrix from global memory to LDS
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
)
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGK
>
)
{
{
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
ConvK
%
ABlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
ConvK
%
ABlockTransferSrcScalarPerVector
==
0
))
{
{
...
@@ -816,7 +861,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -816,7 +861,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
// vector load for B matrix from global memory to LDS
// vector load for B matrix from global memory to LDS
if
constexpr
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
)
if
constexpr
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
)
{
{
if
(
!
(
BBlockTransferSrcVectorDim
==
1
&&
ConvC
%
BBlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
BBlockTransferSrcVectorDim
==
1
&&
ConvC
%
BBlockTransferSrcScalarPerVector
==
0
))
{
{
...
@@ -835,7 +881,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -835,7 +881,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
if
constexpr
(
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
GC
>
||
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
G_C
>
)
is_same_v
<
DLayout
,
tensor_layout
::
convolution
::
G_C
>
)
...
@@ -859,7 +907,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -859,7 +907,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// vector store for E
// vector store for E
if
constexpr
(
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
if
constexpr
(
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
NHWGC
>
)
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
GNDHWC
>
||
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
NDHWGC
>
)
{
{
// vector store C matrix into global memory
// vector store C matrix into global memory
if
(
!
(
ConvC
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
if
(
!
(
ConvC
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment