Unverified Commit 5c7cec11 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Code clean up (#20)



* tuning para,

* testing on v100

* add fp16

* remove deprecated tensor descriptor

* sync with miopen

* update build script
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 7d09790a
...@@ -18,26 +18,32 @@ include_directories(BEFORE ${Boost_INCLUDE_DIRS}) ...@@ -18,26 +18,32 @@ include_directories(BEFORE ${Boost_INCLUDE_DIRS})
link_directories(${Boost_LIBRARY_DIRS}) link_directories(${Boost_LIBRARY_DIRS})
#OpenMP #OpenMP
if( NOT( ${CMAKE_CXX_COMPILER_ID} STREQUAL "AppleClang") ) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
# workaround issue hipcc in rocm3.5 cannot find openmp
set(OpenMP_CXX "${CMAKE_CXX_COMPILER}")
set(OpenMP_CXX_FLAGS "-fopenmp=libomp -Wno-unused-command-line-argument")
set(OpenMP_CXX_LIB_NAMES "libomp" "libgomp" "libiomp5")
set(OpenMP_libomp_LIBRARY ${OpenMP_CXX_LIB_NAMES})
set(OpenMP_libgomp_LIBRARY ${OpenMP_CXX_LIB_NAMES})
set(OpenMP_libiomp5_LIBRARY ${OpenMP_CXX_LIB_NAMES})
else()
find_package(OpenMP REQUIRED) find_package(OpenMP REQUIRED)
endif()
message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}")
message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}")
message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}")
message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
link_libraries(${OpenMP_gomp_LIBRARY}) link_libraries(${OpenMP_gomp_LIBRARY})
link_libraries(${OpenMP_pthread_LIBRARY}) link_libraries(${OpenMP_pthread_LIBRARY})
endif( NOT( ${CMAKE_CXX_COMPILER_ID} STREQUAL "AppleClang") )
#GPU backend #GPU backend
if(DEVICE_BACKEND STREQUAL "AMD") if(DEVICE_BACKEND STREQUAL "AMD")
set(CMAKE_MODULE_PATH "/opt/rocm/hip/cmake" ${CMAKE_MODULE_PATH})
find_package(HIP REQUIRED) find_package(HIP REQUIRED)
elseif(DEVICE_BACKEND STREQUAL "NVIDIA") elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
enable_language(CUDA) enable_language(CUDA)
include_directories(BEFORE ${CUDA_COMMON_INCLUDE_DIR})
endif() endif()
# #
...@@ -47,19 +53,27 @@ include_directories(BEFORE ...@@ -47,19 +53,27 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm ${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm
${PROJECT_SOURCE_DIR}/external/include ${PROJECT_SOURCE_DIR}/external/half/include
${PROJECT_SOURCE_DIR}/driver/include ${PROJECT_SOURCE_DIR}/driver/include
${PROJECT_BINARY_DIR}/composable_kernel/include/utility ${PROJECT_BINARY_DIR}/composable_kernel/include/utility
) )
if(DEVICE_BACKEND STREQUAL "AMD")
include_directories(BEFORE
${PROJECT_SOURCE_DIR}/external/rocm/include
)
endif()
if(DEVICE_BACKEND STREQUAL "AMD") if(DEVICE_BACKEND STREQUAL "AMD")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp")
elseif(DEVICE_BACKEND STREQUAL "NVIDIA") elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp")
endif() endif()
add_subdirectory(driver) add_subdirectory(driver)
#ifndef CK_GRIDWISE_CONVOLUTION_KERNEL_WRAPPER
#define CK_GRIDWISE_CONVOLUTION_KERNEL_WRAPPER
template <class GridwiseConvolution, class T>
__global__ void run_gridwise_convolution_kernel(const T* const __restrict__ p_in_global,
const T* const __restrict__ p_wei_global,
T* const __restrict__ p_out_global)
{
GridwiseConvolution{}.Run(p_in_global, p_wei_global, p_out_global);
}
#endif
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER #define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
template <typename GridwiseOp, typename... Xs> template <typename GridwiseOp, typename... Xs>
__global__ void run_gridwise_operation(GridwiseOp, Xs... xs) __global__ void run_gridwise_operation(Xs... xs)
{ {
GridwiseOp{}.Run(xs...); GridwiseOp{}.Run(xs...);
} }
......
#ifndef CK_CONVOLUTION_COMMON_HPP
#define CK_CONVOLUTION_COMMON_HPP
namespace ck {
enum ConvolutionDirection
{
Forward,
BackwardData,
BackwardWeight
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_COL2IM_EB_NCHW_HPP
#define CK_GRIDWISE_COL2IM_EB_NCHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
namespace ck {
// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename ColGlobalDesc,
typename ImgGlobalDesc,
typename FilterSizes,
typename OutputSizes,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t EPerBlock,
index_t BPerBlock,
typename BlockCopySubLengths_E_B,
typename BlockCopyClusterLengths_E_B,
typename BlockCopyThreadClusterArrangeOrder,
typename BlockCopySrcAccessOrder,
typename BlockCopyDstAccessOrder,
index_t BlockCopyDataPerAccess_B>
struct GridwiseCol2Im_eb_nchw
{
__device__ void Run(const Float* const __restrict__ p_col_global,
Float* const __restrict__ p_img_global) const
{
constexpr auto col_e_b_global_desc = ColGlobalDesc{};
constexpr auto img_n_c_hi_wi_global_desc = ImgGlobalDesc{};
constexpr index_t N = img_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = img_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = img_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = img_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t Ho = OutputSizes{}[0];
constexpr index_t Wo = OutputSizes{}[1];
constexpr index_t Y = FilterSizes{}[0];
constexpr index_t X = FilterSizes{}[1];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || BlockCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % BlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [E, B]
static_assert(E % EPerBlock == 0 && B % BPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t EBlockWork = E / EPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<EBlockWork, BBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t e_block_data_on_global = block_work_id[0] * EPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
// construct img_eb_global_desc
constexpr auto img_n_c_hip_wip_global_desc = transform_tensor_descriptor(
img_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto img_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
img_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto img_e_b_global_desc = transform_tensor_descriptor(
img_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// blockwise atomic accumulation
auto blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(col_e_b_global_desc),
decltype(img_e_b_global_desc),
Sequence<EPerBlock, BPerBlock>,
BlockCopySubLengths_E_B,
BlockCopyClusterLengths_E_B,
BlockCopyThreadClusterArrangeOrder,
BlockCopySrcAccessOrder,
BlockCopyDstAccessOrder,
1,
1,
BlockCopyDataPerAccess_B,
BlockCopyDataPerAccess_B,
AddressSpace::Vgpr,
AddressSpace::Vgpr,
AddressSpace::Global,
InMemoryDataOperation::AtomicAdd>(
{e_block_data_on_global, b_block_data_on_global},
{e_block_data_on_global, b_block_data_on_global});
// blockwise copy
blockwise_copy.Run(p_col_global, p_img_global);
}
};
} // namespace ck
#endif
...@@ -36,8 +36,8 @@ template <index_t GridSize, ...@@ -36,8 +36,8 @@ template <index_t GridSize,
index_t ThreadGemmDataPerRead_GemmN, index_t ThreadGemmDataPerRead_GemmN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM, typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM, typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmN, index_t GemmABlockCopySrcDataPerRead_GemmM,
index_t GemmABlockCopyDstDataPerWrite_GemmN, index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN, index_t GemmBBlockCopySrcDataPerRead_GemmN,
...@@ -82,13 +82,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -82,13 +82,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
constexpr auto wei_gemmk_gemmm_global_desc = constexpr auto wei_gemmk_gemmm_global_desc =
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3); unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
// output tensor
constexpr auto out_gemmk_gemmn_global_desc =
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// input tensor // input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
...@@ -98,16 +91,15 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -98,16 +91,15 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Hi + InLeftPads::At(0) + InRightPads::At(0), Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Sequence<Y, Ho>, Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wi + InLeftPads::At(1) + InRightPads::At(1),
Sequence<X, Wo>,
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -117,6 +109,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -117,6 +109,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_gemmk_gemmn_global_desc =
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM // GEMM
// \todo there are more combinations of Y, ConvDilationH and ConvStrideH that don't need // \todo there are more combinations of Y, ConvDilationH and ConvStrideH that don't need
// atomic, find out all of them // atomic, find out all of them
...@@ -152,8 +151,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -152,8 +151,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
GemmABlockCopySrcDataPerRead_GemmN, GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmN, GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>, Sequence<0, 1>,
......
...@@ -25,13 +25,13 @@ template <index_t GridSize, ...@@ -25,13 +25,13 @@ template <index_t GridSize,
index_t EPerBlock, index_t EPerBlock,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t GemmMPerThreadSubC, index_t GemmMPerThread,
index_t GemmNPerThreadSubC, index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster, index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA, index_t GemmDataPerReadA,
index_t GemmDataPerReadB, index_t GemmDataPerReadB,
typename OutBlockCopySubLengths_K_B_N0, typename OutBlockCopySubLengths_K_B_N0,
...@@ -78,8 +78,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -78,8 +78,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1]; constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t C0 = GemmMPerThreadSubC; constexpr index_t C0 = GemmMPerThread;
constexpr index_t N0 = GemmNPerThreadSubC; constexpr index_t N0 = GemmNPerThread;
static_assert(C % C0 == 0 && N % N0 == 0, "wrong!"); static_assert(C % C0 == 0 && N % N0 == 0, "wrong!");
...@@ -225,20 +225,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -225,20 +225,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_e0e1c0_b0b1n0_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( constexpr auto c_e0e1c0_b0b1n0_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{}); Number<GemmMRepeat * GemmMPerThread>{}, Number<GemmNRepeat * GemmNPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
decltype(a_k_ec0_block_mtx_desc), decltype(a_k_ec0_block_mtx_desc),
decltype(b_k_bn0_block_mtx_desc), decltype(b_k_bn0_block_mtx_desc),
decltype(c_e0e1c0_b0b1n0_thread_mtx_desc), decltype(c_e0e1c0_b0b1n0_thread_mtx_desc),
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB>{}; GemmDataPerReadB>{};
...@@ -371,7 +371,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -371,7 +371,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
// define input tensor descriptor for threadwise copy // define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy // thread input tensor, src of threadwise copy
constexpr auto in_e0_e1_c0_b0_b1_n0_thread_desc = make_native_tensor_descriptor_packed( constexpr auto in_e0_e1_c0_b0_b1_n0_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, 1, GemmMPerThreadSubC, GemmNRepeat, 1, GemmNPerThreadSubC>{}); Sequence<GemmMRepeat, 1, GemmMPerThread, GemmNRepeat, 1, GemmNPerThread>{});
// global input tensor, dst of threadwise copy // global input tensor, dst of threadwise copy
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
...@@ -419,10 +419,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -419,10 +419,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t e_thread_data_on_global = const index_t e_thread_data_on_global =
e_block_data_on_global + c_thread_mtx_on_block.row / GemmMPerThreadSubC; e_block_data_on_global + c_thread_mtx_on_block.row / GemmMPerThread;
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / GemmNPerThreadSubC; b_block_data_on_global + c_thread_mtx_on_block.col / GemmNPerThread;
ThreadwiseGenericTensorSliceCopy_v4r2< ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(in_e0_e1_c0_b0_b1_n0_thread_desc), decltype(in_e0_e1_c0_b0_b1_n0_thread_desc),
......
...@@ -419,7 +419,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -419,7 +419,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
template <index_t GemmId> template <index_t GemmId>
__device__ static void Run(Float* __restrict__ p_in_global, __device__ static void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global, const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const Float* __restrict__ p_out_global,
Number<GemmId>)
{ {
constexpr index_t ConvStrideH = ConvStrides{}[0]; constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1]; constexpr index_t ConvStrideW = ConvStrides{}[1];
......
#ifndef CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW
#define CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_direct_convolution.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t CPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead>
struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_global_desc = InGlobalDesc{};
constexpr auto wei_kcyx_global_desc = WeiGlobalDesc{};
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
constexpr index_t N = in_nchw_global_desc.GetLength(I0);
constexpr index_t K = wei_kcyx_global_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_global_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_global_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_global_desc.GetLength(I3);
constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor_packed(
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{},
Number<InBlockCopyDataPerRead>{});
constexpr auto wei_ke_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<KPerBlock, CPerBlock * Y * X>{},
Number<WeiBlockCopyDataPerRead>{}); // 2d view of wei for blockwise copy
constexpr auto wei_kcyx_block_desc =
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, Y, X>{},
Sequence<wei_ke_block_desc.GetStride(I0), Y * X, X, 1>{});
// shared mem
constexpr index_t in_block_element_size =
in_nchw_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr index_t wei_block_element_size =
wei_kcyx_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
__shared__ Float
p_in_block[max_align * ((in_block_element_size + max_align - 1) / max_align)];
__shared__ Float
p_wei_block[max_align * ((wei_block_element_size + max_align - 1) / max_align)];
// threadwise tensors
constexpr index_t HiPerThread = HoPerThread + Y - 1;
constexpr index_t WiPerThread = WoPerThread + X - 1;
constexpr auto in_nchw_thread_block_desc = make_ConstantTensorDescriptor(
Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
in_nchw_block_desc.GetStrides());
constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_block_desc.GetStrides());
constexpr auto out_nkhw_thread_desc =
get_convolution_output_default_4d_tensor_descriptor_deprecated(
in_nchw_thread_block_desc, wei_kcyx_thread_block_desc);
// register
Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
// divide block work
constexpr index_t NBlockWork =
(out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr index_t KBlockWork =
(out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork =
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork =
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
const index_t block_id = blockIdx.x;
index_t itmp = block_id;
const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
const index_t h_block_work_id = itmp / WBlockWork;
const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin; // minus padding
const index_t wi_block_data_begin = wo_block_data_begin; // minus padding
// divide thread work
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
const index_t thread_id = get_thread_local_1d_id();
itmp = thread_id;
const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork);
itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
const index_t h_thread_work_id = itmp / WThreadWork;
const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
const index_t n_thread_data_begin = n_thread_work_id * NPerThread;
const index_t k_thread_data_begin = k_thread_work_id * KPerThread;
const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread;
const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread;
const index_t hi_thread_data_begin = ho_thread_data_begin;
const index_t wi_thread_data_begin = wo_thread_data_begin;
constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(in_nchw_global_desc),
decltype(in_nchw_block_desc),
decltype(in_nchw_block_desc.GetLengths()),
InBlockCopyDataPerRead>{};
#if 0
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(wei_kcyx_global_desc),
decltype(wei_kcyx_block_desc),
decltype(wei_kcyx_block_desc.GetLengths()),
1>{};
#elif 1
const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_ke_global_desc),
decltype(wei_ke_block_desc),
decltype(wei_ke_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>({0, 0}, {0, 0});
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock, __syncthreads())
{
// copy input tensor to LDS
blockwise_in_copy.Run(
p_in_global +
in_nchw_global_desc.GetOffsetFromMultiIndex(n_block_data_begin,
c_block_data_begin,
hi_block_data_begin,
wi_block_data_begin),
p_in_block);
// copy weight tensor to LDS
blockwise_wei_copy.Run(p_wei_global +
wei_kcyx_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_block);
__syncthreads();
for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
{
// threadwise convolution
#if 1
threadwise_direct_convolution_2(
in_nchw_thread_block_desc,
p_in_block +
in_nchw_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
c_thread_data,
hi_thread_data_begin,
wi_thread_data_begin),
wei_kcyx_thread_block_desc,
p_wei_block +
wei_kcyx_block_desc.GetOffsetFromMultiIndex(
k_thread_data_begin, c_thread_data, 0, 0),
out_nkhw_thread_desc,
p_out_thread);
#elif 0
threadwise_direct_convolution_3(
in_nchw_thread_block_desc,
p_in_block +
in_nchw_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
c_thread_data,
hi_thread_data_begin,
wi_thread_data_begin),
wei_kcyx_thread_block_desc,
p_wei_block +
wei_kcyx_block_desc.GetOffsetFromMultiIndex(
k_thread_data_begin, c_thread_data, 0, 0),
out_nkhw_thread_desc,
p_out_thread);
#endif
}
}
// copy output tensor from register to global mem
threadwise_tensor_slice_copy(out_nkhw_thread_desc,
p_out_thread,
out_nkhw_global_desc,
p_out_global +
out_nkhw_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_nkhw_thread_desc.GetLengths(),
Number<1>{});
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopyClusterLengths_CHWN,
index_t InBlockCopyDataPerRead_N,
index_t WeiBlockCopyDataPerRead_K,
index_t OutThreadCopyDataPerWrite_N>
struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// be careful of this assertion
static_assert(
NPerBlock % NPerThread == 0 &&
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0)),
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
// divide block work: [K, Ho, Wo, N]
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin;
// flattend (2d) tensor view of gridwise weight
constexpr auto wei_cyx_k_global_desc =
make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
// tensor view of blockwise input and weight in LDS
// be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N,
WeiBlockCopyDataPerRead_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{},
Number<InBlockCopyDataPerRead_N>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not meet");
constexpr auto wei_cyx_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock * Y * X, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
constexpr auto wei_c_y_x_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, Y, X, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
// tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
const auto blockwise_in_copy =
#if 0
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyDataPerRead_N>{};
#else
Blockwise4dTensorCopy3<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerRead_N>{};
#endif
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_cyx_k_global_desc),
decltype(wei_cyx_k_block_desc),
decltype(wei_cyx_k_block_desc.GetLengths()),
WeiBlockCopyDataPerRead_K>{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr auto a_c_k_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<KPerBlock>{},
Number<wei_c_y_x_k_block_desc.GetStride(I0)>{});
constexpr auto b_c_wn_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{},
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
constexpr auto c_k_wn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{},
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
const auto blockwise_batch_gemm =
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
BlockSize,
decltype(a_c_k_block_mtx_desc),
decltype(b_c_wn_block_mtx_desc),
decltype(c_k_wn_thread_mtx_desc),
0,
in_c_h_w_n_block_desc.GetStride(I1),
out_k_h_w_n_thread_desc.GetStride(I1),
HoPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t in_block_space =
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space =
wei_c_y_x_k_block_desc.GetElementSpace(Number<max_align>{});
__shared__ Float p_in_block[in_block_space];
__shared__ Float p_wei_block[wei_block_space];
// register
// C++ lambda doesn't capture array, use pointer instead
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
Float* const p_out_thread = p_out_thread_data;
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
const Float* p_in_global_block_offset =
p_in_global +
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_offset =
p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0),
__syncthreads())
{
#if 1
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
#else
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block);
#endif
__syncthreads();
#pragma unroll
for(index_t y = 0; y < Y; ++y)
{
#pragma unroll
for(index_t x = 0; x < X; ++x)
{
#if 1
blockwise_batch_gemm.Run
#else
blockwise_batch_gemm.Run_amd_asm
#endif
(p_wei_block + wei_c_y_x_k_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_in_block + in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_out_thread);
}
}
}
// output: register to global mem,
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
// perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_assert(
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!");
// output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc =
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
K1,
K2,
Ho,
Wo / (W1 * W2),
W1,
W2,
N / f_dummy(N1 * N2),
N1,
N2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_tensor_slice_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
}).Else([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0,
"wrong!");
// output is a 10d tensor
constexpr index_t N1 = NPerBlock;
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
for(index_t i = 0; i < 64; ++i)
{
printf("out %f, ", p_out_thread[i]);
}
}
#endif
threadwise_tensor_slice_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
});
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
// define B = flatten(N, Hi, Wi)
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t BPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t BPerThread,
index_t KPerThread,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
index_t InBlockCopyThreadPerDim0,
index_t InBlockCopyThreadPerDim1,
index_t WeiBlockCopyThreadPerDim0,
index_t WeiBlockCopyThreadPerDim1,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
index_t OutThreadCopyDataPerWrite>
struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_chwn_global_desc = InGlobalDesc{};
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
constexpr index_t Hi = in_chwn_global_desc.GetLength(I1);
constexpr index_t Wi = in_chwn_global_desc.GetLength(I2);
constexpr index_t N = in_chwn_global_desc.GetLength(I3);
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
constexpr index_t B = N * Hi * Wi;
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// divide block work by 2d: [K, B]
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock;
const index_t k_block_work_id = get_block_1d_id() / BBlockWork;
const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t b_block_data_begin = b_block_work_id * BPerBlock;
// flattend (2d) tensor view of gridwise input
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
// tensor view of blockwise input and weight
// be careful of alignment
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
// tensor view of threadwise output in register
constexpr auto out_kb_thread_desc =
make_ConstantTensorDescriptor(Sequence<KPerThread, BPerThread>{});
// blockwise in copy
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 0
const auto blockwise_in_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_in_copy =
Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()),
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_in_copy =
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()),
InBlockCopyDataPerRead>{};
#endif
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
#if 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{};
#endif
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// c_mtx[K,B] is out_block[K,B]
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_cyxk_block_desc.GetStride(I0)>{});
constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<BPerBlock>{}, Number<in_cb_block_desc.GetStride(I0)>{});
constexpr auto c_kxb_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<BPerThread>{});
const auto blockwise_gemm =
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t max_align =
math::lcm(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
constexpr index_t in_block_space = in_cb_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space =
wei_cyxk_block_desc.GetElementSpace(Number<max_align>{});
__shared__ Float p_in_block[in_block_space];
__shared__ Float p_wei_block[wei_block_space];
const Float* p_in_global_block_offset =
p_in_global + in_cb_global_desc.GetOffsetFromMultiIndex(0, b_block_data_begin);
const Float* p_wei_global_block_offset =
p_wei_global +
wei_cyxk_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
// register
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
// set threadwise output to 0
threadwise_matrix_set_zero(c_kxb_thread_mtx_desc, p_out_thread);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
__syncthreads())
{
// load data
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block);
__syncthreads();
// compute on current data
// a series of GEMM
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
#if 1
blockwise_gemm.Run
#elif 0
blockwise_gemm.Run_RegisterDoubleBuffer
#elif 1
blockwise_gemm.Run_amd_asm
#endif
(p_wei_block + wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_in_block + y * Wi + x,
p_out_thread);
}
}
}
// output: register to global mem,
const auto c_thread_mtx_begin =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
{
for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
{
const auto c_thread_mtx_distance =
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row;
index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col;
index_t h_data = b_data / (Wi * N);
index_t itmp = b_data - h_data * (Wi * N);
index_t w_data = itmp / N;
index_t n_data = itmp - w_data * N;
if(n_data < N && h_data < Ho && w_data < Wo)
{
p_out_global[out_khwn_global_desc.GetOffsetFromMultiIndex(
k_data, h_data, w_data, n_data)] =
p_out_thread[out_kb_thread_desc.GetOffsetFromMultiIndex(k, b)];
}
}
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
// define B = flatten(N, Hi, Wi)
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t BPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t BPerThread,
index_t KPerThread,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
index_t InBlockCopyThreadPerDim0,
index_t InBlockCopyThreadPerDim1,
index_t WeiBlockCopyThreadPerDim0,
index_t WeiBlockCopyThreadPerDim1,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
index_t OutThreadCopyDataPerWrite>
struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_chwn_global_desc = InGlobalDesc{};
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
constexpr index_t Hi = in_chwn_global_desc.GetLength(I1);
constexpr index_t Wi = in_chwn_global_desc.GetLength(I2);
constexpr index_t N = in_chwn_global_desc.GetLength(I3);
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
constexpr index_t B = N * Hi * Wi;
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// assert for LDS double buffer
static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided");
// divide block work by 2d: [K, B]
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock;
const index_t k_block_work_id = get_block_1d_id() / BBlockWork;
const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t b_block_data_begin = b_block_work_id * BPerBlock;
// flattend (2d) tensor view of gridwise input
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
// tensor view of blockwise input and weight
// be careful of alignment
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
// tensor view of threadwise output in register
constexpr auto out_kb_thread_desc =
make_ConstantTensorDescriptor(Sequence<KPerThread, BPerThread>{});
// blockwise in copy
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 0
const auto blockwise_in_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_in_copy =
Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()),
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_in_copy =
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()),
InBlockCopyDataPerRead>{};
#endif
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
#if 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{};
#endif
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// c_mtx[K,B] is out_block[K,B]
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_cyxk_block_desc.GetStride(I0)>{});
constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<BPerBlock>{}, Number<in_cb_block_desc.GetStride(I0)>{});
constexpr auto c_kxb_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<BPerThread>{});
const auto blockwise_gemm =
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t max_align =
math::lcm(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
constexpr index_t in_block_space = in_cb_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space =
wei_cyxk_block_desc.GetElementSpace(Number<max_align>{});
// LDS double buffer
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
const Float* p_in_global_block_offset =
p_in_global + in_cb_global_desc.GetOffsetFromMultiIndex(0, b_block_data_begin);
const Float* p_wei_global_block_offset =
p_wei_global +
wei_cyxk_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
// preload data into LDS
{
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_double);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_double);
}
// register
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
// set threadwise output to 0
threadwise_matrix_set_zero(c_kxb_thread_mtx_desc, p_out_thread);
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
c_block_data_begin += 2 * CPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
// load next data
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
__syncthreads();
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset,
p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
// compute on current data
// a series of GEMM
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
#if 1
blockwise_gemm.Run
#elif 0
blockwise_gemm.Run_RegisterDoubleBuffer
#elif 0
blockwise_gemm.Run_amd_asm
#endif
(p_wei_block_now +
wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_in_block_now + y * Wi + x,
p_out_thread);
}
}
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
// tail
{
// even
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
__syncthreads();
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
p_wei_register_buffer);
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
#if 1
blockwise_gemm.Run
#elif 0
blockwise_gemm.Run_RegisterDoubleBuffer
#elif 0
blockwise_gemm.Run_amd_asm
#endif
(p_wei_block_double +
wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_in_block_double + y * Wi + x,
p_out_thread);
}
}
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd
__syncthreads();
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
#if 1
blockwise_gemm.Run
#elif 0
blockwise_gemm.Run_RegisterDoubleBuffer
#elif 0
blockwise_gemm.Run_amd_asm
#endif
(p_wei_block_double + wei_block_space +
wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_in_block_double + in_block_space + y * Wi + x,
p_out_thread);
}
}
}
// output: register to global mem,
const auto c_thread_mtx_begin =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
if(Y == 1 && X == 1)
{ // pure 1x1 conv (non padding, 1x1 stride)
constexpr index_t K2_ = GemmMPerThreadSubC;
constexpr index_t K1_ = KPerBlock / KPerThread;
constexpr index_t B2_ = GemmNPerThreadSubC;
constexpr index_t B1_ = BPerBlock / BPerThread;
constexpr auto out_6d_global_desc = make_ConstantTensorDescriptor(
Sequence<K / (K1_ * K2_), K1_, K2_, B / (B1_ * B2_), B1_, B2_>{});
constexpr auto out_6d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerBlock / (K1_ * K2_), 1, K2_, BPerBlock / (B1_ * B2_), 1, B2_>{});
constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{});
threadwise_6d_tensor_copy(out_6d_thread_desc,
p_out_thread,
out_6d_global_desc,
p_out_global +
out_kb_global_desc.GetOffsetFromMultiIndex(
k_thread_data_begin, b_thread_data_begin),
out_6d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite>{});
}
else
{
for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
{
for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
{
const auto c_thread_mtx_distance =
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row;
index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col;
index_t h_data = b_data / (Wi * N);
index_t itmp = b_data - h_data * (Wi * N);
index_t w_data = itmp / N;
index_t n_data = itmp - w_data * N;
if(n_data < N && h_data < Ho && w_data < Wo)
{
p_out_global[out_khwn_global_desc.GetOffsetFromMultiIndex(
k_data, h_data, w_data, n_data)] =
p_out_thread[out_kb_thread_desc.GetOffsetFromMultiIndex(k, b)];
}
}
}
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
// define B = merge(N0, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t BPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t N1,
index_t N2,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_C_N1_B_N2,
class InBlockCopyClusterLengths_C_N1_B_N2,
index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2,
class WeiBlockCopySubLengths_C_K,
class WeiBlockCopyClusterLengths_C_K,
index_t WeiBlockCopyDataPerAccess_K>
struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2);
constexpr index_t B = N0 * Ho * Wo;
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && C % CPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// memory layout descriptor in device memory [N0, N1, N2, C, H, W]
constexpr auto in_n0_n1_n2_c_h_w_global_mem_desc =
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{});
// merged tensor descriptor in device memory [C, N1, B, N2], src of blockwise copy
constexpr auto in_c_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_n0_n1_n2_c_h_w_global_mem_desc.Slice(I4, Number<Ho>{}).Slice(I5, Number<Wo>{}),
Sequence<3>{},
Sequence<1>{},
Sequence<0, 4, 5>{},
Sequence<2>{});
// memory layout descriptor in LDS [C, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_c_n1_b_n2_block_mem_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert(in_c_n1_b_n2_block_mem_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not satisfied");
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockSize,
Float,
decltype(in_c_n1_b_n2_global_merged_desc),
decltype(in_c_n1_b_n2_block_mem_desc),
decltype(in_c_n1_b_n2_block_mem_desc.GetLengths()),
InBlockCopySubLengths_C_N1_B_N2,
InBlockCopyClusterLengths_C_N1_B_N2,
Sequence<0, 1, 3, 2>, // thread_arrange_order [C, N1, N2, B]
Sequence<1, 3, 0, 2>, // src_access_order [N1, N2, C, B]
Sequence<0, 1, 2, 3>, // dst_access_order [C, N1, B, N2]
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDataPerAccess_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockSize,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_C_K,
WeiBlockCopyClusterLengths_C_K,
Sequence<0, 1>, // thread_arrange_order [C, K]
Sequence<0, 1>, // src_access_order [C, K]
Sequence<0, 1>, // dst_access_order [C, K]
WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[CPerBlock, KPerBlock] is in LDS
// b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
constexpr auto b_c_n1bn2_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<N1 * BPerBlock * N2>{},
Number<in_c_n1_b_n2_block_mem_desc.GetStride(I0)>{});
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_c_k_block_mtx_desc),
decltype(b_c_n1bn2_block_mtx_desc),
decltype(c_k0k2_n1n2_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// choose GEMM implementation here
const auto run_blockwise_gemm = [&](auto... Xs) {
#if 1
return blockwise_gemm.Run(Xs...);
#else
return blockwise_gemm.Run_amd_asm(Xs...);
#endif
};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDataPerAccess_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
in_c_n1_b_n2_block_mem_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
__shared__ Float p_in_block[in_block_space];
__shared__ Float p_wei_block[wei_block_space];
// register allocation for output
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
#if 0
// do work
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
// calculate origin of block input and weight tensor on global memory
const Float* p_in_block_on_global =
p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x);
const Float* p_wei_block_on_global =
p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0);
for(index_t
c_block_data_on_global = 0;
c_block_data_on_global < C;
c_block_data_on_global += CPerBlock,
p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{
blockwise_in_copy.Run(p_in_block_on_global, p_in_block);
blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block);
__syncthreads();
run_blockwise_gemm(p_wei_block, p_in_block, p_out_thread);
__syncthreads();
}
}
}
#else
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
// calculate origin of block input and weight tensor on global memory
const Float* p_in_block_on_global =
p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x);
const Float* p_wei_block_on_global =
p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0);
for(index_t c_block_data_on_global = 0; c_block_data_on_global < C;
c_block_data_on_global += CPerBlock)
{
blockwise_in_copy.Run(p_in_block_on_global, p_in_block);
blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block);
__syncthreads();
blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread);
__syncthreads();
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(
I0, Number<CPerBlock>{}, True);
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(
I0, Number<CPerBlock>{}, True);
}
// reset C
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<C>{}, False);
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<C>{}, False);
}
}
#endif
// copy output: register to global memory
{
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t K0 = K / (K1 * K2);
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
make_ConstantTensorDescriptor_packed(
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
// output tensor descriptor in register, src of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
.Fold(I0, Number<N1>{}, Number<N2>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
Sequence<3>{},
Sequence<1>{},
Sequence<0, 4, 5>{},
Sequence<2>{});
// origin of dst in device memory
Float* p_out_thread_on_global =
p_out_global +
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
threadwise_generic_tensor_slice_copy_v1(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
p_out_thread,
{0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc,
p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 8, 1>::type{},
Number<1>{});
}
}
};
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment