Commit d92fb7e8 authored by rocking's avatar rocking
Browse files

Merge commit 'a3c910ac' into gemm_softmax

parents bfc80764 a3c910ac
cmake_minimum_required(VERSION 3.5)
cmake_minimum_required(VERSION 3.14)
# Check support for CUDA/HIP in Cmake
project(composable_kernel)
......@@ -234,6 +234,8 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/library/include
)
include(googletest)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
......
FROM ubuntu:18.04
ARG ROCMVERSION=5.0
ARG ROCMVERSION=5.1
ARG OSDB_BKC_VERSION
RUN set -xe
......@@ -42,7 +42,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libnuma-dev \
libpthread-stubs0-dev \
llvm-amdgpu \
miopengemm \
pkg-config \
python \
python3 \
......@@ -51,19 +50,15 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
python-pip \
python3-pip \
software-properties-common \
sqlite3 \
wget \
rocm-dev \
rocm-device-libs \
rocm-opencl \
rocm-opencl-dev \
rocm-cmake \
rocblas \
vim \
zlib1g-dev \
openssh-server \
kmod \
mysql-client && \
clang-format-10 \
kmod && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
......
......@@ -140,6 +140,10 @@ def reboot(){
build job: 'reboot-slaves', propagate: false , parameters: [string(name: 'server', value: "${env.NODE_NAME}"),]
}
def buildHipClangJobAndReboot(Map conf=[:]){
try{
buildHipClangJob(conf)
......@@ -156,6 +160,93 @@ def buildHipClangJobAndReboot(Map conf=[:]){
}
}
def runCKProfiler(Map conf=[:]){
show_node_info()
env.HSA_ENABLE_SDMA=0
checkout scm
def image = "composable_kernels"
def prefixpath = conf.get("prefixpath", "/opt/rocm")
def gpu_arch = conf.get("gpu_arch", "gfx908")
// Jenkins is complaining about the render group
// def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1"
}
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' "
def variant = env.STAGE_NAME
def retimage
gitStatusWrapper(credentialsId: '7126e5fe-eb51-4576-b52b-9aaf1de8f0fd', gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
try {
retimage = docker.build("${image}", dockerArgs + '.')
withDockerContainer(image: image, args: dockerOpts) {
timeout(time: 5, unit: 'MINUTES')
{
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
}
}
}
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
echo "The job was cancelled or aborted"
throw e
}
catch(Exception ex) {
retimage = docker.build("${image}", dockerArgs + "--no-cache .")
withDockerContainer(image: image, args: dockerOpts) {
timeout(time: 5, unit: 'MINUTES')
{
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
}
}
}
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 5, unit: 'HOURS')
{
cmake_build(conf)
dir("script"){
def perf_log = "perf_gemm_${gpu_arch}.log"
def artifact = "profile_gemm_${gpu_arch}.txt"
sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee ${perf_log} ||true"
sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${perf_log} ||true"
sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${perf_log} ||true"
sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${perf_log} || true"
//results will be parsed, stored, and analyzed within the python script
//the script will return 0 if the performance criteria are met
//or return 1 if the criteria are not met
sh "python3 parse_perf_data.py ${perf_log} | tee ${artifact}"
}
}
}
}
return retimage
}
def runPerfTest(Map conf=[:]){
try{
runCKProfiler(conf)
}
catch(e){
echo "throwing error exception in performance tests"
echo 'Exception occurred: ' + e.toString()
throw e
}
finally{
if (!conf.get("no_reboot", false)) {
reboot()
}
}
}
pipeline {
agent none
options {
......@@ -178,18 +269,19 @@ pipeline {
// buildHipClangJobAndReboot(build_cmd: build_cmd, no_reboot:true, prefixpath: '/opt/rocm', build_type: 'debug')
// }
// }
stage('Build Profiler: Release, gfx908')
{
agent { label rocmnode("nogpu")}
environment{
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
}
}
// we will build and run ckProfiler release version later, during the performance test stage
//stage('Build Profiler: Release, gfx908')
//{
// agent { label rocmnode("nogpu")}
// environment{
// setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
// }
// steps{
// buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
// }
//}
stage('Build Profiler: Debug, gfx908')
{
{
agent { label rocmnode("nogpu")}
environment{
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
......@@ -204,7 +296,7 @@ pipeline {
stage('Clang Format') {
agent{ label rocmnode("nogpu") }
environment{
execute_cmd = "find . -iname \'*.h\' \
execute_cmd = "find .. -iname \'*.h\' \
-o -iname \'*.hpp\' \
-o -iname \'*.cpp\' \
-o -iname \'*.h.in\' \
......@@ -235,6 +327,35 @@ pipeline {
}
}
stage("Run Tests: gfx90a")
{
agent{ label rocmnode("gfx90a")}
environment{
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release')
}
}
}
}
stage("Performance Tests")
{
parallel
{
stage("Run ckProfiler: gfx908")
{
agent{ label rocmnode("gfx908")}
environment{
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
}
steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
}
}
}
}
......
include(FetchContent)
set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against")
if(GOOGLETEST_DIR)
set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override")
endif()
message(STATUS "Fetching GoogleTest")
list(APPEND GTEST_CMAKE_CXX_FLAGS
-Wno-undef
-Wno-reserved-identifier
-Wno-global-constructors
-Wno-missing-noreturn
-Wno-disabled-macro-expansion
-Wno-used-but-marked-unused
-Wno-switch-enum
-Wno-zero-as-null-pointer-constant
-Wno-unused-member-function
)
message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}")
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG b85864c64758dec007208e56af933fc3f52044ee
)
# Will be necessary for windows build
# set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)
target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
add_example_executable(example_conv2d_fwd_xdl_bias_relu conv2d_fwd_xdl_bias_relu.cpp)
target_link_libraries(example_conv2d_fwd_xdl_bias_relu PRIVATE conv_fwd_util)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp)
target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_fwd_util)
add_example_executable(example_convnd_fwd_xdl convnd_fwd_xdl.cpp)
target_link_libraries(example_convnd_fwd_xdl PRIVATE conv_fwd_util)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_fwd_util)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_fwd_util)
add_example_executable(example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp)
target_link_libraries(example_conv2d_bwd_data_xdl PRIVATE conv_fwd_util)
add_example_executable(example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp)
target_link_libraries(example_conv2d_bwd_weight_xdl PRIVATE conv_fwd_util)
......@@ -72,8 +72,13 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
using ReferenceConvBwdWeightInstance = ck::tensor_operation::host::
ReferenceConvBwdWeight<InDataType, WeiDataType, OutDataType, InElementOp, WeiElementOp, OutElementOp>;
using ReferenceConvBwdWeightInstance =
ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
int main(int argc, char* argv[])
{
......
......@@ -11,9 +11,10 @@
#include "device_tensor.hpp"
#include "device_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"
#include "reduction_operator.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -33,22 +34,23 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum;
using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add<float>;
using D1ReduceOp = ck::reduce::Add<float>;
using D1ElementOp = ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, D1ElementOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......@@ -159,11 +161,10 @@ int main(int argc, char* argv[])
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto d0_reduce_op = D0ReduceOp{};
auto d1_reduce_op = D1ReduceOp{};
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto d1_element_op = D1ElementOp{};
// do GEMM
auto gemm = DeviceGemmReduceInstance{};
......@@ -182,8 +183,7 @@ int main(int argc, char* argv[])
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op);
d1_element_op);
if(!gemm.IsSupportedArgument(argument))
{
......@@ -242,19 +242,26 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
auto d0_reduce_op = D0ReduceOp{};
auto d1_reduce_op = D1ReduceOp{};
for(int m = 0; m < M; ++m)
{
float d0_acc = d0_reduce_op.GetReduceZeroValue();
float d1_acc = d1_reduce_op.GetReduceZeroValue();
float d0_acc = d0_reduce_op.GetReductionZeroVal();
float d1_acc = d1_reduce_op.GetReductionZeroVal();
for(int n = 0; n < N; ++n)
{
d0_reduce_op.Reduce(d0_acc, c_m_n_host_result(m, n));
d1_reduce_op.Reduce(d1_acc, c_m_n_host_result(m, n));
float d0_val = ck::type_convert<float>(c_m_n_host_result(m, n));
float d1_val;
d1_element_op(d1_val, d0_val);
d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val);
}
d0_m_host_result(m) = d0_acc;
d1_m_host_result(m) = d1_acc;
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
}
check_error(c_m_n_host_result, c_m_n_device_result);
......
add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp)
target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_fwd_util)
......@@ -11,9 +11,9 @@
#include "device_tensor.hpp"
#include "device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "reference_batched_gemm.hpp"
#include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -33,22 +33,23 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum;
using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add<float>;
using D1ReduceOp = ck::reduce::Add<float>;
using D1ElementOp = ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, D1ElementOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
......@@ -168,11 +169,12 @@ int main(int argc, char* argv[])
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto d0_reduce_op = D0ReduceOp{};
auto d1_reduce_op = D1ReduceOp{};
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto d0_reduce_op = D0ReduceOp{};
auto d1_reduce_op = D1ReduceOp{};
auto d1_element_op = D1ElementOp{};
// do GEMM
auto batched_gemm = DeviceBatchedGemmReduceInstance{};
......@@ -192,8 +194,7 @@ int main(int argc, char* argv[])
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
d1_element_op,
BatchCount);
if(!batched_gemm.IsSupportedArgument(argument))
......@@ -258,17 +259,21 @@ int main(int argc, char* argv[])
{
for(int m = 0; m < M; ++m)
{
float d0_acc = d0_reduce_op.GetReduceZeroValue();
float d1_acc = d1_reduce_op.GetReduceZeroValue();
float d0_acc = d0_reduce_op.GetReductionZeroVal();
float d1_acc = d1_reduce_op.GetReductionZeroVal();
for(int n = 0; n < N; ++n)
{
d0_reduce_op.Reduce(d0_acc, c_g_m_n_host_result(batch, m, n));
d1_reduce_op.Reduce(d1_acc, c_g_m_n_host_result(batch, m, n));
float d0_val = ck::type_convert<float>(c_g_m_n_host_result(m, n));
float d1_val;
d1_element_op(d1_val, d0_val);
d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val);
}
d0_g_m_host_result(batch, m) = d0_acc;
d1_g_m_host_result(batch, m) = d1_acc;
d0_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d0_acc);
d1_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d1_acc);
}
}
......
......@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -44,8 +43,7 @@ __global__ void
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const D0ReduceOperation d0_reduce_op,
const D1ReduceOperation d1_reduce_op,
const D1ElementwiseOperation d1_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -82,8 +80,7 @@ __global__ void
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
d1_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -99,8 +96,7 @@ __global__ void
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = d0_reduce_op;
ignore = d1_reduce_op;
ignore = d1_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -125,6 +121,7 @@ template <typename ALayout,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -161,8 +158,7 @@ template <typename ALayout,
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation>
D1ElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
......@@ -564,6 +560,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1,
......@@ -624,8 +621,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op,
D1ElementwiseOperation d1_element_op,
index_t BatchCount)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
......@@ -648,8 +644,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
d0_reduce_op_{d0_reduce_op},
d1_reduce_op_{d1_reduce_op}
d1_element_op_{d1_element_op}
{
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
......@@ -684,8 +679,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
D0ReduceOperation d0_reduce_op_;
D1ReduceOperation d1_reduce_op_;
D1ElementwiseOperation d1_element_op_;
};
// Invoker
......@@ -740,8 +734,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -763,8 +756,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d0_reduce_op_,
arg.d1_reduce_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -782,8 +774,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -805,8 +796,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d0_reduce_op_,
arg.d1_reduce_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -865,8 +855,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op,
D1ElementwiseOperation d1_element_op,
index_t BatchCount)
{
return Argument{p_a,
......@@ -883,8 +872,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
d1_element_op,
BatchCount};
}
......@@ -905,8 +893,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op,
D1ElementwiseOperation d1_element_op,
index_t BatchCount) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......@@ -923,8 +910,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
d1_element_op,
BatchCount);
}
......
......@@ -16,6 +16,31 @@ namespace ck {
namespace tensor_operation {
namespace device {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
......@@ -25,7 +50,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputeBasePrtOfBatch,
typename ComputePtrOffsetOfBatch,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
......@@ -43,7 +68,7 @@ __global__ void
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
......@@ -52,11 +77,11 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetBBasePtr(g_idx)));
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -82,7 +107,7 @@ __global__ void
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = compute_base_ptr_of_batch_;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -256,26 +281,26 @@ struct DeviceBatchedGemmXdl
return globalblockid_to_m0_n0_block_cluster_adaptor;
}
struct ComputeBasePtrOfStridedBatch
struct ComputePtrOffsetOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC)
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
......@@ -359,9 +384,9 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_base_ptr_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(),
b_grid_desc_k0_n_k1_.GetElementSpaceSize(),
c_grid_desc_m_n_.GetElementSpaceSize()},
compute_ptr_offset_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(),
b_grid_desc_k0_n_k1_.GetElementSpaceSize(),
c_grid_desc_m_n_.GetElementSpaceSize()},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
......@@ -388,7 +413,7 @@ struct DeviceBatchedGemmXdl
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
......@@ -448,7 +473,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputeBasePtrOfStridedBatch,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
true>;
......@@ -467,7 +492,7 @@ struct DeviceBatchedGemmXdl
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_base_ptr_of_batch_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
else
......@@ -482,7 +507,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputeBasePtrOfStridedBatch,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
false>;
......@@ -501,7 +526,7 @@ struct DeviceBatchedGemmXdl
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_base_ptr_of_batch_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
......
......@@ -18,6 +18,9 @@ namespace ck {
namespace tensor_operation {
namespace device {
/*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3() \endlink.
*/
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
......
......@@ -9,8 +9,7 @@ namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation>
typename D1ElementwiseOperation>
struct DeviceGemmReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -27,8 +26,7 @@ struct DeviceGemmReduce : public BaseOperator
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op,
D1ElementwiseOperation d1_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
......@@ -37,13 +35,11 @@ struct DeviceGemmReduce : public BaseOperator
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation>
typename D1ElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation>>;
D1ElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -29,6 +29,7 @@ template <typename ALayout,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -65,8 +66,7 @@ template <typename ALayout,
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation>
D1ElementwiseOperation>
{
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
......@@ -382,6 +382,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1,
......@@ -440,8 +441,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op)
D1ElementwiseOperation d1_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
......@@ -457,8 +457,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
d0_reduce_op_{d0_reduce_op},
d1_reduce_op_{d1_reduce_op}
d1_element_op_{d1_element_op}
{
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
......@@ -491,8 +490,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
D0ReduceOperation d0_reduce_op_;
D1ReduceOperation d1_reduce_op_;
D1ElementwiseOperation d1_element_op_;
};
// Invoker
......@@ -544,8 +542,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -565,8 +562,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d0_reduce_op_,
arg.d1_reduce_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -583,8 +579,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -604,8 +599,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d0_reduce_op_,
arg.d1_reduce_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -655,8 +649,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op)
D1ElementwiseOperation d1_element_op)
{
return Argument{p_a,
p_b,
......@@ -672,8 +665,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op};
d1_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -693,8 +685,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op,
D1ElementwiseOperation d1_element_op,
index_t /* KBatch */ = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......@@ -711,8 +702,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op);
d1_element_op);
}
// polymorphic
......
......@@ -5,20 +5,6 @@ namespace ck {
namespace tensor_operation {
namespace element_wise {
struct ReduceSum
{
__host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); }
__host__ __device__ void Reduce(float& acc, float v) const { acc += v; }
};
struct ReduceSquareSum
{
__host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); }
__host__ __device__ void Reduce(float& acc, float v) const { acc += v * v; }
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -8,6 +8,7 @@
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "reduction_functions_threadwise.hpp"
namespace ck {
......@@ -18,8 +19,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -39,8 +39,7 @@ __global__ void
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const D0ReduceOperation d0_reduce_op,
const D1ReduceOperation d1_reduce_op,
const D1ElementwiseOperation d1_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -60,8 +59,7 @@ __global__ void
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
d1_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -76,8 +74,7 @@ __global__ void
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = d0_reduce_op;
ignore = d1_reduce_op;
ignore = d1_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -97,6 +94,7 @@ template <typename FloatAB,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
......@@ -372,8 +370,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const D0ReduceOperation& d0_reduce_op,
const D1ReduceOperation& d1_reduce_op,
const D1ElementwiseOperation& d1_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
......@@ -741,13 +738,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
// TODO: this should be implemented as a blockwise reduction
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
// reduce: threadwise copy from LDS to VGPR
......@@ -763,7 +760,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatCShuffle,
FloatCShuffle,
FloatReduceAcc,
decltype(c_reduce_block_desc_mperblock_nperblock),
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(c_reduce_thread_lengths_mperblock_nperblock),
......@@ -775,7 +772,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// reduce: copy from VGPR to global
auto d0_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatCShuffle,
FloatReduceAcc,
FloatD,
decltype(d_reduce_thread_desc_mblock_mperblock),
decltype(d_grid_desc_mblock_mperblock),
......@@ -840,6 +837,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
using ThreadwiseReduce_D0 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
D0ReduceOperation,
false>;
using ThreadwiseReduce_D1 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
D1ReduceOperation,
false>;
const auto d0_zeroVal = D0ReduceOperation::GetReductionZeroVal();
const auto d1_zeroVal = D0ReduceOperation::GetReductionZeroVal();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d0_thread_buf(I) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d1_thread_buf(I) = d1_zeroVal; });
// reduce
{
// copy from LDS to VGPR
......@@ -850,26 +869,20 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_reduce_thread_buf);
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
FloatReduceAcc d0_acc = d0_reduce_op.GetReduceZeroValue();
FloatReduceAcc d1_acc = d1_reduce_op.GetReduceZeroValue();
ThreadwiseReduce_D0::Reduce(c_reduce_thread_buf, d0_thread_buf);
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
d0_reduce_op.Reduce(d0_acc, c_reduce_thread_buf[offset]);
d1_reduce_op.Reduce(d1_acc, c_reduce_thread_buf[offset]);
d1_element_op(c_reduce_thread_buf(offset), c_reduce_thread_buf(offset));
});
constexpr index_t out_offset =
d_reduce_thread_desc_mperblock.CalculateOffset(make_tuple(im));
d0_thread_buf(Number<out_offset>{}) = d0_acc;
d1_thread_buf(Number<out_offset>{}) = d1_acc;
});
ThreadwiseReduce_D1::Reduce(c_reduce_thread_buf, d1_thread_buf);
// copy from VGPR to Global
d0_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment