Commit d92fb7e8 authored by rocking's avatar rocking
Browse files

Merge commit 'a3c910ac' into gemm_softmax

parents bfc80764 a3c910ac
cmake_minimum_required(VERSION 3.5) cmake_minimum_required(VERSION 3.14)
# Check support for CUDA/HIP in Cmake # Check support for CUDA/HIP in Cmake
project(composable_kernel) project(composable_kernel)
...@@ -234,6 +234,8 @@ include_directories(BEFORE ...@@ -234,6 +234,8 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/library/include ${PROJECT_SOURCE_DIR}/library/include
) )
include(googletest)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV) if(BUILD_DEV)
add_compile_options(-Werror) add_compile_options(-Werror)
......
FROM ubuntu:18.04 FROM ubuntu:18.04
ARG ROCMVERSION=5.0 ARG ROCMVERSION=5.1
ARG OSDB_BKC_VERSION ARG OSDB_BKC_VERSION
RUN set -xe RUN set -xe
...@@ -42,7 +42,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -42,7 +42,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libnuma-dev \ libnuma-dev \
libpthread-stubs0-dev \ libpthread-stubs0-dev \
llvm-amdgpu \ llvm-amdgpu \
miopengemm \
pkg-config \ pkg-config \
python \ python \
python3 \ python3 \
...@@ -51,19 +50,15 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -51,19 +50,15 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
python-pip \ python-pip \
python3-pip \ python3-pip \
software-properties-common \ software-properties-common \
sqlite3 \
wget \ wget \
rocm-dev \ rocm-dev \
rocm-device-libs \ rocm-device-libs \
rocm-opencl \
rocm-opencl-dev \
rocm-cmake \ rocm-cmake \
rocblas \
vim \ vim \
zlib1g-dev \ zlib1g-dev \
openssh-server \ openssh-server \
kmod \ clang-format-10 \
mysql-client && \ kmod && \
apt-get clean && \ apt-get clean && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
......
...@@ -140,6 +140,10 @@ def reboot(){ ...@@ -140,6 +140,10 @@ def reboot(){
build job: 'reboot-slaves', propagate: false , parameters: [string(name: 'server', value: "${env.NODE_NAME}"),] build job: 'reboot-slaves', propagate: false , parameters: [string(name: 'server', value: "${env.NODE_NAME}"),]
} }
def buildHipClangJobAndReboot(Map conf=[:]){ def buildHipClangJobAndReboot(Map conf=[:]){
try{ try{
buildHipClangJob(conf) buildHipClangJob(conf)
...@@ -156,6 +160,93 @@ def buildHipClangJobAndReboot(Map conf=[:]){ ...@@ -156,6 +160,93 @@ def buildHipClangJobAndReboot(Map conf=[:]){
} }
} }
def runCKProfiler(Map conf=[:]){
show_node_info()
env.HSA_ENABLE_SDMA=0
checkout scm
def image = "composable_kernels"
def prefixpath = conf.get("prefixpath", "/opt/rocm")
def gpu_arch = conf.get("gpu_arch", "gfx908")
// Jenkins is complaining about the render group
// def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1"
}
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' "
def variant = env.STAGE_NAME
def retimage
gitStatusWrapper(credentialsId: '7126e5fe-eb51-4576-b52b-9aaf1de8f0fd', gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
try {
retimage = docker.build("${image}", dockerArgs + '.')
withDockerContainer(image: image, args: dockerOpts) {
timeout(time: 5, unit: 'MINUTES')
{
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
}
}
}
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
echo "The job was cancelled or aborted"
throw e
}
catch(Exception ex) {
retimage = docker.build("${image}", dockerArgs + "--no-cache .")
withDockerContainer(image: image, args: dockerOpts) {
timeout(time: 5, unit: 'MINUTES')
{
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
}
}
}
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 5, unit: 'HOURS')
{
cmake_build(conf)
dir("script"){
def perf_log = "perf_gemm_${gpu_arch}.log"
def artifact = "profile_gemm_${gpu_arch}.txt"
sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee ${perf_log} ||true"
sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${perf_log} ||true"
sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${perf_log} ||true"
sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${perf_log} || true"
//results will be parsed, stored, and analyzed within the python script
//the script will return 0 if the performance criteria are met
//or return 1 if the criteria are not met
sh "python3 parse_perf_data.py ${perf_log} | tee ${artifact}"
}
}
}
}
return retimage
}
def runPerfTest(Map conf=[:]){
try{
runCKProfiler(conf)
}
catch(e){
echo "throwing error exception in performance tests"
echo 'Exception occurred: ' + e.toString()
throw e
}
finally{
if (!conf.get("no_reboot", false)) {
reboot()
}
}
}
pipeline { pipeline {
agent none agent none
options { options {
...@@ -178,16 +269,17 @@ pipeline { ...@@ -178,16 +269,17 @@ pipeline {
// buildHipClangJobAndReboot(build_cmd: build_cmd, no_reboot:true, prefixpath: '/opt/rocm', build_type: 'debug') // buildHipClangJobAndReboot(build_cmd: build_cmd, no_reboot:true, prefixpath: '/opt/rocm', build_type: 'debug')
// } // }
// } // }
stage('Build Profiler: Release, gfx908') // we will build and run ckProfiler release version later, during the performance test stage
{ //stage('Build Profiler: Release, gfx908')
agent { label rocmnode("nogpu")} //{
environment{ // agent { label rocmnode("nogpu")}
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ // environment{
} // setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
steps{ // }
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') // steps{
} // buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
} // }
//}
stage('Build Profiler: Debug, gfx908') stage('Build Profiler: Debug, gfx908')
{ {
agent { label rocmnode("nogpu")} agent { label rocmnode("nogpu")}
...@@ -204,7 +296,7 @@ pipeline { ...@@ -204,7 +296,7 @@ pipeline {
stage('Clang Format') { stage('Clang Format') {
agent{ label rocmnode("nogpu") } agent{ label rocmnode("nogpu") }
environment{ environment{
execute_cmd = "find . -iname \'*.h\' \ execute_cmd = "find .. -iname \'*.h\' \
-o -iname \'*.hpp\' \ -o -iname \'*.hpp\' \
-o -iname \'*.cpp\' \ -o -iname \'*.cpp\' \
-o -iname \'*.h.in\' \ -o -iname \'*.h.in\' \
...@@ -235,6 +327,35 @@ pipeline { ...@@ -235,6 +327,35 @@ pipeline {
} }
} }
stage("Run Tests: gfx90a")
{
agent{ label rocmnode("gfx90a")}
environment{
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release')
}
}
}
}
stage("Performance Tests")
{
parallel
{
stage("Run ckProfiler: gfx908")
{
agent{ label rocmnode("gfx908")}
environment{
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
}
steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
}
}
} }
} }
......
include(FetchContent)
set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against")
if(GOOGLETEST_DIR)
set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override")
endif()
message(STATUS "Fetching GoogleTest")
list(APPEND GTEST_CMAKE_CXX_FLAGS
-Wno-undef
-Wno-reserved-identifier
-Wno-global-constructors
-Wno-missing-noreturn
-Wno-disabled-macro-expansion
-Wno-used-but-marked-unused
-Wno-switch-enum
-Wno-zero-as-null-pointer-constant
-Wno-unused-member-function
)
message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}")
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG b85864c64758dec007208e56af933fc3f52044ee
)
# Will be necessary for windows build
# set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)
target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
add_example_executable(example_conv2d_fwd_xdl_bias_relu conv2d_fwd_xdl_bias_relu.cpp) add_example_executable(example_conv2d_fwd_xdl_bias_relu conv2d_fwd_xdl_bias_relu.cpp)
target_link_libraries(example_conv2d_fwd_xdl_bias_relu PRIVATE conv_fwd_util)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp) add_example_executable(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp)
target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_fwd_util)
add_example_executable(example_convnd_fwd_xdl convnd_fwd_xdl.cpp) add_example_executable(example_convnd_fwd_xdl convnd_fwd_xdl.cpp)
target_link_libraries(example_convnd_fwd_xdl PRIVATE conv_fwd_util)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_fwd_util)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_fwd_util)
add_example_executable(example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp) add_example_executable(example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp)
target_link_libraries(example_conv2d_bwd_data_xdl PRIVATE conv_fwd_util)
add_example_executable(example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp) add_example_executable(example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp)
target_link_libraries(example_conv2d_bwd_weight_xdl PRIVATE conv_fwd_util)
...@@ -72,8 +72,13 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device:: ...@@ -72,8 +72,13 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on // clang-format on
using ReferenceConvBwdWeightInstance = ck::tensor_operation::host:: using ReferenceConvBwdWeightInstance =
ReferenceConvBwdWeight<InDataType, WeiDataType, OutDataType, InElementOp, WeiElementOp, OutElementOp>; ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
......
...@@ -11,9 +11,10 @@ ...@@ -11,9 +11,10 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_reduce_xdl_cshuffle.hpp" #include "device_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp" #include "reduction_operator.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -36,19 +37,20 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; ...@@ -36,19 +37,20 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum; using D0ReduceOp = ck::reduce::Add<float>;
using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum; using D1ReduceOp = ck::reduce::Add<float>;
using D1ElementOp = ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
static constexpr auto GemmSpecialization = static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default; ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, D1ElementOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
...@@ -162,8 +164,7 @@ int main(int argc, char* argv[]) ...@@ -162,8 +164,7 @@ int main(int argc, char* argv[])
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
auto d0_reduce_op = D0ReduceOp{}; auto d1_element_op = D1ElementOp{};
auto d1_reduce_op = D1ReduceOp{};
// do GEMM // do GEMM
auto gemm = DeviceGemmReduceInstance{}; auto gemm = DeviceGemmReduceInstance{};
...@@ -182,8 +183,7 @@ int main(int argc, char* argv[]) ...@@ -182,8 +183,7 @@ int main(int argc, char* argv[])
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d0_reduce_op, d1_element_op);
d1_reduce_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -242,19 +242,26 @@ int main(int argc, char* argv[]) ...@@ -242,19 +242,26 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
auto d0_reduce_op = D0ReduceOp{};
auto d1_reduce_op = D1ReduceOp{};
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetReduceZeroValue(); float d0_acc = d0_reduce_op.GetReductionZeroVal();
float d1_acc = d1_reduce_op.GetReduceZeroValue(); float d1_acc = d1_reduce_op.GetReductionZeroVal();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
d0_reduce_op.Reduce(d0_acc, c_m_n_host_result(m, n)); float d0_val = ck::type_convert<float>(c_m_n_host_result(m, n));
d1_reduce_op.Reduce(d1_acc, c_m_n_host_result(m, n)); float d1_val;
d1_element_op(d1_val, d0_val);
d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val);
} }
d0_m_host_result(m) = d0_acc; d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
d1_m_host_result(m) = d1_acc; d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
} }
check_error(c_m_n_host_result, c_m_n_device_result); check_error(c_m_n_host_result, c_m_n_device_result);
......
add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp) add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp)
target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_fwd_util)
...@@ -11,9 +11,9 @@ ...@@ -11,9 +11,9 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" #include "device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "reference_batched_gemm.hpp" #include "reference_batched_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -36,19 +36,20 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; ...@@ -36,19 +36,20 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum; using D0ReduceOp = ck::reduce::Add<float>;
using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum; using D1ReduceOp = ck::reduce::Add<float>;
using D1ElementOp = ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
static constexpr auto GemmSpecialization = static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default; ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, D1ElementOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
...@@ -173,6 +174,7 @@ int main(int argc, char* argv[]) ...@@ -173,6 +174,7 @@ int main(int argc, char* argv[])
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
auto d0_reduce_op = D0ReduceOp{}; auto d0_reduce_op = D0ReduceOp{};
auto d1_reduce_op = D1ReduceOp{}; auto d1_reduce_op = D1ReduceOp{};
auto d1_element_op = D1ElementOp{};
// do GEMM // do GEMM
auto batched_gemm = DeviceBatchedGemmReduceInstance{}; auto batched_gemm = DeviceBatchedGemmReduceInstance{};
...@@ -192,8 +194,7 @@ int main(int argc, char* argv[]) ...@@ -192,8 +194,7 @@ int main(int argc, char* argv[])
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d0_reduce_op, d1_element_op,
d1_reduce_op,
BatchCount); BatchCount);
if(!batched_gemm.IsSupportedArgument(argument)) if(!batched_gemm.IsSupportedArgument(argument))
...@@ -258,17 +259,21 @@ int main(int argc, char* argv[]) ...@@ -258,17 +259,21 @@ int main(int argc, char* argv[])
{ {
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetReduceZeroValue(); float d0_acc = d0_reduce_op.GetReductionZeroVal();
float d1_acc = d1_reduce_op.GetReduceZeroValue(); float d1_acc = d1_reduce_op.GetReductionZeroVal();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
d0_reduce_op.Reduce(d0_acc, c_g_m_n_host_result(batch, m, n)); float d0_val = ck::type_convert<float>(c_g_m_n_host_result(m, n));
d1_reduce_op.Reduce(d1_acc, c_g_m_n_host_result(batch, m, n)); float d1_val;
d1_element_op(d1_val, d0_val);
d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val);
} }
d0_g_m_host_result(batch, m) = d0_acc; d0_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d0_acc);
d1_g_m_host_result(batch, m) = d1_acc; d1_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d1_acc);
} }
} }
......
...@@ -21,8 +21,7 @@ template <typename GridwiseGemm, ...@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D1ElementwiseOperation,
typename D1ReduceOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -44,8 +43,7 @@ __global__ void ...@@ -44,8 +43,7 @@ __global__ void
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const D0ReduceOperation d0_reduce_op, const D1ElementwiseOperation d1_element_op,
const D1ReduceOperation d1_reduce_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -82,8 +80,7 @@ __global__ void ...@@ -82,8 +80,7 @@ __global__ void
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d0_reduce_op, d1_element_op,
d1_reduce_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -99,8 +96,7 @@ __global__ void ...@@ -99,8 +96,7 @@ __global__ void
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = d0_reduce_op; ignore = d1_element_op;
ignore = d1_reduce_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -125,6 +121,7 @@ template <typename ALayout, ...@@ -125,6 +121,7 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D0ReduceOperation,
typename D1ReduceOperation, typename D1ReduceOperation,
typename D1ElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -161,8 +158,7 @@ template <typename ALayout, ...@@ -161,8 +158,7 @@ template <typename ALayout,
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation, struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation>
D1ReduceOperation>
{ {
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
...@@ -564,6 +560,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -564,6 +560,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D0ReduceOperation,
D1ReduceOperation, D1ReduceOperation,
D1ElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
...@@ -624,8 +621,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -624,8 +621,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op,
D1ReduceOperation d1_reduce_op,
index_t BatchCount) index_t BatchCount)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
...@@ -648,8 +644,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -648,8 +644,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
d0_reduce_op_{d0_reduce_op}, d1_element_op_{d1_element_op}
d1_reduce_op_{d1_reduce_op}
{ {
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_)) a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
...@@ -684,8 +679,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -684,8 +679,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
D0ReduceOperation d0_reduce_op_; D1ElementwiseOperation d1_element_op_;
D1ReduceOperation d1_reduce_op_;
}; };
// Invoker // Invoker
...@@ -740,8 +734,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -740,8 +734,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation,
D1ReduceOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -763,8 +756,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -763,8 +756,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.d0_reduce_op_, arg.d1_element_op_,
arg.d1_reduce_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
...@@ -782,8 +774,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -782,8 +774,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation,
D1ReduceOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -805,8 +796,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -805,8 +796,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.d0_reduce_op_, arg.d1_element_op_,
arg.d1_reduce_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
...@@ -865,8 +855,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -865,8 +855,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op,
D1ReduceOperation d1_reduce_op,
index_t BatchCount) index_t BatchCount)
{ {
return Argument{p_a, return Argument{p_a,
...@@ -883,8 +872,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -883,8 +872,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d0_reduce_op, d1_element_op,
d1_reduce_op,
BatchCount}; BatchCount};
} }
...@@ -905,8 +893,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -905,8 +893,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op,
D1ReduceOperation d1_reduce_op,
index_t BatchCount) override index_t BatchCount) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
...@@ -923,8 +910,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -923,8 +910,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d0_reduce_op, d1_element_op,
d1_reduce_op,
BatchCount); BatchCount);
} }
......
...@@ -16,6 +16,31 @@ namespace ck { ...@@ -16,6 +16,31 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -25,7 +50,7 @@ template <typename GridwiseGemm, ...@@ -25,7 +50,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename ComputeBasePrtOfBatch, typename ComputePtrOffsetOfBatch,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
...@@ -43,7 +68,7 @@ __global__ void ...@@ -43,7 +68,7 @@ __global__ void
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...@@ -52,11 +77,11 @@ __global__ void ...@@ -52,11 +77,11 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -82,7 +107,7 @@ __global__ void ...@@ -82,7 +107,7 @@ __global__ void
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = compute_base_ptr_of_batch_; ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -256,26 +281,26 @@ struct DeviceBatchedGemmXdl ...@@ -256,26 +281,26 @@ struct DeviceBatchedGemmXdl
return globalblockid_to_m0_n0_block_cluster_adaptor; return globalblockid_to_m0_n0_block_cluster_adaptor;
} }
struct ComputeBasePtrOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB, index_t BatchStrideB,
index_t BatchStrideC) index_t BatchStrideC)
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC) : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
{ {
} }
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideA_); return g_idx * static_cast<long_index_t>(BatchStrideA_);
} }
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB_); return g_idx * static_cast<long_index_t>(BatchStrideB_);
} }
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideC_); return g_idx * static_cast<long_index_t>(BatchStrideC_);
} }
...@@ -359,7 +384,7 @@ struct DeviceBatchedGemmXdl ...@@ -359,7 +384,7 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_base_ptr_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(), compute_ptr_offset_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(),
b_grid_desc_k0_n_k1_.GetElementSpaceSize(), b_grid_desc_k0_n_k1_.GetElementSpaceSize(),
c_grid_desc_m_n_.GetElementSpaceSize()}, c_grid_desc_m_n_.GetElementSpaceSize()},
block_2_ctile_map_{}, block_2_ctile_map_{},
...@@ -388,7 +413,7 @@ struct DeviceBatchedGemmXdl ...@@ -388,7 +413,7 @@ struct DeviceBatchedGemmXdl
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -448,7 +473,7 @@ struct DeviceBatchedGemmXdl ...@@ -448,7 +473,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
ComputeBasePtrOfStridedBatch, ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
true>; true>;
...@@ -467,7 +492,7 @@ struct DeviceBatchedGemmXdl ...@@ -467,7 +492,7 @@ struct DeviceBatchedGemmXdl
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.compute_base_ptr_of_batch_, arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
else else
...@@ -482,7 +507,7 @@ struct DeviceBatchedGemmXdl ...@@ -482,7 +507,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
ComputeBasePtrOfStridedBatch, ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
false>; false>;
...@@ -501,7 +526,7 @@ struct DeviceBatchedGemmXdl ...@@ -501,7 +526,7 @@ struct DeviceBatchedGemmXdl
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.compute_base_ptr_of_batch_, arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
......
...@@ -18,6 +18,9 @@ namespace ck { ...@@ -18,6 +18,9 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
/*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3() \endlink.
*/
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
......
...@@ -9,8 +9,7 @@ namespace device { ...@@ -9,8 +9,7 @@ namespace device {
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D1ElementwiseOperation>
typename D1ReduceOperation>
struct DeviceGemmReduce : public BaseOperator struct DeviceGemmReduce : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -27,8 +26,7 @@ struct DeviceGemmReduce : public BaseOperator ...@@ -27,8 +26,7 @@ struct DeviceGemmReduce : public BaseOperator
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op,
D1ReduceOperation d1_reduce_op,
ck::index_t BatchCount = 1) = 0; ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
...@@ -37,13 +35,11 @@ struct DeviceGemmReduce : public BaseOperator ...@@ -37,13 +35,11 @@ struct DeviceGemmReduce : public BaseOperator
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D1ElementwiseOperation>
typename D1ReduceOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation, using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation>>;
D1ReduceOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -29,6 +29,7 @@ template <typename ALayout, ...@@ -29,6 +29,7 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D0ReduceOperation,
typename D1ReduceOperation, typename D1ReduceOperation,
typename D1ElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -65,8 +66,7 @@ template <typename ALayout, ...@@ -65,8 +66,7 @@ template <typename ALayout,
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation, struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation>
D1ReduceOperation>
{ {
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
...@@ -382,6 +382,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -382,6 +382,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D0ReduceOperation,
D1ReduceOperation, D1ReduceOperation,
D1ElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
...@@ -440,8 +441,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -440,8 +441,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op)
D1ReduceOperation d1_reduce_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -457,8 +457,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -457,8 +457,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
d0_reduce_op_{d0_reduce_op}, d1_element_op_{d1_element_op}
d1_reduce_op_{d1_reduce_op}
{ {
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_)) a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
...@@ -491,8 +490,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -491,8 +490,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
D0ReduceOperation d0_reduce_op_; D1ElementwiseOperation d1_element_op_;
D1ReduceOperation d1_reduce_op_;
}; };
// Invoker // Invoker
...@@ -544,8 +542,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -544,8 +542,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation,
D1ReduceOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -565,8 +562,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -565,8 +562,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.d0_reduce_op_, arg.d1_element_op_,
arg.d1_reduce_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
...@@ -583,8 +579,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -583,8 +579,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
D0ReduceOperation, D1ElementwiseOperation,
D1ReduceOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -604,8 +599,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -604,8 +599,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.d0_reduce_op_, arg.d1_element_op_,
arg.d1_reduce_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
...@@ -655,8 +649,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -655,8 +649,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op)
D1ReduceOperation d1_reduce_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -672,8 +665,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -672,8 +665,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d0_reduce_op, d1_element_op};
d1_reduce_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -693,8 +685,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -693,8 +685,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op, D1ElementwiseOperation d1_element_op,
D1ReduceOperation d1_reduce_op,
index_t /* KBatch */ = 1) override index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
...@@ -711,8 +702,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera ...@@ -711,8 +702,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d0_reduce_op, d1_element_op);
d1_reduce_op);
} }
// polymorphic // polymorphic
......
...@@ -5,20 +5,6 @@ namespace ck { ...@@ -5,20 +5,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
struct ReduceSum
{
__host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); }
__host__ __device__ void Reduce(float& acc, float v) const { acc += v; }
};
struct ReduceSquareSum
{
__host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); }
__host__ __device__ void Reduce(float& acc, float v) const { acc += v * v; }
};
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "blockwise_tensor_slice_transfer_v6r1.hpp" #include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
#include "reduction_functions_threadwise.hpp"
namespace ck { namespace ck {
...@@ -18,8 +19,7 @@ template <typename GridwiseGemm, ...@@ -18,8 +19,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D1ElementwiseOperation,
typename D1ReduceOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -39,8 +39,7 @@ __global__ void ...@@ -39,8 +39,7 @@ __global__ void
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const D0ReduceOperation d0_reduce_op, const D1ElementwiseOperation d1_element_op,
const D1ReduceOperation d1_reduce_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -60,8 +59,7 @@ __global__ void ...@@ -60,8 +59,7 @@ __global__ void
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d0_reduce_op, d1_element_op,
d1_reduce_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -76,8 +74,7 @@ __global__ void ...@@ -76,8 +74,7 @@ __global__ void
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = d0_reduce_op; ignore = d1_element_op;
ignore = d1_reduce_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -97,6 +94,7 @@ template <typename FloatAB, ...@@ -97,6 +94,7 @@ template <typename FloatAB,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D0ReduceOperation,
typename D1ReduceOperation, typename D1ReduceOperation,
typename D1ElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum DGlobalMemoryDataOperation, InMemoryDataOperationEnum DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
...@@ -372,8 +370,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -372,8 +370,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const D0ReduceOperation& d0_reduce_op, const D1ElementwiseOperation& d1_element_op,
const D1ReduceOperation& d1_reduce_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
...@@ -741,13 +738,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -741,13 +738,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
// TODO: this should be implemented as a blockwise reduction // TODO: this should be implemented as a blockwise reduction
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>( auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize()); d_reduce_thread_desc_mperblock.GetElementSpaceSize());
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>( auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize()); d_reduce_thread_desc_mperblock.GetElementSpaceSize());
// reduce: threadwise copy from LDS to VGPR // reduce: threadwise copy from LDS to VGPR
...@@ -763,7 +760,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -763,7 +760,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatCShuffle, FloatCShuffle,
FloatCShuffle, FloatReduceAcc,
decltype(c_reduce_block_desc_mperblock_nperblock), decltype(c_reduce_block_desc_mperblock_nperblock),
decltype(c_reduce_thread_desc_mperblock_nperblock), decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(c_reduce_thread_lengths_mperblock_nperblock), decltype(c_reduce_thread_lengths_mperblock_nperblock),
...@@ -775,7 +772,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -775,7 +772,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// reduce: copy from VGPR to global // reduce: copy from VGPR to global
auto d0_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto d0_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatCShuffle, FloatReduceAcc,
FloatD, FloatD,
decltype(d_reduce_thread_desc_mblock_mperblock), decltype(d_reduce_thread_desc_mblock_mperblock),
decltype(d_grid_desc_mblock_mperblock), decltype(d_grid_desc_mblock_mperblock),
...@@ -840,6 +837,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -840,6 +837,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); c_grid_buf);
using ThreadwiseReduce_D0 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
D0ReduceOperation,
false>;
using ThreadwiseReduce_D1 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
D1ReduceOperation,
false>;
const auto d0_zeroVal = D0ReduceOperation::GetReductionZeroVal();
const auto d1_zeroVal = D0ReduceOperation::GetReductionZeroVal();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d0_thread_buf(I) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d1_thread_buf(I) = d1_zeroVal; });
// reduce // reduce
{ {
// copy from LDS to VGPR // copy from LDS to VGPR
...@@ -850,26 +869,20 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -850,26 +869,20 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_reduce_thread_buf); c_reduce_thread_buf);
// reduce in VGPR // reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) { ThreadwiseReduce_D0::Reduce(c_reduce_thread_buf, d0_thread_buf);
FloatReduceAcc d0_acc = d0_reduce_op.GetReduceZeroValue();
FloatReduceAcc d1_acc = d1_reduce_op.GetReduceZeroValue();
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) { static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto offset = constexpr auto offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset( Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{}; make_tuple(im, in))>{};
d0_reduce_op.Reduce(d0_acc, c_reduce_thread_buf[offset]); d1_element_op(c_reduce_thread_buf(offset), c_reduce_thread_buf(offset));
d1_reduce_op.Reduce(d1_acc, c_reduce_thread_buf[offset]);
}); });
constexpr index_t out_offset =
d_reduce_thread_desc_mperblock.CalculateOffset(make_tuple(im));
d0_thread_buf(Number<out_offset>{}) = d0_acc;
d1_thread_buf(Number<out_offset>{}) = d1_acc;
}); });
ThreadwiseReduce_D1::Reduce(c_reduce_thread_buf, d1_thread_buf);
// copy from VGPR to Global // copy from VGPR to Global
d0_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock, d0_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0), make_tuple(I0, I0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment