Unverified Commit 12649254 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

reorganize files to prepare for MIOpen integration (#51)

* change olc cmake

* adding online compile to fwd-v4r5r2

* update scripts

* remane fwd-v4r5r2 to fwd-v6r1

* clean up
parent fbdf4332
......@@ -6,14 +6,14 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
include(TargetFlags)
include(AddKernels)
#c++
## C++
enable_language(CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
#OpenMP
## OpenMP
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
# workaround issue hipcc in rocm3.5 cannot find openmp
set(OpenMP_CXX "${CMAKE_CXX_COMPILER}")
......@@ -35,56 +35,8 @@ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
link_libraries(${OpenMP_gomp_LIBRARY})
link_libraries(${OpenMP_pthread_LIBRARY})
#GPU backend
if(DEVICE_BACKEND STREQUAL "AMD")
find_package(HIP REQUIRED)
endif()
#
include_directories(BEFORE
${PROJECT_SOURCE_DIR}/composable_kernel/include
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
${PROJECT_SOURCE_DIR}/external/half/include
${PROJECT_SOURCE_DIR}/driver/include
${PROJECT_BINARY_DIR}/composable_kernel/include/utility
)
if(DEVICE_BACKEND STREQUAL "AMD")
include_directories(BEFORE
${PROJECT_SOURCE_DIR}/external/rocm/include
)
endif()
if(DEVICE_BACKEND STREQUAL "AMD")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp")
endif()
add_subdirectory(driver)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
message("Compiling options for drivers: ${CMAKE_CXX_FLAGS}")
if(DEVICE_BACKEND STREQUAL "AMD")
set(CONV_V2_SOURCE driver/conv_driver_v2.cpp)
set(CONV_BWD_DATA_V2_SOURCE driver/conv_bwd_data_driver_v2.cpp)
set(CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp)
endif()
add_executable(conv_driver_v2 ${CONV_V2_SOURCE})
add_executable(conv_bwd_data_driver_v2 ${CONV_BWD_DATA_V2_SOURCE})
add_executable(conv_driver_v2_olc ${CONV_V2_OLC_SOURCE})
target_include_directories(conv_driver_v2_olc PRIVATE driver/olCompiling/include/)
target_link_libraries(conv_driver_v2 PRIVATE modConv)
target_link_libraries(conv_bwd_data_driver_v2 PRIVATE modConv)
target_link_libraries(conv_driver_v2_olc PRIVATE modConv)
## HIP
find_package(HIP REQUIRED)
message(STATUS "Build with HIP ${hip_VERSION}")
add_subdirectory(host)
#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP
#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_contraction_v1r1.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGKGM0GM1GridDesc,
typename BGKGN0GN1GridDesc,
typename CGM0GM1GN0GN1GridDesc,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
typename ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_GM11,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
typename BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_GN11,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
__host__ float
driver_dynamic_contraction_v1r1(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc,
const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc,
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
// GEMM
using GridwiseContraction = GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGKGM0GM1GridDesc,
BGKGN0GN1GridDesc,
CGM0GM1GN0GN1GridDesc,
GM1PerBlockGM11,
GN1PerBlockGN11,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_GM11,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_GN11,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
const auto K = a_gk_gm0_gm1_grid_desc.GetLength(I0);
if(!GridwiseContraction::CheckValidity(
a_gk_gm0_gm1_grid_desc, b_gk_gn0_gn1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc))
{
throw std::runtime_error(
"wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting");
}
const auto a_gk_gm0_gm10_gm11_grid_desc =
GridwiseContraction::MakeAGKGM0GM10GM11GridDescriptor(a_gk_gm0_gm1_grid_desc);
const auto b_gk_gn0_gn10_gn11_grid_desc =
GridwiseContraction::MakeBGKGN0GN10GN11GridDescriptor(b_gk_gn0_gn1_grid_desc);
using AGKGM0GM10GM11GridDesc = decltype(a_gk_gm0_gm10_gm11_grid_desc);
using BGKGN0GN10GN11GridDesc = decltype(b_gk_gn0_gn10_gn11_grid_desc);
// c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc);
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc);
// c_blockid_to_gm10_gn10_block_cluster_adaptor
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc);
using CBlockIdToGM10GN10BlockClusterAdaptor =
decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor);
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc);
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(K);
const bool has_double_tail_k_block_loop =
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(K);
{
std::cout << "a_gk_gm0_gm10_gm11_grid_desc{" << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I0)
<< ", " << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I1) << ", "
<< a_gk_gm0_gm10_gm11_grid_desc.GetLength(I2) << ", "
<< a_gk_gm0_gm10_gm11_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "b_gk_gn0_gn10_gn11_grid_desc{" << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I0)
<< ", " << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I1) << ", "
<< b_gk_gn0_gn10_gn11_grid_desc.GetLength(I2) << ", "
<< b_gk_gn0_gn10_gn11_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl;
}
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGKGM0GM10GM11GridDesc>,
remove_reference_t<BGKGN0GN10GN11GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGKGM0GM10GM11GridDesc>,
remove_reference_t<BGKGN0GN10GN11GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGKGM0GM10GM11GridDesc>,
remove_reference_t<BGKGN0GN10GN11GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGKGM0GM10GM11GridDesc>,
remove_reference_t<BGKGN0GN10GN11GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
return ave_time;
}
} // namespace ck
#endif
......@@ -13,19 +13,19 @@ template <index_t BlockSize,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGKGM0GM1GridDesc,
typename BGKGN0GN1GridDesc,
typename CGM0GM1GN0GN1GridDesc,
typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
index_t GK0PerBlock,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
index_t BM10BN10ThreadClusterBM100,
index_t BM10BN10ThreadClusterBN100,
index_t BM10BN10ThreadClusterBM101,
index_t BM10BN10ThreadClusterBN101,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder,
......@@ -52,9 +52,9 @@ __host__ float
driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AGKGM0GM1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc,
const BGKGN0GN1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc,
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc,
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
......@@ -71,25 +71,26 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
constexpr auto I5 = Number<5>{};
// GEMM
using GridwiseContraction = GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2<
using GridwiseContraction =
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGKGM0GM1GridDesc,
BGKGN0GN1GridDesc,
CGM0GM1GN0GN1GridDesc,
AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1,
GM1PerBlockGM11,
GN1PerBlockGN11,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
GK0PerBlock,
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM100,
BM10BN10ThreadClusterBN100,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
......@@ -113,37 +114,40 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
if(!GridwiseContraction::CheckValidity(
a_gk0_gm0_gm1_gk1_grid_desc, b_gk0_gn0_gn1_gk1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc))
a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
{
throw std::runtime_error(
"wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting");
throw std::runtime_error("wrong! "
"GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
"GM0_GM1_GN0_GN1 has invalid setting");
}
const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc =
GridwiseContraction::MakeAGK0GM0GM10GM11GK1GridDescriptor(a_gk0_gm0_gm1_gk1_grid_desc);
const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc =
GridwiseContraction::MakeBGK0GN0GN10GN11GK1GridDescriptor(b_gk0_gn0_gn1_gk1_grid_desc);
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
using AGK0GM0GM10GM11GK1GridDesc = decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc);
using BGK0GN0GN10GN11GK1GridDesc = decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc);
using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1);
using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1);
// c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc);
// c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
c_grid_desc_gm0_gm1_gn0_gn1);
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc);
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1);
// c_blockid_to_gm10_gn10_block_cluster_adaptor
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc);
// c_grid_block_cluster_blockid_to_gm10_gn10
const auto c_grid_block_cluster_blockid_to_gm10_gn10 =
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1);
using CBlockIdToGM10GN10BlockClusterAdaptor =
decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor);
using CGridBlockCluster_BlockId_To_GM10_GN10 =
decltype(c_grid_block_cluster_blockid_to_gm10_gn10);
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc);
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1);
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0);
......@@ -151,41 +155,41 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0);
{
std::cout << "a_gk0_gm0_gm10_gm11_gk1_grid_desc{"
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I1) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I2) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I3) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I4) << "}" << std::endl;
std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{"
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl;
std::cout << "b_gk0_gn0_gn10_gn11_gk1_grid_desc{"
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I0) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I1) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I2) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I3) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I4) << "}" << std::endl;
std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{"
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl;
std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl;
std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl;
}
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
const auto kernel = kernel_dynamic_contraction_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
true,
true>;
......@@ -198,21 +202,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
const auto kernel = kernel_dynamic_contraction_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
true,
false>;
......@@ -225,21 +229,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
const auto kernel = kernel_dynamic_contraction_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
false,
true>;
......@@ -252,21 +256,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
else
{
const auto kernel = kernel_dynamic_contraction_v1r1<
const auto kernel = kernel_dynamic_contraction_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
false,
false>;
......@@ -279,10 +283,10 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
return ave_time;
......
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1
#define CK_DRIVER_DYNAMIC_GEMM_V1
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_v1r1.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global,
const FloatAB* p_b_global,
FloatC* p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
const auto K = a_k_m_global_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// GEMM
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGlobalDesc,
BGlobalDesc,
CGlobalDesc,
CBlockClusterDesc,
MPerBlock,
NPerBlock,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks>;
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
return ave_time;
#endif
}
} // namespace ck
#endif
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
......@@ -10,41 +10,44 @@ namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t N0_,
typename... Wei,
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
......@@ -58,67 +61,68 @@ transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad(
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gk_gm0_gm1_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
make_pass_through_transform(C * Y * X)),
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1, 2>{}, Sequence<0>{}));
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
constexpr auto N0 = Number<N0_>{};
const auto N1 = N / N0;
const auto in_n0_n1_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(C),
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}));
const auto in_gk_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
in_n0_n1_c_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(N0),
make_merge_transform(make_tuple(N1, Ho, Wo))),
make_tuple(Sequence<2, 3, 5>{}, Sequence<0>{}, Sequence<1, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_n_k_howo_grid_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo));
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor(
out_n_k_howo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(Number<N0>{}, N1)),
make_unmerge_transform(make_tuple(I1, K)),
make_pass_through_transform(Ho * Wo)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
out_n0_n1_1_k_howo_grid_desc,
make_tuple(make_pass_through_transform(I1),
make_pass_through_transform(K),
make_pass_through_transform(Number<N0>{}),
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(
wei_gk_gm0_gm1_grid_desc, in_gk_gn0_gn1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc);
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} // namespace ck
......
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
......@@ -17,10 +17,10 @@ template <typename... Wei,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t N0Value,
index_t C0Value>
typename N0Type,
typename C0Type>
__host__ __device__ constexpr auto
transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
......@@ -28,8 +28,8 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<N0Value>,
Number<C0Value>)
const N0Type& N0,
const C0Type& C0)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -61,9 +61,6 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
constexpr auto N0 = Number<N0Value>{};
constexpr auto C0 = Number<C0Value>{};
const auto N1 = N / N0;
const auto C1 = C / C0;
......@@ -109,7 +106,7 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor(
out_n_k_howo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(Number<N0>{}, N1)),
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_unmerge_transform(make_tuple(I1, K)),
make_pass_through_transform(Ho * Wo)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
......@@ -119,7 +116,7 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
out_n0_n1_1_k_howo_grid_desc,
make_tuple(make_pass_through_transform(I1),
make_pass_through_transform(K),
make_pass_through_transform(Number<N0>{}),
make_pass_through_transform(N0),
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
......
......@@ -4,7 +4,7 @@
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_gemm_v2.hpp"
#include "threadwise_contraction.hpp"
namespace ck {
......
......@@ -4,43 +4,43 @@
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
#include "threadwise_gemm_v2.hpp"
#include "threadwise_contraction.hpp"
namespace ck {
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. AK0MK1BlockDesc is known at compile-time
// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. BK0NK1BlockDesc is known at compile-time
// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CM0M1N0N1ThreadDesc is known at compile-time
// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t M1PerThreadM11,
index_t N1PerThreadN11,
index_t KPerThread,
index_t M1N1ThreadClusterM100,
index_t M1N1ThreadClusterN100,
index_t M1N1ThreadClusterM101,
index_t M1N1ThreadClusterN101,
index_t AThreadCopyScalarPerVector_M11,
index_t BThreadCopyScalarPerVector_N11,
typename std::enable_if<AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
typename ABlockDesc_BK0_BM_BK1,
typename BBlockDesc_BK0_BN_BK1,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
index_t BM10BN10ThreadClusterBM100,
index_t BM10BN10ThreadClusterBN100,
index_t BM10BN10ThreadClusterBM101,
index_t BM10BN10ThreadClusterBN101,
index_t AThreadCopyScalarPerVector_BM11,
index_t BThreadCopyScalarPerVector_BN11,
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
......@@ -51,138 +51,144 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t M = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t N = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
static constexpr index_t M100 = M1N1ThreadClusterM100;
static constexpr index_t N100 = M1N1ThreadClusterN100;
static constexpr index_t BM100 = BM10BN10ThreadClusterBM100;
static constexpr index_t BN100 = BM10BN10ThreadClusterBN100;
static constexpr index_t M101 = M1N1ThreadClusterM101;
static constexpr index_t N101 = M1N1ThreadClusterN101;
static constexpr index_t BM101 = BM10BN10ThreadClusterBM101;
static constexpr index_t BN101 = BM10BN10ThreadClusterBN101;
static constexpr index_t M11 = M1PerThreadM11;
static constexpr index_t N11 = N1PerThreadN11;
static constexpr index_t BM11 = BM1PerThreadBM11;
static constexpr index_t BN11 = BN1PerThreadBN11;
static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
static constexpr index_t BM1 =
BM10BN10ThreadClusterBM100 * BM10BN10ThreadClusterBM101 * BM1PerThreadBM11;
static constexpr index_t BN1 =
BM10BN10ThreadClusterBN100 * BM10BN10ThreadClusterBN101 * BN1PerThreadBN11;
static constexpr index_t M0 = M / M1;
static constexpr index_t N0 = N / N1;
static constexpr index_t BM0 = BM / BM1;
static constexpr index_t BN0 = BN / BN1;
__host__ __device__ static constexpr auto
MakeAK0M0M1K1BlockDescriptor(const AK0MK1BlockDesc& a_k0_m_k1_block_desc)
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
{
const auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<K0>{}),
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{})),
make_pass_through_transform(Number<K1>{})),
const auto a_block_bk0_bm0_bm1_bk1 = transform_dynamic_tensor_descriptor(
a_block_desc_bk0_bm_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_k0_m0_m1_k1_block_desc;
return a_block_bk0_bm0_bm1_bk1;
}
__host__ __device__ static constexpr auto
MakeBK0N0N1K1BlockDescriptor(const BK0NK1BlockDesc& b_k0_n_k1_block_desc)
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
{
const auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<K0>{}),
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{})),
make_pass_through_transform(Number<K1>{})),
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_dynamic_tensor_descriptor(
b_block_desc_bk0_bn_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_k0_n0_n1_k1_block_desc;
return b_block_desc_bk0_bn0_bn1_bk1;
}
__host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
__host__ __device__ static constexpr auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M, N]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM, BN]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(
Number<M0>{}, Number<M100>{}, Number<M101>{}, Number<M11>{})),
Number<BM0>{}, Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_unmerge_transform(make_tuple(
Number<N0>{}, Number<N100>{}, Number<N101>{}, Number<N11>{}))),
Number<BN0>{}, Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
}
__host__ __device__ static constexpr auto
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M0, M1, N0, N1]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM0, BM1, BN0, BN1]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(Number<M0>{}),
make_tuple(make_pass_through_transform(Number<BM0>{}),
make_unmerge_transform(
make_tuple(Number<M100>{}, Number<M101>{}, Number<M11>{})),
make_pass_through_transform(Number<N0>{}),
make_tuple(Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_pass_through_transform(Number<BN0>{}),
make_unmerge_transform(
make_tuple(Number<N100>{}, Number<N101>{}, Number<N11>{}))),
make_tuple(Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
}
__host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
__host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
{
return Sequence<M0, M11, N0, N11>{};
return Sequence<BM0, BM11, BN0, BN11>{};
}
static constexpr auto a_k0_m0_m1_k1_block_desc_ =
MakeAK0M0M1K1BlockDescriptor(AK0MK1BlockDesc{});
static constexpr auto b_k0_n0_n1_k1_block_desc_ =
MakeBK0N0N1K1BlockDescriptor(BK0NK1BlockDesc{});
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
public:
__device__ BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2()
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
__device__ BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)}
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == M101 * M100 * N101 * N100,
static_assert(BlockSize == BM101 * BM100 * BN101 * BN100,
"wrong! blocksize and cluster size not consistent");
static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0),
static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
"wrong! K dimension not consistent");
// TODO: remove this restriction
static_assert(M0 == 2 && N0 == 2, "wrong");
static_assert(BM0 == 2 && BN0 == 2, "wrong");
}
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
{
// lower: [M0, M1, N0, N1]
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
// lower: [BM0, BM1, BN0, BN1]
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
constexpr auto adaptor0 =
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1();
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
// upper: [Tid, M0, M11, N0, N11]
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// upper: [Tid, BM0, BM11, BN0, BN11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
make_pass_through_transform(M0),
make_pass_through_transform(M11),
make_pass_through_transform(N0),
make_pass_through_transform(N11)),
make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)),
make_pass_through_transform(BM0),
make_pass_through_transform(BM11),
make_pass_through_transform(BN0),
make_pass_through_transform(BN11)),
make_tuple(
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
......@@ -192,73 +198,75 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0));
}
template <typename CM0M1N0N1ThreadDesc,
template <typename CThreadDesc_BM0_BM11_BN0_BN11,
typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc,
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11& c_m0_m1_n0_n1_thread_desc,
const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: remove this restriction
static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
static_assert(BM0 == 2 && BN0 == 2 &&
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 &&
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
"wrong");
auto a_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatA>(
a_k0_m0_m1_k1_thread_desc_.GetElementSpaceSize());
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatB>(
b_k0_n0_n1_k1_thread_desc_.GetElementSpaceSize());
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
constexpr auto threadwise_gemm =
ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1<FloatA,
constexpr auto threadwise_contraction =
ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
FloatA,
FloatB,
FloatC,
decltype(a_k0_m0_m1_k1_thread_desc_),
decltype(b_k0_n0_n1_k1_thread_desc_),
CM0M1N0N1ThreadDesc,
Sequence<KPerThread, K1>,
Sequence<1, M1PerThreadM11>,
Sequence<1, N1PerThreadN11>>{};
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
CThreadDesc_BM0_BM11_BN0_BN11,
Sequence<BK0PerThread, BK1>,
Sequence<1, BM1PerThreadBM11>,
Sequence<1, BN1PerThreadBN11>>{};
// read A_sub_0
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I1, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I1, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
......@@ -266,25 +274,25 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
// loop over rest of k
static_for<KPerThread, K0, KPerThread>{}([&](auto k) {
// loop over rest of bk0
static_for<BK0PerThread, BK0, BK0PerThread>{}([&](auto bk0) {
// read A_sub_0
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
make_tuple(k, I0, I0, I0),
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(bk0, I0, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
......@@ -292,15 +300,15 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
make_tuple(I1, I0, I0, I0));
// read B_sub_0
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
make_tuple(k, I0, I0, I0),
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(bk0, I0, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
......@@ -308,23 +316,23 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
make_tuple(I1, I0, I1, I0));
// read B_sub_1
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
make_tuple(k, I1, I0, I0),
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(bk0, I1, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
make_tuple(k, I1, I0, I0),
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(bk0, I1, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
......@@ -332,7 +340,7 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
......@@ -341,7 +349,7 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
......@@ -349,7 +357,7 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
make_tuple(I1, I0, I0, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
......@@ -358,34 +366,34 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
}
private:
// A[K0, M0, M1, K1]
static constexpr auto a_k0_m0_m1_k1_thread_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}, Number<K1>{}));
// A[BK0, BM0, BM1, BK1]
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
// B[K0, N0, N1, K1]
static constexpr auto b_k0_n0_n1_k1_thread_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}, Number<K1>{}));
// B[BK0, BN0, BN1, BK1]
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
FloatA,
FloatA,
decltype(a_k0_m0_m1_k1_block_desc_),
decltype(a_k0_m0_m1_k1_thread_desc_),
Sequence<KPerThread, 1, M1PerThreadM11, K1>, // SliceLengths
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
Sequence<BK0PerThread, 1, BM1PerThreadBM11, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, M1PerThreadM11, K1>, // SrcVectorTensorLengths
Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
FloatB,
FloatB,
decltype(b_k0_n0_n1_k1_block_desc_),
decltype(b_k0_n0_n1_k1_thread_desc_),
Sequence<KPerThread, 1, N1PerThreadN11, K1>, // SliceLengths
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
Sequence<BK0PerThread, 1, BN1PerThreadBN11, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, N1PerThreadN11, K1>, // SrcVectorTensorLengths
Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
CIndex c_thread_origin_data_idx_;
......
#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R1_HPP
#define CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R1_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2r2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
template <typename GridwiseContraction,
typename FloatAB,
typename FloatC,
typename AGKGM0GM10GM11GridDesc,
typename BGKGN0GN10GN11GridDesc,
typename CGM10BM0BM1GN10BN0BN1GridDesc,
typename CBlockIdToGM10GN10BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_contraction_v1r1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGKGM0GM10GM11GridDesc a_gk_gm0_gm10_gm11_grid_desc,
const BGKGN0GN10GN11GridDesc b_gk_gn0_gn10_gn11_grid_desc,
const CGM10BM0BM1GN10BN0BN1GridDesc c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
const CBlockIdToGM10GN10BlockClusterAdaptor
c_blockid_to_gm10_gn10_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseContraction::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGKGM0GM1GridDesc,
typename BGKGN0GN1GridDesc,
typename CGM0GM1GN0GN1GridDesc,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t KPerBlock,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
index_t M11N11ThreadClusterM1100,
index_t M11N11ThreadClusterN1100,
index_t M11N11ThreadClusterM1101,
index_t M11N11ThreadClusterN1101,
typename ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
typename ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_GM11,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
typename BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_GN11,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// GM0 and GN0 need to known at compile-time
static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0);
static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_GM11>{},
Number<BBlockTransferDstScalarPerVector_GN11>{},
Number<M1PerThreadM111>{},
Number<N1PerThreadN111>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gk_gm0_gm10_gm11_block_desc =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gk_gn0_gn10_gn11_block_desc =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc,
const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc,
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
"wrong! GM0 and GN0 need to be known at compile-time");
const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2);
const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2);
const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return ((GM0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I0) &&
GM1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1) &&
GN0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I2) &&
GN1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3) &&
GM0 == a_gk_gm0_gm1_grid_desc.GetLength(I1) &&
GM1 == a_gk_gm0_gm1_grid_desc.GetLength(I2) &&
GN0 == b_gk_gn0_gn1_grid_desc.GetLength(I1) &&
GN1 == b_gk_gn0_gn1_grid_desc.GetLength(I2) &&
GK == b_gk_gn0_gn1_grid_desc.GetLength(I0)) &&
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK % KPerBlock == 0));
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
constexpr index_t GM11 = GM1PerBlockGM11;
constexpr index_t GN11 = GN1PerBlockGN11;
const index_t GM10 = GM1 / GM11;
const index_t GN10 = GN1 / GN11;
const index_t grid_size = GM10 * GN10;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK)
{
const bool has_main_k_block_loop = (GK + KPerBlock) / (2 * KPerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK)
{
const bool has_double_tail_k_block_loop = (GK / KPerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAGKGM0GM10GM11GridDescriptor(const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc)
{
const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0);
const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2);
const auto GM11 = Number<GM1PerBlockGM11>{};
const auto GM10 = GM1 / GM11;
const auto a_gk_gm0_gm10_gm11_grid_desc = transform_dynamic_tensor_descriptor(
a_gk_gm0_gm1_grid_desc,
make_tuple(make_pass_through_transform(GK),
make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
return a_gk_gm0_gm10_gm11_grid_desc;
}
__host__ __device__ static constexpr auto
MakeBGKGN0GN10GN11GridDescriptor(const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc)
{
const auto GK = b_gk_gn0_gn1_grid_desc.GetLength(I0);
const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2);
const auto GN11 = Number<GN1PerBlockGN11>{};
const auto GN10 = GN1 / GN11;
const auto b_gk_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
b_gk_gn0_gn1_grid_desc,
make_tuple(make_pass_through_transform(GK),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
return b_gk_gn0_gn10_gn11_grid_desc;
}
__host__ __device__ static constexpr auto MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
constexpr auto BM = GM0 * GM11;
constexpr auto BN = GN0 * GN11;
constexpr auto BM1 =
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
constexpr auto BN1 =
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
constexpr auto BM0 = BM / BM1;
constexpr auto BN0 = BN / BN1;
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
c_gm0_gm1_gn0_gn1_grid_desc,
make_tuple(make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor(
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_merge_transform(make_tuple(GM0, GM11)),
make_pass_through_transform(GN10),
make_merge_transform(make_tuple(GN0, GN11))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = transform_dynamic_tensor_descriptor(
c_gm10_bm_gn10_bn_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_unmerge_transform(make_tuple(BM0, BM1)),
make_pass_through_transform(GN10),
make_unmerge_transform(make_tuple(BN0, BN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
return c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc;
}
__host__ __device__ static constexpr auto MakeCBlockIdToGM10GN10BlockClusterAdaptor(
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_gm10_gn10_block_cluster_adaptor;
}
using AGKGM0GM10GM11GridDesc = decltype(MakeAGKGM0GM10GM11GridDescriptor(AGKGM0GM1GridDesc{}));
using BGKGN0GN10GN11GridDesc = decltype(MakeBGKGN0GN10GN11GridDescriptor(BGKGN0GN1GridDesc{}));
using CGM10BM0BM1GN10BN0BN1GridDesc =
decltype(MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(CGM0GM1GN0GN1GridDesc{}));
using CBlockIdToGM10GN10BlockClusterAdaptor =
decltype(MakeCBlockIdToGM10GN10BlockClusterAdaptor(CGM0GM1GN0GN1GridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AGKGM0GM10GM11GridDesc& a_gk_gm0_gm10_gm11_grid_desc,
const BGKGN0GN10GN11GridDesc& b_gk_gn0_gn10_gn11_grid_desc,
const CGM10BM0BM1GN10BN0BN1GridDesc& c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
const CBlockIdToGM10GN10BlockClusterAdaptor& c_blockid_to_gm10_gn10_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_grid, a_gk_gm0_gm10_gm11_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_gk_gn0_gn10_gn11_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetElementSpaceSize());
const auto GK = a_gk_gm0_gm10_gm11_grid_desc.GetLength(I0);
// divide block work by [GM10, GN10]
const auto c_gm10_gn10_block_cluster_idx =
c_blockid_to_gm10_gn10_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]);
const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]);
// lds max alignment
// part of them should be moved into blockwise-gemm
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_GM11>{},
Number<BBlockTransferDstScalarPerVector_GN11>{},
Number<M1PerThreadM111>{},
Number<N1PerThreadN111>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gk_gm0_gm10_gm11_block_desc =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gk_gn0_gn10_gn11_block_desc =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}), max_lds_align);
// A matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto a_gk_bm_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}), max_lds_align);
// B matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto b_gk_bn_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4<
BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, GM0, 1, GM1PerBlockGM11>,
ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_gk_gm0_gm10_gm11_grid_desc),
decltype(a_gk_gm0_gm10_gm11_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_GM11,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_gk_gm0_gm10_gm11_grid_desc,
make_multi_index(0, 0, igm10, 0),
a_gk_gm0_gm10_gm11_block_desc,
make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4<
BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, GN0, 1, GN1PerBlockGN11>,
BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_gk_gn0_gn10_gn11_grid_desc),
decltype(b_gk_gn0_gn10_gn11_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_GN11,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_gk_gn0_gn10_gn11_grid_desc,
make_multi_index(0, 0, ign10, 0),
b_gk_gn0_gn10_gn11_block_desc,
make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, GM1PerBlockGM11] is in LDS
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_gk_bm_block_desc),
decltype(b_gk_bn_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM1100,
M11N11ThreadClusterN1100,
M11N11ThreadClusterM1101,
M11N11ThreadClusterN1101,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_bm0_bm1_bn0_bn1_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_bm0_bm1_bn0_bn1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(c_bm0_bm1_bn0_bn1_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_bm0_bm1_bn0_bn1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_bm0_bm1_bn0_bn1_thread_desc),
decltype(c_bm0_bm1_bn0_bn1_thread_tensor_lengths)>{}
.Run(c_bm0_bm1_bn0_bn1_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{};
constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack =
AGridMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack =
BGridMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(
a_gk_gm0_gm10_gm11_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(
b_gk_gn0_gn10_gn11_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_bm0_bm1_bn0_bn1_thread_desc,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(
a_gk_gm0_gm10_gm11_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(
b_gk_gn0_gn10_gn11_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < GK - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_gk_gm0_gm10_gm11_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_gk_gn0_gn10_gn11_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr index_t M11 =
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
constexpr index_t N11 =
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
constexpr index_t M10 = GM1PerBlockGM11 / M11;
constexpr index_t N10 = GN1PerBlockGN11 / N11;
constexpr index_t M111 = M1PerThreadM111;
constexpr index_t N111 = N1PerThreadN111;
constexpr auto c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1,
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0]>{},
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1]>{},
I1,
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2]>{},
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>{}));
const auto c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc),
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc),
Sequence<1,
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0],
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1],
1,
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2],
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
make_multi_index(igm10,
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I0],
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I1],
ign10,
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I2],
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I3])}
.Run(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_grid_buf,
CGridIteratorHacks{});
}
}
};
} // namespace ck
#endif
......@@ -5,8 +5,8 @@
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2r2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "blockwise_gemm_v2r3.hpp"
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
......@@ -15,10 +15,10 @@ namespace ck {
template <typename GridwiseContraction,
typename FloatAB,
typename FloatC,
typename AGK0GM0GM10GM11GK1GridDesc,
typename BGK0GN0GN10GN11GK1GridDesc,
typename CGM10BM0BM1GN10BN0BN1GridDesc,
typename CBlockIdToGM10GN10BlockClusterAdaptor,
typename AGridDesc_GK0_GM0_GM10_GM11_GK1,
typename BGridDesc_GK0_GN0_GN10_GN11_GK1,
typename CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1,
typename CGridBlockCluster_BlockId_To_GM10_GN10,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
......@@ -29,11 +29,10 @@ __global__ void
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGK0GM0GM10GM11GK1GridDesc a_gk0_gm0_gm10_gm11_gk1_grid_desc,
const BGK0GN0GN10GN11GK1GridDesc b_gk0_gn0_gn10_gn11_gk1_grid_desc,
const CGM10BM0BM1GN10BN0BN1GridDesc c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
const CBlockIdToGM10GN10BlockClusterAdaptor
c_blockid_to_gm10_gn10_block_cluster_adaptor)
const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1,
const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1,
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10)
{
constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -44,10 +43,10 @@ __global__ void
p_b_grid,
p_c_grid,
p_shared_block,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
......@@ -57,19 +56,19 @@ template <index_t BlockSize,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGK0GM0GM1GK1GridDesc,
typename BGK0GN0GN1GK1GridDesc,
typename CGM0GM1GN0GN1GridDesc,
typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t KPerBlock,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
index_t M11N11ThreadClusterM1100,
index_t M11N11ThreadClusterN1100,
index_t M11N11ThreadClusterM1101,
index_t M11N11ThreadClusterN1101,
index_t GK0PerBlock,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
index_t BM10BN10ThreadClusterBM100,
index_t BM10BN10ThreadClusterBN100,
index_t BM10BN10ThreadClusterBM101,
index_t BM10BN10ThreadClusterBN101,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder,
......@@ -92,7 +91,7 @@ template <index_t BlockSize,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -100,9 +99,9 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
static constexpr auto I3 = Number<3>{};
// GM0 and GN0 need to known at compile-time
static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0);
static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2);
static constexpr auto GK1 = AGK0GM0GM1GK1GridDesc{}.GetLength(I3);
static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0);
static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2);
static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
......@@ -113,61 +112,62 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc =
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc =
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc,
const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc,
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
"wrong! GM0 and GN0 need to be known at compile-time");
const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2);
const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2);
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return ((GM0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I0) &&
GM1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1) &&
GN0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I2) &&
GN1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3) &&
GM0 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I1) &&
GM1 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2) &&
GN0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I1) &&
GN1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2) &&
GK0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0) &&
GK1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I3)) &&
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % KPerBlock == 0));
return (
(GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) &&
GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) &&
GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) &&
GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) &&
GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) &&
GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) &&
GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) &&
GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) &&
GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) &&
GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) &&
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0));
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr index_t GM11 = GM1PerBlockGM11;
constexpr index_t GN11 = GN1PerBlockGN11;
......@@ -182,29 +182,29 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0)
{
const bool has_main_k_block_loop = (GK0 + KPerBlock) / (2 * KPerBlock) > 1;
const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0)
{
const bool has_double_tail_k_block_loop = (GK0 / KPerBlock) % 2 == 0;
const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAGK0GM0GM10GM11GK1GridDescriptor(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc)
__host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1)
{
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2);
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
const auto GM11 = Number<GM1PerBlockGM11>{};
const auto GM10 = GM1 / GM11;
const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc = transform_dynamic_tensor_descriptor(
a_gk0_gm0_gm1_gk1_grid_desc,
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_dynamic_tensor_descriptor(
a_grid_desc_gk0_gm0_gm1_gk1,
make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
......@@ -212,20 +212,20 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return a_gk0_gm0_gm10_gm11_gk1_grid_desc;
return a_grid_desc_gk0_gm0_gm10_gm11_gk1;
}
__host__ __device__ static constexpr auto
MakeBGK0GN0GN10GN11GK1GridDescriptor(const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc)
__host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1)
{
const auto GK0 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0);
const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2);
const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0);
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
const auto GN11 = Number<GN1PerBlockGN11>{};
const auto GN10 = GN1 / GN11;
const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc = transform_dynamic_tensor_descriptor(
b_gk0_gn0_gn1_gk1_grid_desc,
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_dynamic_tensor_descriptor(
b_grid_desc_gk0_gn0_gn1_gk1,
make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11)),
......@@ -233,14 +233,14 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return b_gk0_gn0_gn10_gn11_gk1_grid_desc;
return b_grid_desc_gk0_gn0_gn10_gn11_gk1;
}
__host__ __device__ static constexpr auto MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
__host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
......@@ -252,15 +252,15 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
constexpr auto BN = GN0 * GN11;
constexpr auto BM1 =
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
Number<BM10BN10ThreadClusterBM100 * BM10BN10ThreadClusterBM101 * BM1PerThreadBM11>{};
constexpr auto BN1 =
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
Number<BM10BN10ThreadClusterBN100 * BM10BN10ThreadClusterBN101 * BN1PerThreadBN11>{};
constexpr auto BM0 = BM / BM1;
constexpr auto BN0 = BN / BN1;
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
c_gm0_gm1_gn0_gn1_grid_desc,
c_grid_desc_gm0_gm1_gn0_gn1,
make_tuple(make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
make_pass_through_transform(GN0),
......@@ -277,7 +277,7 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = transform_dynamic_tensor_descriptor(
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_dynamic_tensor_descriptor(
c_gm10_bm_gn10_bn_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_unmerge_transform(make_tuple(BM0, BM1)),
......@@ -286,14 +286,14 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
return c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc;
return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1;
}
__host__ __device__ static constexpr auto MakeCBlockIdToGM10GN10BlockClusterAdaptor(
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
__host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10(
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
......@@ -301,22 +301,22 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = make_single_stage_tensor_adaptor(
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_gm10_gn10_block_cluster_adaptor;
return c_grid_block_cluster_blockid_to_gm10_gn10;
}
using AGK0GM0GM10GM11GK1GridDesc =
decltype(MakeAGK0GM0GM10GM11GK1GridDescriptor(AGK0GM0GM1GK1GridDesc{}));
using BGK0GN0GN10GN11GK1GridDesc =
decltype(MakeBGK0GN0GN10GN11GK1GridDescriptor(BGK0GN0GN1GK1GridDesc{}));
using CGM10BM0BM1GN10BN0BN1GridDesc =
decltype(MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(CGM0GM1GN0GN1GridDesc{}));
using CBlockIdToGM10GN10BlockClusterAdaptor =
decltype(MakeCBlockIdToGM10GN10BlockClusterAdaptor(CGM0GM1GN0GN1GridDesc{}));
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{}));
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{}));
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{}));
using CGridBlockCluster_BlockId_To_GM10_GN10 =
decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
......@@ -324,25 +324,25 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AGK0GM0GM10GM11GK1GridDesc& a_gk0_gm0_gm10_gm11_gk1_grid_desc,
const BGK0GN0GN10GN11GK1GridDesc& b_gk0_gn0_gn10_gn11_gk1_grid_desc,
const CGM10BM0BM1GN10BN0BN1GridDesc& c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
const CBlockIdToGM10GN10BlockClusterAdaptor& c_blockid_to_gm10_gn10_block_cluster_adaptor,
const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1,
const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1,
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_grid, a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetElementSpaceSize());
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetElementSpaceSize());
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetElementSpaceSize());
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
const auto GK0 = a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0);
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
// divide block work by [GM10, GN10]
const auto c_gm10_gn10_block_cluster_idx =
c_blockid_to_gm10_gn10_block_cluster_adaptor.CalculateBottomIndex(
c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
......@@ -356,46 +356,46 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc =
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc =
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align);
// A matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto a_gk0_bm_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
constexpr auto a_block_desc_gk0_bm_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
// B matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto b_gk0_bn_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
constexpr auto b_block_desc_gk0_bn_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
static_assert(a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize() ==
a_gk0_bm_gk1_block_desc.GetElementSpaceSize() &&
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize() ==
b_gk0_bn_gk1_block_desc.GetElementSpaceSize(),
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
a_block_desc_gk0_bm_gk1.GetElementSpaceSize() &&
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() ==
b_block_desc_gk0_bn_gk1.GetElementSpaceSize(),
"wrong!");
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc),
decltype(a_gk0_gm0_gm10_gm11_gk1_block_desc),
decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1),
decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths
......@@ -403,23 +403,23 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
make_multi_index(0, 0, igm10, 0, 0),
a_gk0_gm0_gm10_gm11_gk1_block_desc,
a_block_desc_gk0_gm0_gm10_gm11_gk1,
make_multi_index(0, 0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc),
decltype(b_gk0_gn0_gn10_gn11_gk1_block_desc),
decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1),
decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths
......@@ -427,102 +427,103 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
make_multi_index(0, 0, ign10, 0, 0),
b_gk0_gn0_gn10_gn11_gk1_block_desc,
b_block_desc_gk0_gn0_gn10_gn11_gk1,
make_multi_index(0, 0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, GM1PerBlockGM11] is in LDS
// a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2<BlockSize,
BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_gk0_bm_gk1_block_desc),
decltype(b_gk0_bn_gk1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM1100,
M11N11ThreadClusterN1100,
M11N11ThreadClusterM1101,
M11N11ThreadClusterN1101,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_bm0_bm1_bn0_bn1_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_bm0_bm1_bn0_bn1_thread_desc =
decltype(a_block_desc_gk0_bm_gk1),
decltype(b_block_desc_gk0_bn_gk1),
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM100,
BM10BN10ThreadClusterBN100,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
BM1PerThreadBM11,
BN1PerThreadBN11>{};
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 =
make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(c_bm0_bm1_bn0_bn1_thread_tensor_lengths));
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_bm0_bm1_bn0_bn1_thread_desc.GetElementSpaceSize());
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_bm0_bm1_bn0_bn1_thread_desc),
decltype(c_bm0_bm1_bn0_bn1_thread_tensor_lengths)>{}
.Run(c_bm0_bm1_bn0_bn1_thread_desc,
decltype(c_thread_desc_bm0_bm1_bn0_bn1),
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
.Run(c_thread_desc_bm0_bm1_bn0_bn1,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0);
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize());
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize());
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize());
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize());
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf);
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
index_t gk0_block_on_grid = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
......@@ -530,25 +531,25 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_bm0_bm1_bn0_bn1_thread_desc,
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf);
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
......@@ -556,29 +557,29 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf);
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < GK0 - 2 * KPerBlock);
gk0_block_on_grid += 2 * GK0PerBlock;
} while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
......@@ -586,23 +587,23 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf);
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
......@@ -610,61 +611,51 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr index_t M11 =
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
constexpr index_t N11 =
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
constexpr index_t M10 = GM1PerBlockGM11 / M11;
constexpr index_t N10 = GN1PerBlockGN11 / N11;
constexpr index_t M111 = M1PerThreadM111;
constexpr index_t N111 = N1PerThreadN111;
constexpr auto c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc =
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1,
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0]>{},
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
I1,
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2]>{},
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>{}));
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>{}));
const auto c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 =
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc),
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc),
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1),
Sequence<1,
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0],
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1],
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0],
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1],
1,
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2],
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>,
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2],
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
make_multi_index(igm10,
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I0],
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I1],
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0],
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1],
ign10,
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I2],
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I3])}
.Run(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc,
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2],
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])}
.Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_buf,
CGridIteratorHacks{});
}
......
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k_m_global_desc,
const BGlobalDesc b_k_n_global_desc,
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc c_block_cluster_desc)
{
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k_m_global_desc,
const void __CONSTANT__* p_b_k_n_global_desc,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m_global_desc =
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k_m_global_desc);
const auto b_k_n_global_desc =
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k_n_global_desc);
const auto c_m0_m1_n0_n1_global_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_n0_n1_global_desc);
const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<M1PerThread>{},
Number<N1PerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_k_m_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k_n_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_n0_n1_global_desc.GetElementSpaceSize());
const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_global into SGPR
const index_t m_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<M1PerThread>{},
Number<N1PerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k_m_global_desc,
make_multi_index(0, m_block_data_idx_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1>,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k_n_global_desc,
make_multi_index(0, n_block_data_idx_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(
MPerBlock % (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10) == 0 &&
NPerBlock % (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10) == 0,
"wrong!");
constexpr index_t M0PerThread =
MPerBlock / (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10);
constexpr index_t N0PerThread =
NPerBlock / (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10);
constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
a_k_m_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<M0PerThread>{},
Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
b_k_n_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<N0PerThread>{},
Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<M0PerThread>{},
Number<M1PerThread>{},
Number<N0PerThread>{},
Number<N1PerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k_m0_m1_block_desc),
decltype(b_k_n0_n1_block_desc),
decltype(c_m0_m1_n0_n1_thread_desc),
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
M1PerThread,
N1PerThread>{};
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<
FloatAcc,
decltype(c_m0_m1_n0_n1_thread_desc),
Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>>{}
.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_k_m_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_k_n_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double + a_block_space_size, a_k_m_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double + b_block_space_size, b_k_n_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM10 * M1N1ThreadClusterM11>{};
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN10 * M1N1ThreadClusterN11>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
const auto c_thread_data_idx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_n0_n1_global_desc,
make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0],
c_thread_data_idx_on_block[I1],
n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2],
c_thread_data_idx_on_block[I3])}
.Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
c_m0_m1_n0_n1_global_desc,
c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
}
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
} // namespace ck
#endif
......@@ -435,7 +435,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2<BlockSize,
BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
......
#ifndef CK_THREADWISE_GEMM_V2_HPP
#define CK_THREADWISE_GEMM_V2_HPP
#ifndef CK_THREADWISE_CONTRACTION_HPP
#define CK_THREADWISE_CONTRACTION_HPP
#include "common_header.hpp"
#include "math.hpp"
namespace ck {
// C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1]
// C[TM0, TM1, TN0, TN1] += A[TK, TM0, TM1] * B[TK, TN0, TN1]
// Tensor element can be vectorized data
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
// known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc,
typename CDesc,
typename KLengths,
typename MLengths,
typename NLengths,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CThreadDesc_TM0_TM1_TN0_TN1,
typename TKLengths,
typename TMLengths,
typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
{
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1()
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
// TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
// CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
// TODO remove this restriction
static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2,
static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!");
}
......@@ -70,28 +74,31 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto K = KLengths{}[I0];
constexpr auto M0 = MLengths{}[I0];
constexpr auto M1 = MLengths{}[I1];
constexpr auto N0 = NLengths{}[I0];
constexpr auto N1 = NLengths{}[I1];
constexpr auto TK = TKLengths{}[I0];
constexpr auto TM0 = TMLengths{}[I0];
constexpr auto TM1 = TMLengths{}[I1];
constexpr auto TN0 = TNLengths{}[I0];
constexpr auto TN1 = TNLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M0, 1>{}([&](auto m0) {
static_for<0, M1, 1>{}([&](auto m1) {
static_for<0, N0, 1>{}([&](auto n0) {
static_for<0, N1, 1>{}([&](auto n1) {
static_for<0, TK, 1>{}([&](auto tk) {
static_for<0, TM0, 1>{}([&](auto tm0) {
static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, TN0, 1>{}([&](auto tn0) {
static_for<0, TN1, 1>{}([&](auto tn1) {
constexpr index_t a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_multi_index(k, m0, m1));
AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk, tm0, tm1));
constexpr index_t b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_multi_index(k, n0, n1));
constexpr index_t c_offset = CDesc{}.CalculateOffset(
c_origin_idx + make_multi_index(m0, m1, n0, n1));
BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk, tn0, tn1));
constexpr index_t c_offset =
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
amd_inner_product_dlop<FloatA, FloatB, FloatC>(
a_buf[Number<a_offset>{}],
......@@ -105,35 +112,39 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
}
};
// C[M0, M1, N0, N1] += A[K0, M0, M1, K1] * B[K0, N0, N1, K1]
// C[TM0, TM1, TN0, TN1] += A[TK0, TM0, TM1, TK1] * B[TK0, TN0, TN1, TK1]
// Tensor element can be vectorized data
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
// known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc,
typename CDesc,
typename KLengths,
typename MLengths,
typename NLengths,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CThreadDesc_TM0_TM1_TN0_TN1,
typename TKLengths,
typename TMLengths,
typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1
struct ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
__device__ constexpr ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1()
__device__ constexpr ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
// TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
// CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
// TODO remove this restriction
static_assert(KLengths::Size() == 2 && MLengths::Size() == 2 && NLengths::Size() == 2,
static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!");
}
......@@ -169,43 +180,45 @@ struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr index_t K0 = KLengths{}[I0];
constexpr index_t K1 = KLengths{}[I1];
constexpr index_t M0 = MLengths{}[I0];
constexpr index_t M1 = MLengths{}[I1];
constexpr index_t N0 = NLengths{}[I0];
constexpr index_t N1 = NLengths{}[I1];
constexpr index_t TK0 = TKLengths{}[I0];
constexpr index_t TK1 = TKLengths{}[I1];
constexpr index_t TM0 = TMLengths{}[I0];
constexpr index_t TM1 = TMLengths{}[I1];
constexpr index_t TN0 = TNLengths{}[I0];
constexpr index_t TN1 = TNLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, K0, 1>{}([&](auto k0) {
static_for<0, M0, 1>{}([&](auto m0) {
static_for<0, M1, 1>{}([&](auto m1) {
static_for<0, N0, 1>{}([&](auto n0) {
static_for<0, N1, 1>{}([&](auto n1) {
static_for<0, TK0, 1>{}([&](auto tk0) {
static_for<0, TM0, 1>{}([&](auto tm0) {
static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, TN0, 1>{}([&](auto tn0) {
static_for<0, TN1, 1>{}([&](auto tn1) {
vector_type<FloatA, K1> a_vec;
vector_type<FloatB, K1> b_vec;
vector_type<FloatA, TK1> a_vec;
vector_type<FloatB, TK1> b_vec;
static_for<0, K1, 1>{}([&](auto k1) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(
a_origin_idx + make_multi_index(k0, m0, m1, k1));
constexpr index_t b_offset = BDesc{}.CalculateOffset(
b_origin_idx + make_multi_index(k0, n0, n1, k1));
static_for<0, TK1, 1>{}([&](auto tk1) {
constexpr index_t a_offset =
AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1));
a_vec.template AsType<FloatA>()(k1) = a_buf[Number<a_offset>{}];
constexpr index_t b_offset =
BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1));
b_vec.template AsType<FloatB>()(k1) = b_buf[Number<b_offset>{}];
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, K1>::type;
using b_vector_t = typename vector_type<FloatB, K1>::type;
using a_vector_t = typename vector_type<FloatA, TK1>::type;
using b_vector_t = typename vector_type<FloatB, TK1>::type;
constexpr index_t c_offset = CDesc{}.CalculateOffset(
c_origin_idx + make_multi_index(m0, m1, n0, n1));
constexpr index_t c_offset =
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
amd_inner_product_dlop<a_vector_t, b_vector_t, FloatC>(
a_vec.template AsType<a_vector_t>()[I0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment