Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d92fb7e8
Commit
d92fb7e8
authored
May 09, 2022
by
rocking
Browse files
Merge commit '
a3c910ac
' into gemm_softmax
parents
bfc80764
a3c910ac
Changes
62
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
368 additions
and
190 deletions
+368
-190
CMakeLists.txt
CMakeLists.txt
+3
-1
Dockerfile
Dockerfile
+3
-8
Jenkinsfile
Jenkinsfile
+133
-12
cmake/googletest.cmake
cmake/googletest.cmake
+36
-0
example/06_conv2d_fwd_bias_relu/CMakeLists.txt
example/06_conv2d_fwd_bias_relu/CMakeLists.txt
+1
-0
example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt
example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt
+1
-0
example/09_convnd_fwd/CMakeLists.txt
example/09_convnd_fwd/CMakeLists.txt
+3
-0
example/10_conv2d_bwd_data/CMakeLists.txt
example/10_conv2d_bwd_data/CMakeLists.txt
+1
-0
example/11_conv2d_bwd_weight/CMakeLists.txt
example/11_conv2d_bwd_weight/CMakeLists.txt
+1
-0
example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp
example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp
+7
-2
example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp
example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp
+28
-21
example/17_convnd_bwd_data_xdl/CMakeLists.txt
example/17_convnd_bwd_data_xdl/CMakeLists.txt
+1
-0
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
...e/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
+26
-21
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
...on/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
+18
-32
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+46
-21
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...on/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
+4
-8
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
..._operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
+14
-24
include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp
...r_operation/gpu/element/element_wise_reduce_operation.hpp
+0
-14
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+39
-26
No files found.
CMakeLists.txt
View file @
d92fb7e8
cmake_minimum_required
(
VERSION 3.
5
)
cmake_minimum_required
(
VERSION 3.
14
)
# Check support for CUDA/HIP in Cmake
project
(
composable_kernel
)
...
...
@@ -234,6 +234,8 @@ include_directories(BEFORE
${
PROJECT_SOURCE_DIR
}
/library/include
)
include
(
googletest
)
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
if
(
BUILD_DEV
)
add_compile_options
(
-Werror
)
...
...
Dockerfile
View file @
d92fb7e8
FROM
ubuntu:18.04
ARG
ROCMVERSION=5.
0
ARG
ROCMVERSION=5.
1
ARG
OSDB_BKC_VERSION
RUN
set
-xe
...
...
@@ -42,7 +42,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libnuma-dev
\
libpthread-stubs0-dev
\
llvm-amdgpu
\
miopengemm
\
pkg-config
\
python
\
python3
\
...
...
@@ -51,19 +50,15 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
python-pip
\
python3-pip
\
software-properties-common
\
sqlite3
\
wget
\
rocm-dev
\
rocm-device-libs
\
rocm-opencl
\
rocm-opencl-dev
\
rocm-cmake
\
rocblas
\
vim
\
zlib1g-dev
\
openssh-server
\
kmod
\
mysql-client
&&
\
clang-format-10
\
kmod
&&
\
apt-get clean
&&
\
rm
-rf
/var/lib/apt/lists/
*
...
...
Jenkinsfile
View file @
d92fb7e8
...
...
@@ -140,6 +140,10 @@ def reboot(){
build
job:
'reboot-slaves'
,
propagate:
false
,
parameters:
[
string
(
name:
'server'
,
value:
"${env.NODE_NAME}"
),]
}
def
buildHipClangJobAndReboot
(
Map
conf
=[:]){
try
{
buildHipClangJob
(
conf
)
...
...
@@ -156,6 +160,93 @@ def buildHipClangJobAndReboot(Map conf=[:]){
}
}
def
runCKProfiler
(
Map
conf
=[:]){
show_node_info
()
env
.
HSA_ENABLE_SDMA
=
0
checkout
scm
def
image
=
"composable_kernels"
def
prefixpath
=
conf
.
get
(
"prefixpath"
,
"/opt/rocm"
)
def
gpu_arch
=
conf
.
get
(
"gpu_arch"
,
"gfx908"
)
// Jenkins is complaining about the render group
// def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def
dockerOpts
=
"--device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if
(
conf
.
get
(
"enforce_xnack_on"
,
false
))
{
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1"
}
def
dockerArgs
=
"--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' "
def
variant
=
env
.
STAGE_NAME
def
retimage
gitStatusWrapper
(
credentialsId:
'7126e5fe-eb51-4576-b52b-9aaf1de8f0fd'
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCmSoftwarePlatform'
,
repo:
'composable_kernel'
)
{
try
{
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
'.'
)
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
timeout
(
time:
5
,
unit:
'MINUTES'
)
{
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
}
}
}
catch
(
org
.
jenkinsci
.
plugins
.
workflow
.
steps
.
FlowInterruptedException
e
){
echo
"The job was cancelled or aborted"
throw
e
}
catch
(
Exception
ex
)
{
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
"--no-cache ."
)
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
timeout
(
time:
5
,
unit:
'MINUTES'
)
{
sh
'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
}
}
}
withDockerContainer
(
image:
image
,
args:
dockerOpts
+
' -v=/var/jenkins/:/var/jenkins'
)
{
timeout
(
time:
5
,
unit:
'HOURS'
)
{
cmake_build
(
conf
)
dir
(
"script"
){
def
perf_log
=
"perf_gemm_${gpu_arch}.log"
def
artifact
=
"profile_gemm_${gpu_arch}.txt"
sh
"./profile_gemm.sh gemm 0 0 0 1 0 5 | tee ${perf_log} ||true"
sh
"./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${perf_log} ||true"
sh
"./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${perf_log} ||true"
sh
"./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${perf_log} || true"
//results will be parsed, stored, and analyzed within the python script
//the script will return 0 if the performance criteria are met
//or return 1 if the criteria are not met
sh
"python3 parse_perf_data.py ${perf_log} | tee ${artifact}"
}
}
}
}
return
retimage
}
def
runPerfTest
(
Map
conf
=[:]){
try
{
runCKProfiler
(
conf
)
}
catch
(
e
){
echo
"throwing error exception in performance tests"
echo
'Exception occurred: '
+
e
.
toString
()
throw
e
}
finally
{
if
(!
conf
.
get
(
"no_reboot"
,
false
))
{
reboot
()
}
}
}
pipeline
{
agent
none
options
{
...
...
@@ -178,18 +269,19 @@ pipeline {
// buildHipClangJobAndReboot(build_cmd: build_cmd, no_reboot:true, prefixpath: '/opt/rocm', build_type: 'debug')
// }
// }
stage
(
'Build Profiler: Release, gfx908'
)
{
agent
{
label
rocmnode
(
"nogpu"
)}
environment
{
setup_args
=
""" -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
}
steps
{
buildHipClangJobAndReboot
(
setup_args:
setup_args
,
config_targets:
"ckProfiler"
,
no_reboot:
true
,
build_type:
'Release'
)
}
}
// we will build and run ckProfiler release version later, during the performance test stage
//stage('Build Profiler: Release, gfx908')
//{
// agent { label rocmnode("nogpu")}
// environment{
// setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
// }
// steps{
// buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
// }
//}
stage
(
'Build Profiler: Debug, gfx908'
)
{
{
agent
{
label
rocmnode
(
"nogpu"
)}
environment
{
setup_args
=
""" -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
...
...
@@ -204,7 +296,7 @@ pipeline {
stage
(
'Clang Format'
)
{
agent
{
label
rocmnode
(
"nogpu"
)
}
environment
{
execute_cmd
=
"find . -iname \'*.h\' \
execute_cmd
=
"find .
.
-iname \'*.h\' \
-o -iname \'*.hpp\' \
-o -iname \'*.cpp\' \
-o -iname \'*.h.in\' \
...
...
@@ -235,6 +327,35 @@ pipeline {
}
}
stage
(
"Run Tests: gfx90a"
)
{
agent
{
label
rocmnode
(
"gfx90a"
)}
environment
{
setup_args
=
""" -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """
}
steps
{
buildHipClangJobAndReboot
(
setup_args:
setup_args
,
config_targets:
"check"
,
no_reboot:
true
,
build_type:
'Release'
)
}
}
}
}
stage
(
"Performance Tests"
)
{
parallel
{
stage
(
"Run ckProfiler: gfx908"
)
{
agent
{
label
rocmnode
(
"gfx908"
)}
environment
{
setup_args
=
""" -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
}
steps
{
runPerfTest
(
setup_args:
setup_args
,
config_targets:
"ckProfiler"
,
no_reboot:
true
,
build_type:
'Release'
)
}
}
}
}
...
...
cmake/googletest.cmake
0 → 100644
View file @
d92fb7e8
include
(
FetchContent
)
set
(
GOOGLETEST_DIR
""
CACHE STRING
"Location of local GoogleTest repo to build against"
)
if
(
GOOGLETEST_DIR
)
set
(
FETCHCONTENT_SOURCE_DIR_GOOGLETEST
${
GOOGLETEST_DIR
}
CACHE STRING
"GoogleTest source directory override"
)
endif
()
message
(
STATUS
"Fetching GoogleTest"
)
list
(
APPEND GTEST_CMAKE_CXX_FLAGS
-Wno-undef
-Wno-reserved-identifier
-Wno-global-constructors
-Wno-missing-noreturn
-Wno-disabled-macro-expansion
-Wno-used-but-marked-unused
-Wno-switch-enum
-Wno-zero-as-null-pointer-constant
-Wno-unused-member-function
)
message
(
STATUS
"Suppressing googltest warnings with flags:
${
GTEST_CMAKE_CXX_FLAGS
}
"
)
FetchContent_Declare
(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG b85864c64758dec007208e56af933fc3f52044ee
)
# Will be necessary for windows build
# set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable
(
googletest
)
target_compile_options
(
gtest PRIVATE
${
GTEST_CMAKE_CXX_FLAGS
}
)
target_compile_options
(
gtest_main PRIVATE
${
GTEST_CMAKE_CXX_FLAGS
}
)
example/06_conv2d_fwd_bias_relu/CMakeLists.txt
View file @
d92fb7e8
add_example_executable
(
example_conv2d_fwd_xdl_bias_relu conv2d_fwd_xdl_bias_relu.cpp
)
target_link_libraries
(
example_conv2d_fwd_xdl_bias_relu PRIVATE conv_fwd_util
)
example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt
View file @
d92fb7e8
add_example_executable
(
example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp
)
target_link_libraries
(
example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_fwd_util
)
example/09_convnd_fwd/CMakeLists.txt
View file @
d92fb7e8
add_example_executable
(
example_convnd_fwd_xdl convnd_fwd_xdl.cpp
)
target_link_libraries
(
example_convnd_fwd_xdl PRIVATE conv_fwd_util
)
add_example_executable
(
example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp
)
target_link_libraries
(
example_convnd_fwd_xdl_int8 PRIVATE conv_fwd_util
)
add_example_executable
(
example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp
)
target_link_libraries
(
example_convnd_fwd_xdl_fp16 PRIVATE conv_fwd_util
)
example/10_conv2d_bwd_data/CMakeLists.txt
View file @
d92fb7e8
add_example_executable
(
example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp
)
target_link_libraries
(
example_conv2d_bwd_data_xdl PRIVATE conv_fwd_util
)
example/11_conv2d_bwd_weight/CMakeLists.txt
View file @
d92fb7e8
add_example_executable
(
example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp
)
target_link_libraries
(
example_conv2d_bwd_weight_xdl PRIVATE conv_fwd_util
)
example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp
View file @
d92fb7e8
...
...
@@ -72,8 +72,13 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
using
ReferenceConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
InDataType
,
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
using
ReferenceConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
InDataType
,
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp
View file @
d92fb7e8
...
...
@@ -11,9 +11,10 @@
#include "device_tensor.hpp"
#include "device_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "
element_wise_
reduc
e
_operat
ion
.hpp"
#include "reduc
tion
_operat
or
.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -33,22 +34,23 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
tensor_operation
::
element_wise
::
ReduceSum
;
using
D1ReduceOp
=
ck
::
tensor_operation
::
element_wise
::
ReduceSquareSum
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
reduce
::
Add
<
float
>
;
using
D1ReduceOp
=
ck
::
reduce
::
Add
<
float
>
;
using
D1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
float
,
float
,
false
>
;
static
constexpr
auto
GemmSpecialization
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmReduceInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1|
D1EleOp|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
AElementOp
,
BElementOp
,
CElementOp
,
D0ReduceOp
,
D1ReduceOp
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
AElementOp
,
BElementOp
,
CElementOp
,
D0ReduceOp
,
D1ReduceOp
,
D1ElementOp
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
@@ -159,11 +161,10 @@ int main(int argc, char* argv[])
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
d0_reduce_op
=
D0ReduceOp
{};
auto
d1_reduce_op
=
D1ReduceOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
d1_element_op
=
D1ElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmReduceInstance
{};
...
...
@@ -182,8 +183,7 @@ int main(int argc, char* argv[])
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
);
d1_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -242,19 +242,26 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
auto
d0_reduce_op
=
D0ReduceOp
{};
auto
d1_reduce_op
=
D1ReduceOp
{};
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
float
d0_acc
=
d0_reduce_op
.
GetReduc
e
ZeroVal
ue
();
float
d1_acc
=
d1_reduce_op
.
GetReduc
e
ZeroVal
ue
();
float
d0_acc
=
d0_reduce_op
.
GetReduc
tion
ZeroVal
();
float
d1_acc
=
d1_reduce_op
.
GetReduc
tion
ZeroVal
();
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
d0_reduce_op
.
Reduce
(
d0_acc
,
c_m_n_host_result
(
m
,
n
));
d1_reduce_op
.
Reduce
(
d1_acc
,
c_m_n_host_result
(
m
,
n
));
float
d0_val
=
ck
::
type_convert
<
float
>
(
c_m_n_host_result
(
m
,
n
));
float
d1_val
;
d1_element_op
(
d1_val
,
d0_val
);
d0_reduce_op
(
d0_acc
,
d0_val
);
d1_reduce_op
(
d1_acc
,
d1_val
);
}
d0_m_host_result
(
m
)
=
d0_acc
;
d1_m_host_result
(
m
)
=
d1_acc
;
d0_m_host_result
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
d0_acc
)
;
d1_m_host_result
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
d1_acc
)
;
}
check_error
(
c_m_n_host_result
,
c_m_n_device_result
);
...
...
example/17_convnd_bwd_data_xdl/CMakeLists.txt
View file @
d92fb7e8
add_example_executable
(
example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp
)
target_link_libraries
(
example_convnd_bwd_data_xdl PRIVATE conv_fwd_util
)
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
View file @
d92fb7e8
...
...
@@ -11,9 +11,9 @@
#include "device_tensor.hpp"
#include "device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "reference_batched_gemm.hpp"
#include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -33,22 +33,23 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
tensor_operation
::
element_wise
::
ReduceSum
;
using
D1ReduceOp
=
ck
::
tensor_operation
::
element_wise
::
ReduceSquareSum
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
reduce
::
Add
<
float
>
;
using
D1ReduceOp
=
ck
::
reduce
::
Add
<
float
>
;
using
D1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
float
,
float
,
false
>
;
static
constexpr
auto
GemmSpecialization
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceBatchedGemmReduceInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1|
D1EleOp|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
AElementOp
,
BElementOp
,
CElementOp
,
D0ReduceOp
,
D1ReduceOp
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
AElementOp
,
BElementOp
,
CElementOp
,
D0ReduceOp
,
D1ReduceOp
,
D1ElementOp
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
// clang-format on
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
@@ -168,11 +169,12 @@ int main(int argc, char* argv[])
a_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_g_k_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
d0_reduce_op
=
D0ReduceOp
{};
auto
d1_reduce_op
=
D1ReduceOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
d0_reduce_op
=
D0ReduceOp
{};
auto
d1_reduce_op
=
D1ReduceOp
{};
auto
d1_element_op
=
D1ElementOp
{};
// do GEMM
auto
batched_gemm
=
DeviceBatchedGemmReduceInstance
{};
...
...
@@ -192,8 +194,7 @@ int main(int argc, char* argv[])
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
,
d1_element_op
,
BatchCount
);
if
(
!
batched_gemm
.
IsSupportedArgument
(
argument
))
...
...
@@ -258,17 +259,21 @@ int main(int argc, char* argv[])
{
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
float
d0_acc
=
d0_reduce_op
.
GetReduc
e
ZeroVal
ue
();
float
d1_acc
=
d1_reduce_op
.
GetReduc
e
ZeroVal
ue
();
float
d0_acc
=
d0_reduce_op
.
GetReduc
tion
ZeroVal
();
float
d1_acc
=
d1_reduce_op
.
GetReduc
tion
ZeroVal
();
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
d0_reduce_op
.
Reduce
(
d0_acc
,
c_g_m_n_host_result
(
batch
,
m
,
n
));
d1_reduce_op
.
Reduce
(
d1_acc
,
c_g_m_n_host_result
(
batch
,
m
,
n
));
float
d0_val
=
ck
::
type_convert
<
float
>
(
c_g_m_n_host_result
(
m
,
n
));
float
d1_val
;
d1_element_op
(
d1_val
,
d0_val
);
d0_reduce_op
(
d0_acc
,
d0_val
);
d1_reduce_op
(
d1_acc
,
d1_val
);
}
d0_g_m_host_result
(
batch
,
m
)
=
d0_acc
;
d1_g_m_host_result
(
batch
,
m
)
=
d1_acc
;
d0_g_m_host_result
(
batch
,
m
)
=
ck
::
type_convert
<
DDataType
>
(
d0_acc
)
;
d1_g_m_host_result
(
batch
,
m
)
=
ck
::
type_convert
<
DDataType
>
(
d1_acc
)
;
}
}
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
d92fb7e8
...
...
@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
,
typename
D1ElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -44,8 +43,7 @@ __global__ void
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
D0ReduceOperation
d0_reduce_op
,
const
D1ReduceOperation
d1_reduce_op
,
const
D1ElementwiseOperation
d1_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -82,8 +80,7 @@ __global__ void
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
,
d1_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -99,8 +96,7 @@ __global__ void
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
d0_reduce_op
;
ignore
=
d1_reduce_op
;
ignore
=
d1_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -125,6 +121,7 @@ template <typename ALayout,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
,
typename
D1ElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -161,8 +158,7 @@ template <typename ALayout,
struct
DeviceBatchedGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
>
D1ElementwiseOperation
>
{
using
DeviceOp
=
DeviceBatchedGemmReduce_Xdl_CShuffle
;
...
...
@@ -564,6 +560,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
,
D1ElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_AK0_M_AK1
,
...
...
@@ -624,8 +621,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ReduceOperation
d1_reduce_op
,
D1ElementwiseOperation
d1_element_op
,
index_t
BatchCount
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
...
...
@@ -648,8 +644,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
d0_reduce_op_
{
d0_reduce_op
},
d1_reduce_op_
{
d1_reduce_op
}
d1_element_op_
{
d1_element_op
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
))
...
...
@@ -684,8 +679,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
D0ReduceOperation
d0_reduce_op_
;
D1ReduceOperation
d1_reduce_op_
;
D1ElementwiseOperation
d1_element_op_
;
};
// Invoker
...
...
@@ -740,8 +734,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
,
D1ElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -763,8 +756,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
d0_reduce_op_
,
arg
.
d1_reduce_op_
,
arg
.
d1_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
...
...
@@ -782,8 +774,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
,
D1ElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -805,8 +796,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
d0_reduce_op_
,
arg
.
d1_reduce_op_
,
arg
.
d1_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
...
...
@@ -865,8 +855,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ReduceOperation
d1_reduce_op
,
D1ElementwiseOperation
d1_element_op
,
index_t
BatchCount
)
{
return
Argument
{
p_a
,
...
...
@@ -883,8 +872,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
,
d1_element_op
,
BatchCount
};
}
...
...
@@ -905,8 +893,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ReduceOperation
d1_reduce_op
,
D1ElementwiseOperation
d1_element_op
,
index_t
BatchCount
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
...
@@ -923,8 +910,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
,
d1_element_op
,
BatchCount
);
}
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
d92fb7e8
...
...
@@ -16,6 +16,31 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
...
...
@@ -25,7 +50,7 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Compute
BasePr
tOfBatch
,
typename
Compute
PtrOffse
tOfBatch
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
...
...
@@ -43,7 +68,7 @@ __global__ void
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Compute
BasePr
tOfBatch
compute_
base_ptr
_of_batch
_
,
const
Compute
PtrOffse
tOfBatch
compute_
ptr_offset
_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -52,11 +77,11 @@ __global__ void
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_
base_ptr
_of_batch
_
.
GetA
BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_
ptr_offset
_of_batch
.
GetA
PtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_
base_ptr
_of_batch
_
.
GetB
BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_
ptr_offset
_of_batch
.
GetB
PtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_
base_ptr
_of_batch
_
.
GetC
BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_
ptr_offset
_of_batch
.
GetC
PtrOffset
(
g_idx
)));
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -82,7 +107,7 @@ __global__ void
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
compute_
base_ptr
_of_batch
_
;
ignore
=
compute_
ptr_offset
_of_batch
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
...
...
@@ -256,26 +281,26 @@ struct DeviceBatchedGemmXdl
return
globalblockid_to_m0_n0_block_cluster_adaptor
;
}
struct
Compute
BasePtr
OfStridedBatch
struct
Compute
PtrOffset
OfStridedBatch
{
Compute
BasePtr
OfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideC
)
Compute
PtrOffset
OfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideC
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideC_
(
BatchStrideC
)
{
}
__host__
__device__
constexpr
long_index_t
GetA
BasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetA
PtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetB
BasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetB
PtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetC
BasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetC
PtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
...
...
@@ -359,9 +384,9 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceBatchedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
compute_
base_ptr
_of_batch_
{
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
(),
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
(),
c_grid_desc_m_n_
.
GetElementSpaceSize
()},
compute_
ptr_offset
_of_batch_
{
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
(),
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
(),
c_grid_desc_m_n_
.
GetElementSpaceSize
()},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
...
...
@@ -388,7 +413,7 @@ struct DeviceBatchedGemmXdl
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
Compute
BasePtr
OfStridedBatch
compute_
base_ptr
_of_batch_
;
Compute
PtrOffset
OfStridedBatch
compute_
ptr_offset
_of_batch_
;
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
...
...
@@ -448,7 +473,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Compute
BasePtr
OfStridedBatch
,
Compute
PtrOffset
OfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
true
>
;
...
...
@@ -467,7 +492,7 @@ struct DeviceBatchedGemmXdl
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_
base_ptr
_of_batch_
,
arg
.
compute_
ptr_offset
_of_batch_
,
arg
.
block_2_ctile_map_
);
}
else
...
...
@@ -482,7 +507,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Compute
BasePtr
OfStridedBatch
,
Compute
PtrOffset
OfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
false
>
;
...
...
@@ -501,7 +526,7 @@ struct DeviceBatchedGemmXdl
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_
base_ptr
_of_batch_
,
arg
.
compute_
ptr_offset
_of_batch_
,
arg
.
block_2_ctile_map_
);
}
...
...
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
d92fb7e8
...
...
@@ -18,6 +18,9 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
/*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3() \endlink.
*/
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
View file @
d92fb7e8
...
...
@@ -9,8 +9,7 @@ namespace device {
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
>
typename
D1ElementwiseOperation
>
struct
DeviceGemmReduce
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
...
...
@@ -27,8 +26,7 @@ struct DeviceGemmReduce : public BaseOperator
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ReduceOperation
d1_reduce_op
,
D1ElementwiseOperation
d1_element_op
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
...
...
@@ -37,13 +35,11 @@ struct DeviceGemmReduce : public BaseOperator
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
>
typename
D1ElementwiseOperation
>
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
>>
;
D1ElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
View file @
d92fb7e8
...
...
@@ -29,6 +29,7 @@ template <typename ALayout,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
,
typename
D1ElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -65,8 +66,7 @@ template <typename ALayout,
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
>
D1ElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemmReduce_Xdl_CShuffle
;
...
...
@@ -382,6 +382,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
,
D1ElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_AK0_M_AK1
,
...
...
@@ -440,8 +441,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ReduceOperation
d1_reduce_op
)
D1ElementwiseOperation
d1_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
...
...
@@ -457,8 +457,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
d0_reduce_op_
{
d0_reduce_op
},
d1_reduce_op_
{
d1_reduce_op
}
d1_element_op_
{
d1_element_op
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
))
...
...
@@ -491,8 +490,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
D0ReduceOperation
d0_reduce_op_
;
D1ReduceOperation
d1_reduce_op_
;
D1ElementwiseOperation
d1_element_op_
;
};
// Invoker
...
...
@@ -544,8 +542,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
,
D1ElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -565,8 +562,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
d0_reduce_op_
,
arg
.
d1_reduce_op_
,
arg
.
d1_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
...
...
@@ -583,8 +579,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
,
D1ElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -604,8 +599,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
d0_reduce_op_
,
arg
.
d1_reduce_op_
,
arg
.
d1_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
...
...
@@ -655,8 +649,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ReduceOperation
d1_reduce_op
)
D1ElementwiseOperation
d1_element_op
)
{
return
Argument
{
p_a
,
p_b
,
...
...
@@ -672,8 +665,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
};
d1_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -693,8 +685,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ReduceOperation
d1_reduce_op
,
D1ElementwiseOperation
d1_element_op
,
index_t
/* KBatch */
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
...
@@ -711,8 +702,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
);
d1_element_op
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp
View file @
d92fb7e8
...
...
@@ -5,20 +5,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
element_wise
{
struct
ReduceSum
{
__host__
__device__
static
constexpr
float
GetReduceZeroValue
()
{
return
float
(
0
);
}
__host__
__device__
void
Reduce
(
float
&
acc
,
float
v
)
const
{
acc
+=
v
;
}
};
struct
ReduceSquareSum
{
__host__
__device__
static
constexpr
float
GetReduceZeroValue
()
{
return
float
(
0
);
}
__host__
__device__
void
Reduce
(
float
&
acc
,
float
v
)
const
{
acc
+=
v
*
v
;
}
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
d92fb7e8
...
...
@@ -8,6 +8,7 @@
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "reduction_functions_threadwise.hpp"
namespace
ck
{
...
...
@@ -18,8 +19,7 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
,
typename
D1ElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -39,8 +39,7 @@ __global__ void
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
D0ReduceOperation
d0_reduce_op
,
const
D1ReduceOperation
d1_reduce_op
,
const
D1ElementwiseOperation
d1_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -60,8 +59,7 @@ __global__ void
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
,
d1_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -76,8 +74,7 @@ __global__ void
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
d0_reduce_op
;
ignore
=
d1_reduce_op
;
ignore
=
d1_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -97,6 +94,7 @@ template <typename FloatAB,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
,
typename
D1ElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
DGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
...
...
@@ -372,8 +370,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
D0ReduceOperation
&
d0_reduce_op
,
const
D1ReduceOperation
&
d1_reduce_op
,
const
D1ElementwiseOperation
&
d1_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
...
...
@@ -741,13 +738,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{}));
// TODO: this should be implemented as a blockwise reduction
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
CShuffle
>
(
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
ReduceAcc
>
(
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
CShuffle
>
(
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
ReduceAcc
>
(
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
auto
d1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
CShuffle
>
(
auto
d1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
ReduceAcc
>
(
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
// reduce: threadwise copy from LDS to VGPR
...
...
@@ -763,7 +760,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto
c_reduce_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatCShuffle
,
Float
CShuffle
,
Float
ReduceAcc
,
decltype
(
c_reduce_block_desc_mperblock_nperblock
),
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
c_reduce_thread_lengths_mperblock_nperblock
),
...
...
@@ -775,7 +772,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// reduce: copy from VGPR to global
auto
d0_reduce_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
Float
CShuffle
,
Float
ReduceAcc
,
FloatD
,
decltype
(
d_reduce_thread_desc_mblock_mperblock
),
decltype
(
d_grid_desc_mblock_mperblock
),
...
...
@@ -840,6 +837,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
using
ThreadwiseReduce_D0
=
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_reduce_thread_desc_mperblock
),
D0ReduceOperation
,
false
>
;
using
ThreadwiseReduce_D1
=
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_reduce_thread_desc_mperblock
),
D1ReduceOperation
,
false
>
;
const
auto
d0_zeroVal
=
D0ReduceOperation
::
GetReductionZeroVal
();
const
auto
d1_zeroVal
=
D0ReduceOperation
::
GetReductionZeroVal
();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d0_thread_buf
(
I
)
=
d0_zeroVal
;
});
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d1_thread_buf
(
I
)
=
d1_zeroVal
;
});
// reduce
{
// copy from LDS to VGPR
...
...
@@ -850,26 +869,20 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_reduce_thread_buf
);
// reduce in VGPR
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
FloatReduceAcc
d0_acc
=
d0_reduce_op
.
GetReduceZeroValue
();
FloatReduceAcc
d1_acc
=
d1_reduce_op
.
GetReduceZeroValue
();
ThreadwiseReduce_D0
::
Reduce
(
c_reduce_thread_buf
,
d0_thread_buf
);
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
static_for
<
0
,
nreduce_per_thread
,
1
>
{}([
&
](
auto
in
)
{
constexpr
auto
offset
=
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
make_tuple
(
im
,
in
))
>
{};
d0_reduce_op
.
Reduce
(
d0_acc
,
c_reduce_thread_buf
[
offset
]);
d1_reduce_op
.
Reduce
(
d1_acc
,
c_reduce_thread_buf
[
offset
]);
d1_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
});
constexpr
index_t
out_offset
=
d_reduce_thread_desc_mperblock
.
CalculateOffset
(
make_tuple
(
im
));
d0_thread_buf
(
Number
<
out_offset
>
{})
=
d0_acc
;
d1_thread_buf
(
Number
<
out_offset
>
{})
=
d1_acc
;
});
ThreadwiseReduce_D1
::
Reduce
(
c_reduce_thread_buf
,
d1_thread_buf
);
// copy from VGPR to Global
d0_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment