Unverified Commit 12649254 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

reorganize files to prepare for MIOpen integration (#51)

* change olc cmake

* adding online compile to fwd-v4r5r2

* update scripts

* remane fwd-v4r5r2 to fwd-v6r1

* clean up
parent fbdf4332
...@@ -6,14 +6,14 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") ...@@ -6,14 +6,14 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
include(TargetFlags) include(TargetFlags)
include(AddKernels) include(AddKernels)
#c++ ## C++
enable_language(CXX) enable_language(CXX)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_CXX_EXTENSIONS OFF)
message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
#OpenMP ## OpenMP
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
# workaround issue hipcc in rocm3.5 cannot find openmp # workaround issue hipcc in rocm3.5 cannot find openmp
set(OpenMP_CXX "${CMAKE_CXX_COMPILER}") set(OpenMP_CXX "${CMAKE_CXX_COMPILER}")
...@@ -35,56 +35,8 @@ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") ...@@ -35,56 +35,8 @@ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
link_libraries(${OpenMP_gomp_LIBRARY}) link_libraries(${OpenMP_gomp_LIBRARY})
link_libraries(${OpenMP_pthread_LIBRARY}) link_libraries(${OpenMP_pthread_LIBRARY})
#GPU backend ## HIP
if(DEVICE_BACKEND STREQUAL "AMD") find_package(HIP REQUIRED)
find_package(HIP REQUIRED) message(STATUS "Build with HIP ${hip_VERSION}")
endif()
#
include_directories(BEFORE
${PROJECT_SOURCE_DIR}/composable_kernel/include
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
${PROJECT_SOURCE_DIR}/external/half/include
${PROJECT_SOURCE_DIR}/driver/include
${PROJECT_BINARY_DIR}/composable_kernel/include/utility
)
if(DEVICE_BACKEND STREQUAL "AMD")
include_directories(BEFORE
${PROJECT_SOURCE_DIR}/external/rocm/include
)
endif()
if(DEVICE_BACKEND STREQUAL "AMD")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp")
endif()
add_subdirectory(driver)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
message("Compiling options for drivers: ${CMAKE_CXX_FLAGS}")
if(DEVICE_BACKEND STREQUAL "AMD")
set(CONV_V2_SOURCE driver/conv_driver_v2.cpp)
set(CONV_BWD_DATA_V2_SOURCE driver/conv_bwd_data_driver_v2.cpp)
set(CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp)
endif()
add_executable(conv_driver_v2 ${CONV_V2_SOURCE})
add_executable(conv_bwd_data_driver_v2 ${CONV_BWD_DATA_V2_SOURCE})
add_executable(conv_driver_v2_olc ${CONV_V2_OLC_SOURCE})
target_include_directories(conv_driver_v2_olc PRIVATE driver/olCompiling/include/)
target_link_libraries(conv_driver_v2 PRIVATE modConv)
target_link_libraries(conv_bwd_data_driver_v2 PRIVATE modConv)
target_link_libraries(conv_driver_v2_olc PRIVATE modConv)
add_subdirectory(host)
#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP
#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_contraction_v1r1.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGKGM0GM1GridDesc,
typename BGKGN0GN1GridDesc,
typename CGM0GM1GN0GN1GridDesc,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
typename ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_GM11,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
typename BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_GN11,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
__host__ float
driver_dynamic_contraction_v1r1(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc,
const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc,
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
// GEMM
using GridwiseContraction = GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGKGM0GM1GridDesc,
BGKGN0GN1GridDesc,
CGM0GM1GN0GN1GridDesc,
GM1PerBlockGM11,
GN1PerBlockGN11,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_GM11,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_GN11,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
const auto K = a_gk_gm0_gm1_grid_desc.GetLength(I0);
if(!GridwiseContraction::CheckValidity(
a_gk_gm0_gm1_grid_desc, b_gk_gn0_gn1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc))
{
throw std::runtime_error(
"wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting");
}
const auto a_gk_gm0_gm10_gm11_grid_desc =
GridwiseContraction::MakeAGKGM0GM10GM11GridDescriptor(a_gk_gm0_gm1_grid_desc);
const auto b_gk_gn0_gn10_gn11_grid_desc =
GridwiseContraction::MakeBGKGN0GN10GN11GridDescriptor(b_gk_gn0_gn1_grid_desc);
using AGKGM0GM10GM11GridDesc = decltype(a_gk_gm0_gm10_gm11_grid_desc);
using BGKGN0GN10GN11GridDesc = decltype(b_gk_gn0_gn10_gn11_grid_desc);
// c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc);
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc);
// c_blockid_to_gm10_gn10_block_cluster_adaptor
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc);
using CBlockIdToGM10GN10BlockClusterAdaptor =
decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor);
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc);
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(K);
const bool has_double_tail_k_block_loop =
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(K);
{
std::cout << "a_gk_gm0_gm10_gm11_grid_desc{" << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I0)
<< ", " << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I1) << ", "
<< a_gk_gm0_gm10_gm11_grid_desc.GetLength(I2) << ", "
<< a_gk_gm0_gm10_gm11_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "b_gk_gn0_gn10_gn11_grid_desc{" << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I0)
<< ", " << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I1) << ", "
<< b_gk_gn0_gn10_gn11_grid_desc.GetLength(I2) << ", "
<< b_gk_gn0_gn10_gn11_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl;
}
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGKGM0GM10GM11GridDesc>,
remove_reference_t<BGKGN0GN10GN11GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGKGM0GM10GM11GridDesc>,
remove_reference_t<BGKGN0GN10GN11GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGKGM0GM10GM11GridDesc>,
remove_reference_t<BGKGN0GN10GN11GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGKGM0GM10GM11GridDesc>,
remove_reference_t<BGKGN0GN10GN11GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
return ave_time;
}
} // namespace ck
#endif
...@@ -13,19 +13,19 @@ template <index_t BlockSize, ...@@ -13,19 +13,19 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGKGM0GM1GridDesc, typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGKGN0GN1GridDesc, typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGM0GM1GN0GN1GridDesc, typename CGridDesc_GM0_GM1_GN0_GN1,
index_t GM1PerBlockGM11, index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11, index_t GN1PerBlockGN11,
index_t KPerBlock, index_t GK0PerBlock,
index_t M1PerThread, index_t BM1PerThreadBM11,
index_t N1PerThread, index_t BN1PerThreadBN11,
index_t KPerThread, index_t BK0PerThread,
index_t M1N1ThreadClusterM10, index_t BM10BN10ThreadClusterBM100,
index_t M1N1ThreadClusterN10, index_t BM10BN10ThreadClusterBN100,
index_t M1N1ThreadClusterM11, index_t BM10BN10ThreadClusterBM101,
index_t M1N1ThreadClusterN11, index_t BM10BN10ThreadClusterBN101,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -52,9 +52,9 @@ __host__ float ...@@ -52,9 +52,9 @@ __host__ float
driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
const AGKGM0GM1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc, const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
const BGKGN0GN1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc, const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc, const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
...@@ -71,79 +71,83 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, ...@@ -71,79 +71,83 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
constexpr auto I5 = Number<5>{}; constexpr auto I5 = Number<5>{};
// GEMM // GEMM
using GridwiseContraction = GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2< using GridwiseContraction =
BlockSize, GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
FloatAB, BlockSize,
FloatAcc, FloatAB,
FloatC, FloatAcc,
CGlobalMemoryDataOperation, FloatC,
AGKGM0GM1GridDesc, CGlobalMemoryDataOperation,
BGKGN0GN1GridDesc, AGridDesc_GK0_GM0_GM1_GK1,
CGM0GM1GN0GN1GridDesc, BGridDesc_GK0_GN0_GN1_GK1,
GM1PerBlockGM11, CGridDesc_GM0_GM1_GN0_GN1,
GN1PerBlockGN11, GM1PerBlockGM11,
KPerBlock, GN1PerBlockGN11,
M1PerThread, GK0PerBlock,
N1PerThread, BM1PerThreadBM11,
KPerThread, BN1PerThreadBN11,
M1N1ThreadClusterM10, BK0PerThread,
M1N1ThreadClusterN10, BM10BN10ThreadClusterBM100,
M1N1ThreadClusterM11, BM10BN10ThreadClusterBN100,
M1N1ThreadClusterN11, BM10BN10ThreadClusterBM101,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, BM10BN10ThreadClusterBN101,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcAccessOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferSrcAccessOrder,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, ABlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcAccessOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferSrcAccessOrder,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
CThreadTransferSrcDstAccessOrder, BBlockTransferSrcVectorTensorContiguousDimOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstAccessOrder,
CThreadTransferDstScalarPerVector, CThreadTransferSrcDstVectorDim,
AGridIteratorHacks, CThreadTransferDstScalarPerVector,
BGridIteratorHacks, AGridIteratorHacks,
CGridIteratorHacks, BGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, CGridIteratorHacks,
BGridMoveSliceWindowIteratorHacks>; AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0); const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
if(!GridwiseContraction::CheckValidity( if(!GridwiseContraction::CheckValidity(
a_gk0_gm0_gm1_gk1_grid_desc, b_gk0_gn0_gn1_gk1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc)) a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
{ {
throw std::runtime_error( throw std::runtime_error("wrong! "
"wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting"); "GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
"GM0_GM1_GN0_GN1 has invalid setting");
} }
const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc = const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
GridwiseContraction::MakeAGK0GM0GM10GM11GK1GridDescriptor(a_gk0_gm0_gm1_gk1_grid_desc); GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc = const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
GridwiseContraction::MakeBGK0GN0GN10GN11GK1GridDescriptor(b_gk0_gn0_gn1_gk1_grid_desc); GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
using AGK0GM0GM10GM11GK1GridDesc = decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc); using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1);
using BGK0GN0GN10GN11GK1GridDesc = decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc); using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1);
// c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc // c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc); GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
c_grid_desc_gm0_gm1_gn0_gn1);
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc); using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1);
// c_blockid_to_gm10_gn10_block_cluster_adaptor // c_grid_block_cluster_blockid_to_gm10_gn10
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = const auto c_grid_block_cluster_blockid_to_gm10_gn10 =
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc); GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1);
using CBlockIdToGM10GN10BlockClusterAdaptor = using CGridBlockCluster_BlockId_To_GM10_GN10 =
decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor); decltype(c_grid_block_cluster_blockid_to_gm10_gn10);
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc); const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1);
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0); const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0);
...@@ -151,41 +155,41 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, ...@@ -151,41 +155,41 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0); GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0);
{ {
std::cout << "a_gk0_gm0_gm10_gm11_gk1_grid_desc{" std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{"
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0) << ", " << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I1) << ", " << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I2) << ", " << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I3) << ", " << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I4) << "}" << std::endl; << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl;
std::cout << "b_gk0_gn0_gn10_gn11_gk1_grid_desc{" std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{"
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I0) << ", " << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I1) << ", " << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I2) << ", " << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I3) << ", " << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I4) << "}" << std::endl; << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl;
std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ " std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", " << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", " << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", " << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", " << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", " << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl; << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl;
} }
float ave_time = 0; float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_contraction_v1r1< const auto kernel = kernel_dynamic_contraction_v1r2<
GridwiseContraction, GridwiseContraction,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>, remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>, remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>, remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>, remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
true, true,
true>; true>;
...@@ -198,21 +202,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, ...@@ -198,21 +202,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_blockid_to_gm10_gn10_block_cluster_adaptor); c_grid_block_cluster_blockid_to_gm10_gn10);
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_contraction_v1r1< const auto kernel = kernel_dynamic_contraction_v1r2<
GridwiseContraction, GridwiseContraction,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>, remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>, remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>, remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>, remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
true, true,
false>; false>;
...@@ -225,21 +229,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, ...@@ -225,21 +229,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_blockid_to_gm10_gn10_block_cluster_adaptor); c_grid_block_cluster_blockid_to_gm10_gn10);
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_contraction_v1r1< const auto kernel = kernel_dynamic_contraction_v1r2<
GridwiseContraction, GridwiseContraction,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>, remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>, remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>, remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>, remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
false, false,
true>; true>;
...@@ -252,21 +256,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, ...@@ -252,21 +256,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_blockid_to_gm10_gn10_block_cluster_adaptor); c_grid_block_cluster_blockid_to_gm10_gn10);
} }
else else
{ {
const auto kernel = kernel_dynamic_contraction_v1r1< const auto kernel = kernel_dynamic_contraction_v1r2<
GridwiseContraction, GridwiseContraction,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>, remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>, remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>, remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>, remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
false, false,
false>; false>;
...@@ -279,10 +283,10 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, ...@@ -279,10 +283,10 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_blockid_to_gm10_gn10_block_cluster_adaptor); c_grid_block_cluster_blockid_to_gm10_gn10);
} }
return ave_time; return ave_time;
......
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1
#define CK_DRIVER_DYNAMIC_GEMM_V1
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_v1r1.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global,
const FloatAB* p_b_global,
FloatC* p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
const auto K = a_k_m_global_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// GEMM
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGlobalDesc,
BGlobalDesc,
CGlobalDesc,
CBlockClusterDesc,
MPerBlock,
NPerBlock,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks>;
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
return ave_time;
#endif
}
} // namespace ck
#endif
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
...@@ -10,41 +10,44 @@ namespace ck { ...@@ -10,41 +10,44 @@ namespace ck {
// GemmM = K // GemmM = K
// GemmN = N * Ho * Wo // GemmN = N * Ho * Wo
// GemmK = C * Y * X // GemmK = C * Y * X
template <index_t N0_, template <typename... Wei,
typename... Wei,
typename... In, typename... In,
typename... Out, typename... Out,
typename ConvStrides, typename ConvStrides,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad( transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc, const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc, const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc, const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads) const InRightPads& in_right_pads,
Number<GemmK1Value>)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0]; const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1]; const auto ConvStrideW = conv_strides[I1];
...@@ -58,67 +61,68 @@ transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad( ...@@ -58,67 +61,68 @@ transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad(
const auto InRightPadH = in_right_pads[I0]; const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor // weight tensor
const auto wei_gk_gm0_gm1_grid_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
make_tuple(make_unmerge_transform(make_tuple(I1, K)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1, 2>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor // input tensor
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_grid_desc, in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
constexpr auto N0 = Number<N0_>{}; const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
const auto N1 = N / N0; in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
const auto in_n0_n1_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5, 6>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gk_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor( const auto in_gemmk_gemmn_grid_desc =
in_n0_n1_c_y_ho_x_wo_grid_desc, transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)), make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_pass_through_transform(N0), make_merge_transform(make_tuple(N, Ho, Wo))),
make_merge_transform(make_tuple(N1, Ho, Wo))), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<2, 3, 5>{}, Sequence<0>{}, Sequence<1, 4, 6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor // output tensor
const auto out_n_k_howo_grid_desc = const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor( make_tuple(Sequence<0>{}, Sequence<1>{}),
out_n_k_howo_grid_desc, make_tuple(Sequence<1>{}, Sequence<0>{}));
make_tuple(make_unmerge_transform(make_tuple(Number<N0>{}, N1)),
make_unmerge_transform(make_tuple(I1, K)),
make_pass_through_transform(Ho * Wo)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
out_n0_n1_1_k_howo_grid_desc,
make_tuple(make_pass_through_transform(I1),
make_pass_through_transform(K),
make_pass_through_transform(Number<N0>{}),
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return make_tuple( return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
wei_gk_gm0_gm1_grid_desc, in_gk_gn0_gn1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc); in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
} }
} // namespace ck } // namespace ck
......
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
...@@ -17,10 +17,10 @@ template <typename... Wei, ...@@ -17,10 +17,10 @@ template <typename... Wei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads, typename InRightPads,
index_t N0Value, typename N0Type,
index_t C0Value> typename C0Type>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad( transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc, const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc, const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
...@@ -28,8 +28,8 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad( ...@@ -28,8 +28,8 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
Number<N0Value>, const N0Type& N0,
Number<C0Value>) const C0Type& C0)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -61,9 +61,6 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad( ...@@ -61,9 +61,6 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
const auto InRightPadH = in_right_pads[I0]; const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
constexpr auto N0 = Number<N0Value>{};
constexpr auto C0 = Number<C0Value>{};
const auto N1 = N / N0; const auto N1 = N / N0;
const auto C1 = C / C0; const auto C1 = C / C0;
...@@ -109,7 +106,7 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad( ...@@ -109,7 +106,7 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor( const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor(
out_n_k_howo_grid_desc, out_n_k_howo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(Number<N0>{}, N1)), make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_unmerge_transform(make_tuple(I1, K)), make_unmerge_transform(make_tuple(I1, K)),
make_pass_through_transform(Ho * Wo)), make_pass_through_transform(Ho * Wo)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
...@@ -119,7 +116,7 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad( ...@@ -119,7 +116,7 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
out_n0_n1_1_k_howo_grid_desc, out_n0_n1_1_k_howo_grid_desc,
make_tuple(make_pass_through_transform(I1), make_tuple(make_pass_through_transform(I1),
make_pass_through_transform(K), make_pass_through_transform(K),
make_pass_through_transform(Number<N0>{}), make_pass_through_transform(N0),
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))), make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}), make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_gemm_v2.hpp" #include "threadwise_contraction.hpp"
namespace ck { namespace ck {
......
...@@ -435,21 +435,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 ...@@ -435,21 +435,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2<BlockSize, BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
FloatAB, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAB,
decltype(a_k0_m_k1_block_desc), FloatAcc,
decltype(b_k0_n_k1_block_desc), decltype(a_k0_m_k1_block_desc),
M1PerThreadM111, decltype(b_k0_n_k1_block_desc),
N1PerThreadN111, M1PerThreadM111,
KPerThread, N1PerThreadN111,
M11N11ThreadClusterM1100, KPerThread,
M11N11ThreadClusterN1100, M11N11ThreadClusterM1100,
M11N11ThreadClusterM1101, M11N11ThreadClusterN1100,
M11N11ThreadClusterN1101, M11N11ThreadClusterM1101,
M1PerThreadM111, M11N11ThreadClusterN1101,
N1PerThreadN111>{}; M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
......
#ifndef CK_THREADWISE_GEMM_V2_HPP #ifndef CK_THREADWISE_CONTRACTION_HPP
#define CK_THREADWISE_GEMM_V2_HPP #define CK_THREADWISE_CONTRACTION_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "math.hpp" #include "math.hpp"
namespace ck { namespace ck {
// C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1] // C[TM0, TM1, TN0, TN1] += A[TK, TM0, TM1] * B[TK, TN0, TN1]
// Tensor element can be vectorized data // Tensor element can be vectorized data
// Assume: // Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time // 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
// known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time // 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA, template <typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename ADesc, typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BDesc, typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CDesc, typename CThreadDesc_TM0_TM1_TN0_TN1,
typename KLengths, typename TKLengths,
typename MLengths, typename TMLengths,
typename NLengths, typename TNLengths,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1() __device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1()
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
// CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
// TODO remove this restriction // TODO remove this restriction
static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2, static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!"); "wrong!");
} }
...@@ -70,28 +74,31 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 ...@@ -70,28 +74,31 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto K = KLengths{}[I0]; constexpr auto TK = TKLengths{}[I0];
constexpr auto M0 = MLengths{}[I0]; constexpr auto TM0 = TMLengths{}[I0];
constexpr auto M1 = MLengths{}[I1]; constexpr auto TM1 = TMLengths{}[I1];
constexpr auto N0 = NLengths{}[I0]; constexpr auto TN0 = TNLengths{}[I0];
constexpr auto N1 = NLengths{}[I1]; constexpr auto TN1 = TNLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, K, 1>{}([&](auto k) { static_for<0, TK, 1>{}([&](auto tk) {
static_for<0, M0, 1>{}([&](auto m0) { static_for<0, TM0, 1>{}([&](auto tm0) {
static_for<0, M1, 1>{}([&](auto m1) { static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, N0, 1>{}([&](auto n0) { static_for<0, TN0, 1>{}([&](auto tn0) {
static_for<0, N1, 1>{}([&](auto n1) { static_for<0, TN1, 1>{}([&](auto tn1) {
constexpr index_t a_offset = constexpr index_t a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_multi_index(k, m0, m1)); AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk, tm0, tm1));
constexpr index_t b_offset = constexpr index_t b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_multi_index(k, n0, n1)); BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
constexpr index_t c_offset = CDesc{}.CalculateOffset( b_origin_idx + make_multi_index(tk, tn0, tn1));
c_origin_idx + make_multi_index(m0, m1, n0, n1)); constexpr index_t c_offset =
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
amd_inner_product_dlop<FloatA, FloatB, FloatC>( amd_inner_product_dlop<FloatA, FloatB, FloatC>(
a_buf[Number<a_offset>{}], a_buf[Number<a_offset>{}],
...@@ -105,35 +112,39 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 ...@@ -105,35 +112,39 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
} }
}; };
// C[M0, M1, N0, N1] += A[K0, M0, M1, K1] * B[K0, N0, N1, K1] // C[TM0, TM1, TN0, TN1] += A[TK0, TM0, TM1, TK1] * B[TK0, TN0, TN1, TK1]
// Tensor element can be vectorized data // Tensor element can be vectorized data
// Assume: // Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time // 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
// known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time // 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA, template <typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename ADesc, typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BDesc, typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CDesc, typename CThreadDesc_TM0_TM1_TN0_TN1,
typename KLengths, typename TKLengths,
typename MLengths, typename TMLengths,
typename NLengths, typename TNLengths,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1 struct ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{ {
__device__ constexpr ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1() __device__ constexpr ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
// CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
// TODO remove this restriction // TODO remove this restriction
static_assert(KLengths::Size() == 2 && MLengths::Size() == 2 && NLengths::Size() == 2, static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!"); "wrong!");
} }
...@@ -169,43 +180,45 @@ struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1 ...@@ -169,43 +180,45 @@ struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr index_t K0 = KLengths{}[I0]; constexpr index_t TK0 = TKLengths{}[I0];
constexpr index_t K1 = KLengths{}[I1]; constexpr index_t TK1 = TKLengths{}[I1];
constexpr index_t M0 = MLengths{}[I0]; constexpr index_t TM0 = TMLengths{}[I0];
constexpr index_t M1 = MLengths{}[I1]; constexpr index_t TM1 = TMLengths{}[I1];
constexpr index_t N0 = NLengths{}[I0]; constexpr index_t TN0 = TNLengths{}[I0];
constexpr index_t N1 = NLengths{}[I1]; constexpr index_t TN1 = TNLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, K0, 1>{}([&](auto k0) { static_for<0, TK0, 1>{}([&](auto tk0) {
static_for<0, M0, 1>{}([&](auto m0) { static_for<0, TM0, 1>{}([&](auto tm0) {
static_for<0, M1, 1>{}([&](auto m1) { static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, N0, 1>{}([&](auto n0) { static_for<0, TN0, 1>{}([&](auto tn0) {
static_for<0, N1, 1>{}([&](auto n1) { static_for<0, TN1, 1>{}([&](auto tn1) {
vector_type<FloatA, K1> a_vec; vector_type<FloatA, TK1> a_vec;
vector_type<FloatB, K1> b_vec; vector_type<FloatB, TK1> b_vec;
static_for<0, K1, 1>{}([&](auto k1) { static_for<0, TK1, 1>{}([&](auto tk1) {
constexpr index_t a_offset = ADesc{}.CalculateOffset( constexpr index_t a_offset =
a_origin_idx + make_multi_index(k0, m0, m1, k1)); AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1));
constexpr index_t b_offset = BDesc{}.CalculateOffset( constexpr index_t b_offset =
b_origin_idx + make_multi_index(k0, n0, n1, k1)); BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1));
a_vec.template AsType<FloatA>()(k1) = a_buf[Number<a_offset>{}]; a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
b_vec.template AsType<FloatB>()(k1) = b_buf[Number<b_offset>{}];
}); });
using a_vector_t = typename vector_type<FloatA, K1>::type; using a_vector_t = typename vector_type<FloatA, TK1>::type;
using b_vector_t = typename vector_type<FloatB, K1>::type; using b_vector_t = typename vector_type<FloatB, TK1>::type;
constexpr index_t c_offset = CDesc{}.CalculateOffset( constexpr index_t c_offset =
c_origin_idx + make_multi_index(m0, m1, n0, n1)); CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
amd_inner_product_dlop<a_vector_t, b_vector_t, FloatC>( amd_inner_product_dlop<a_vector_t, b_vector_t, FloatC>(
a_vec.template AsType<a_vector_t>()[I0], a_vec.template AsType<a_vector_t>()[I0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment