Unverified Commit b8b2d0a6 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

DL GEMM fp32/fp16/int8 (#41)

* add threadwise copy the copy a tensor in one copy, added kpack to DL GEMM

* add kpack into fwd v4r5 nchw fp32
parent 11ec07e9
#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP
#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_contraction_v1r2.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGKGM0GM1GridDesc,
typename BGKGN0GN1GridDesc,
typename CGM0GM1GN0GN1GridDesc,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
__host__ float
driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AGKGM0GM1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc,
const BGKGN0GN1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc,
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
// GEMM
using GridwiseContraction = GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGKGM0GM1GridDesc,
BGKGN0GN1GridDesc,
CGM0GM1GN0GN1GridDesc,
GM1PerBlockGM11,
GN1PerBlockGN11,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
if(!GridwiseContraction::CheckValidity(
a_gk0_gm0_gm1_gk1_grid_desc, b_gk0_gn0_gn1_gk1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc))
{
throw std::runtime_error(
"wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting");
}
const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc =
GridwiseContraction::MakeAGK0GM0GM10GM11GK1GridDescriptor(a_gk0_gm0_gm1_gk1_grid_desc);
const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc =
GridwiseContraction::MakeBGK0GN0GN10GN11GK1GridDescriptor(b_gk0_gn0_gn1_gk1_grid_desc);
using AGK0GM0GM10GM11GK1GridDesc = decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc);
using BGK0GN0GN10GN11GK1GridDesc = decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc);
// c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc);
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc);
// c_blockid_to_gm10_gn10_block_cluster_adaptor
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc);
using CBlockIdToGM10GN10BlockClusterAdaptor =
decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor);
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc);
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0);
const bool has_double_tail_k_block_loop =
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0);
{
std::cout << "a_gk0_gm0_gm10_gm11_gk1_grid_desc{"
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I1) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I2) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I3) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I4) << "}" << std::endl;
std::cout << "b_gk0_gn0_gn10_gn11_gk1_grid_desc{"
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I0) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I1) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I2) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I3) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I4) << "}" << std::endl;
std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl;
}
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
return ave_time;
}
} // namespace ck
#endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_v1r3
#define CK_DRIVER_DYNAMIC_GEMM_v1r3
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_v1r3.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
__host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
// GEMM
using GridwiseGemm =
GridwiseDynamicGemm_km_kn_mn_v1r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{
throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_mn_v1r3 has invalid setting");
}
const auto a_k0_m0_m1_k1_grid_desc =
GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc);
const auto b_k0_n0_n1_k1_grid_desc =
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc);
using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc);
using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc);
// c_m0_m10_m11_n0_n10_n11_grid_desc
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
// c_blockid_to_m0_n0_block_cluster_adaptor
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
{
std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", "
<< a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", "
<< a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", "
<< a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", "
<< b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", "
<< b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", "
<< b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
}
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor);
}
else
{
const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor);
}
return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc));
DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc));
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc);
b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc);
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
&c_blockid_to_m0_n0_block_cluster_adaptor);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
}
else
{
const auto kernel =
kernel_dynamic_gemm_v1r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
}
return ave_time;
#endif
}
} // namespace ck
#endif
...@@ -12,7 +12,7 @@ namespace ck { ...@@ -12,7 +12,7 @@ namespace ck {
// C: out // C: out
// GemmM = N * Ho * Wo // GemmM = N * Ho * Wo
// GemmN = K // GemmN = K
// GemmK = C * Y * X // GemmK = Y * X * C
template <typename... In, template <typename... In,
typename... Wei, typename... Wei,
typename... Out, typename... Out,
......
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t N0Value,
index_t C0Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<N0Value>,
Number<C0Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
constexpr auto N0 = Number<N0Value>{};
constexpr auto C0 = Number<C0Value>{};
const auto N1 = N / N0;
const auto C1 = C / C0;
// weight tensor
const auto wei_gk0_gm0_gm1_gk1_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
make_unmerge_transform(make_tuple(C0, C1 * Y * X))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{}));
// input tensor
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_unmerge_transform(make_tuple(C0, C1)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_dynamic_tensor_descriptor(
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C1, Y, X)),
make_pass_through_transform(N0),
make_merge_transform(make_tuple(N1, Ho, Wo)),
make_pass_through_transform(C0)),
make_tuple(Sequence<3, 4, 6>{}, Sequence<0>{}, Sequence<1, 5, 7>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor
const auto out_n_k_howo_grid_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo));
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor(
out_n_k_howo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(Number<N0>{}, N1)),
make_unmerge_transform(make_tuple(I1, K)),
make_pass_through_transform(Ho * Wo)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
out_n0_n1_1_k_howo_grid_desc,
make_tuple(make_pass_through_transform(I1),
make_pass_through_transform(K),
make_pass_through_transform(Number<N0>{}),
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return make_tuple(
wei_gk0_gm0_gm1_gk1_grid_desc, in_gk0_gn0_gn1_gk1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc);
}
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
InMemoryDataOperation DstInMemOp,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
typename SrcVectorTensorLengths,
typename DstVectorTensorLengths,
typename SrcVectorTensorContiguousDimOrder,
typename DstVectorTensorContiguousDimOrder,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseDynamicTensorSliceTransfer_v4r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4r1(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin)
: threadwise_transfer_(
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, typename SrcIteratorHacks>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const SrcIteratorHacks& src_iterator_hacks)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
}
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf);
}
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window
template <typename SrcMoveSliceWindowIteratorHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_iterator_hack);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v3r1<ThreadSliceLengths,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorTensorLengths,
DstVectorTensorLengths,
SrcVectorTensorContiguousDimOrder,
DstVectorTensorContiguousDimOrder,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_GEMM_V2R3_HPP
#define CK_BLOCKWISE_GEMM_V2R3_HPP
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
#include "threadwise_gemm_v2.hpp"
namespace ck {
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. AK0MK1BlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. BK0NK1BlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CM0M1N0N1ThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t M1PerThreadM11,
index_t N1PerThreadN11,
index_t KPerThread,
index_t M1N1ThreadClusterM100,
index_t M1N1ThreadClusterN100,
index_t M1N1ThreadClusterM101,
index_t M1N1ThreadClusterN101,
index_t AThreadCopyScalarPerVector_M11,
index_t BThreadCopyScalarPerVector_N11,
typename std::enable_if<AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t M = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t N = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t M100 = M1N1ThreadClusterM100;
static constexpr index_t N100 = M1N1ThreadClusterN100;
static constexpr index_t M101 = M1N1ThreadClusterM101;
static constexpr index_t N101 = M1N1ThreadClusterN101;
static constexpr index_t M11 = M1PerThreadM11;
static constexpr index_t N11 = N1PerThreadN11;
static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
static constexpr index_t M0 = M / M1;
static constexpr index_t N0 = N / N1;
__host__ __device__ static constexpr auto
MakeAK0M0M1K1BlockDescriptor(const AK0MK1BlockDesc& a_k0_m_k1_block_desc)
{
const auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<K0>{}),
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_k0_m0_m1_k1_block_desc;
}
__host__ __device__ static constexpr auto
MakeBK0N0N1K1BlockDescriptor(const BK0NK1BlockDesc& b_k0_n_k1_block_desc)
{
const auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<K0>{}),
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_k0_n0_n1_k1_block_desc;
}
__host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M, N]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(
Number<M0>{}, Number<M100>{}, Number<M101>{}, Number<M11>{})),
make_unmerge_transform(make_tuple(
Number<N0>{}, Number<N100>{}, Number<N101>{}, Number<N11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
}
__host__ __device__ static constexpr auto
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M0, M1, N0, N1]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(Number<M0>{}),
make_unmerge_transform(
make_tuple(Number<M100>{}, Number<M101>{}, Number<M11>{})),
make_pass_through_transform(Number<N0>{}),
make_unmerge_transform(
make_tuple(Number<N100>{}, Number<N101>{}, Number<N11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
}
__host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
{
return Sequence<M0, M11, N0, N11>{};
}
static constexpr auto a_k0_m0_m1_k1_block_desc_ =
MakeAK0M0M1K1BlockDescriptor(AK0MK1BlockDesc{});
static constexpr auto b_k0_n0_n1_k1_block_desc_ =
MakeBK0N0N1K1BlockDescriptor(BK0NK1BlockDesc{});
public:
__device__ BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2()
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)}
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == M101 * M100 * N101 * N100,
"wrong! blocksize and cluster size not consistent");
static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
// TODO: remove this restriction
static_assert(M0 == 2 && N0 == 2, "wrong");
}
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
{
// lower: [M0, M1, N0, N1]
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
// upper: [Tid, M0, M11, N0, N11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
make_pass_through_transform(M0),
make_pass_through_transform(M11),
make_pass_through_transform(N0),
make_pass_through_transform(N11)),
make_tuple(
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0));
}
template <typename CM0M1N0N1ThreadDesc,
typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc,
const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: remove this restriction
static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
"wrong");
auto a_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatA>(
a_k0_m0_m1_k1_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatB>(
b_k0_n0_n1_k1_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm =
ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1<FloatA,
FloatB,
FloatC,
decltype(a_k0_m0_m1_k1_thread_desc_),
decltype(b_k0_n0_n1_k1_thread_desc_),
CM0M1N0N1ThreadDesc,
Sequence<KPerThread, K1>,
Sequence<1, M1PerThreadM11>,
Sequence<1, N1PerThreadN11>>{};
// read A_sub_0
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
make_tuple(I0, I0, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
make_tuple(I0, I0, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
make_tuple(I0, I1, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
make_tuple(I0, I1, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
// loop over rest of k
static_for<KPerThread, K0, KPerThread>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
make_tuple(k, I0, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// read B_sub_0
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
make_tuple(k, I0, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
// read B_sub_1
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
make_tuple(k, I1, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
make_tuple(k, I1, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
}
private:
// A[K0, M0, M1, K1]
static constexpr auto a_k0_m0_m1_k1_thread_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}, Number<K1>{}));
// B[K0, N0, N1, K1]
static constexpr auto b_k0_n0_n1_k1_thread_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}, Number<K1>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
FloatA,
FloatA,
decltype(a_k0_m0_m1_k1_block_desc_),
decltype(a_k0_m0_m1_k1_thread_desc_),
Sequence<KPerThread, 1, M1PerThreadM11, K1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, M1PerThreadM11, K1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
FloatB,
FloatB,
decltype(b_k0_n0_n1_k1_block_desc_),
decltype(b_k0_n0_n1_k1_thread_desc_),
Sequence<KPerThread, 1, N1PerThreadN11, K1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, N1PerThreadN11, K1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R2_HPP
#define CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R2_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2r2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
template <typename GridwiseContraction,
typename FloatAB,
typename FloatC,
typename AGK0GM0GM10GM11GK1GridDesc,
typename BGK0GN0GN10GN11GK1GridDesc,
typename CGM10BM0BM1GN10BN0BN1GridDesc,
typename CBlockIdToGM10GN10BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_contraction_v1r2(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGK0GM0GM10GM11GK1GridDesc a_gk0_gm0_gm10_gm11_gk1_grid_desc,
const BGK0GN0GN10GN11GK1GridDesc b_gk0_gn0_gn10_gn11_gk1_grid_desc,
const CGM10BM0BM1GN10BN0BN1GridDesc c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
const CBlockIdToGM10GN10BlockClusterAdaptor
c_blockid_to_gm10_gn10_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseContraction::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGK0GM0GM1GK1GridDesc,
typename BGK0GN0GN1GK1GridDesc,
typename CGM0GM1GN0GN1GridDesc,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t KPerBlock,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
index_t M11N11ThreadClusterM1100,
index_t M11N11ThreadClusterN1100,
index_t M11N11ThreadClusterM1101,
index_t M11N11ThreadClusterN1101,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// GM0 and GN0 need to known at compile-time
static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0);
static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2);
static constexpr auto GK1 = AGK0GM0GM1GK1GridDesc{}.GetLength(I3);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// lds max alignment
// TODO: part of them should be moved into blockwise-gemm
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = GK1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc,
const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc,
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
"wrong! GM0 and GN0 need to be known at compile-time");
const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2);
const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2);
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return ((GM0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I0) &&
GM1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1) &&
GN0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I2) &&
GN1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3) &&
GM0 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I1) &&
GM1 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2) &&
GN0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I1) &&
GN1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2) &&
GK0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0) &&
GK1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I3)) &&
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % KPerBlock == 0));
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
constexpr index_t GM11 = GM1PerBlockGM11;
constexpr index_t GN11 = GN1PerBlockGN11;
const index_t GM10 = GM1 / GM11;
const index_t GN10 = GN1 / GN11;
const index_t grid_size = GM10 * GN10;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0)
{
const bool has_main_k_block_loop = (GK0 + KPerBlock) / (2 * KPerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0)
{
const bool has_double_tail_k_block_loop = (GK0 / KPerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAGK0GM0GM10GM11GK1GridDescriptor(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc)
{
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2);
const auto GM11 = Number<GM1PerBlockGM11>{};
const auto GM10 = GM1 / GM11;
const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc = transform_dynamic_tensor_descriptor(
a_gk0_gm0_gm1_gk1_grid_desc,
make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
make_pass_through_transform(GK1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return a_gk0_gm0_gm10_gm11_gk1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeBGK0GN0GN10GN11GK1GridDescriptor(const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc)
{
const auto GK0 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0);
const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2);
const auto GN11 = Number<GN1PerBlockGN11>{};
const auto GN10 = GN1 / GN11;
const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc = transform_dynamic_tensor_descriptor(
b_gk0_gn0_gn1_gk1_grid_desc,
make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11)),
make_pass_through_transform(GK1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return b_gk0_gn0_gn10_gn11_gk1_grid_desc;
}
__host__ __device__ static constexpr auto MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
constexpr auto BM = GM0 * GM11;
constexpr auto BN = GN0 * GN11;
constexpr auto BM1 =
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
constexpr auto BN1 =
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
constexpr auto BM0 = BM / BM1;
constexpr auto BN0 = BN / BN1;
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
c_gm0_gm1_gn0_gn1_grid_desc,
make_tuple(make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor(
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_merge_transform(make_tuple(GM0, GM11)),
make_pass_through_transform(GN10),
make_merge_transform(make_tuple(GN0, GN11))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = transform_dynamic_tensor_descriptor(
c_gm10_bm_gn10_bn_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_unmerge_transform(make_tuple(BM0, BM1)),
make_pass_through_transform(GN10),
make_unmerge_transform(make_tuple(BN0, BN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
return c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc;
}
__host__ __device__ static constexpr auto MakeCBlockIdToGM10GN10BlockClusterAdaptor(
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_gm10_gn10_block_cluster_adaptor;
}
using AGK0GM0GM10GM11GK1GridDesc =
decltype(MakeAGK0GM0GM10GM11GK1GridDescriptor(AGK0GM0GM1GK1GridDesc{}));
using BGK0GN0GN10GN11GK1GridDesc =
decltype(MakeBGK0GN0GN10GN11GK1GridDescriptor(BGK0GN0GN1GK1GridDesc{}));
using CGM10BM0BM1GN10BN0BN1GridDesc =
decltype(MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(CGM0GM1GN0GN1GridDesc{}));
using CBlockIdToGM10GN10BlockClusterAdaptor =
decltype(MakeCBlockIdToGM10GN10BlockClusterAdaptor(CGM0GM1GN0GN1GridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AGK0GM0GM10GM11GK1GridDesc& a_gk0_gm0_gm10_gm11_gk1_grid_desc,
const BGK0GN0GN10GN11GK1GridDesc& b_gk0_gn0_gn10_gn11_gk1_grid_desc,
const CGM10BM0BM1GN10BN0BN1GridDesc& c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
const CBlockIdToGM10GN10BlockClusterAdaptor& c_blockid_to_gm10_gn10_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_grid, a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetElementSpaceSize());
const auto GK0 = a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0);
// divide block work by [GM10, GN10]
const auto c_gm10_gn10_block_cluster_idx =
c_blockid_to_gm10_gn10_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]);
const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]);
// lds max alignment
// TODO: part of them should be moved into blockwise-gemm
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = GK1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align);
// A matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto a_gk0_bm_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
// B matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto b_gk0_bn_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
static_assert(a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize() ==
a_gk0_bm_gk1_block_desc.GetElementSpaceSize() &&
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize() ==
b_gk0_bn_gk1_block_desc.GetElementSpaceSize(),
"wrong!");
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc),
decltype(a_gk0_gm0_gm10_gm11_gk1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
make_multi_index(0, 0, igm10, 0, 0),
a_gk0_gm0_gm10_gm11_gk1_block_desc,
make_multi_index(0, 0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc),
decltype(b_gk0_gn0_gn10_gn11_gk1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
make_multi_index(0, 0, ign10, 0, 0),
b_gk0_gn0_gn10_gn11_gk1_block_desc,
make_multi_index(0, 0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, GM1PerBlockGM11] is in LDS
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_gk0_bm_gk1_block_desc),
decltype(b_gk0_bn_gk1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM1100,
M11N11ThreadClusterN1100,
M11N11ThreadClusterM1101,
M11N11ThreadClusterN1101,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_bm0_bm1_bn0_bn1_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_bm0_bm1_bn0_bn1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(c_bm0_bm1_bn0_bn1_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_bm0_bm1_bn0_bn1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_bm0_bm1_bn0_bn1_thread_desc),
decltype(c_bm0_bm1_bn0_bn1_thread_tensor_lengths)>{}
.Run(c_bm0_bm1_bn0_bn1_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_bm0_bm1_bn0_bn1_thread_desc,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < GK0 - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr index_t M11 =
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
constexpr index_t N11 =
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
constexpr index_t M10 = GM1PerBlockGM11 / M11;
constexpr index_t N10 = GN1PerBlockGN11 / N11;
constexpr index_t M111 = M1PerThreadM111;
constexpr index_t N111 = N1PerThreadN111;
constexpr auto c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1,
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0]>{},
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1]>{},
I1,
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2]>{},
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>{}));
const auto c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc),
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc),
Sequence<1,
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0],
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1],
1,
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2],
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
make_multi_index(igm10,
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I0],
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I1],
ign10,
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I2],
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I3])}
.Run(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_grid_buf,
CGridIteratorHacks{});
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2r3.hpp"
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AK0M0M1K1GridDesc,
typename BK0N0N1K1GridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockIdToM0N0BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1r3(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc,
const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AK0M0M1K1GridDesc,
typename BK0N0N1K1GridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockIdToM0N0BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1r3(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k0_m0_m1_k1_grid_desc,
const void __CONSTANT__* p_b_k0_n0_n1_k1_grid_desc,
const void __CONSTANT__* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k0_m0_m1_k1_grid_desc =
*reinterpret_cast<const AK0M0M1K1GridDesc*>((const void*)p_a_k0_m0_m1_k1_grid_desc);
const auto b_k0_n0_n1_k1_grid_desc =
*reinterpret_cast<const BK0N0N1K1GridDesc*>((const void*)p_b_k0_n0_n1_k1_grid_desc);
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlockM1,
index_t NPerBlockN1,
index_t KPerBlock,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
index_t M11N11ThreadClusterM1100,
index_t M11N11ThreadClusterN1100,
index_t M11N11ThreadClusterM1101,
index_t M11N11ThreadClusterN1101,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_mn_v1r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto K1 = a_k0_m_k1_grid_desc.GetLength(I2);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{
const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{
const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc)
{
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto M1 = Number<MPerBlockM1>{};
const auto M0 = M / M1;
const auto a_k0_m0_m1_k1_grid_desc = transform_dynamic_tensor_descriptor(
a_k0_m_k1_grid_desc,
make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_k0_m0_m1_k1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
{
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto N1 = Number<NPerBlockN1>{};
const auto N0 = N / N1;
const auto b_k0_n0_n1_k1_grid_desc = transform_dynamic_tensor_descriptor(
b_k0_n_k1_grid_desc,
make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_k0_n0_n1_k1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 =
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
constexpr auto N11 =
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_m0_m10_m11_n0_n10_n11_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{}));
using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{}));
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc,
const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
// divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx =
c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m0_m1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n0_n1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() ==
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m0_m1_k1_grid_desc),
decltype(a_k0_m0_m1_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(a_k0_m0_m1_k1_grid_desc,
make_multi_index(0, im0, 0, 0),
a_k0_m0_m1_k1_block_desc,
make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n0_n1_k1_grid_desc),
decltype(b_k0_n0_n1_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(b_k0_n0_n1_k1_grid_desc,
make_multi_index(0, in0, 0, 0),
b_k0_n0_n1_k1_block_desc,
make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM1100,
M11N11ThreadClusterN1100,
M11N11ThreadClusterM1101,
M11N11ThreadClusterN1101,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_m10_m11_n10_n11_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_m10_m11_n10_n11_thread_desc),
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
.Run(c_m10_m11_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf =
make_dynamic_buffer<AddressSpace::Lds>(p_a_block_double + a_block_aligned_space_size,
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf =
make_dynamic_buffer<AddressSpace::Lds>(p_b_block_double + b_block_aligned_space_size,
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0);
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K0 - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr index_t M11 =
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
constexpr index_t N11 =
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
constexpr index_t M10 = MPerBlockM1 / M11;
constexpr index_t N10 = NPerBlockN1 / N11;
constexpr index_t M111 = M1PerThreadM111;
constexpr index_t N111 = N1PerThreadN111;
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf,
CGridIteratorHacks{});
}
}
};
} // namespace ck
#endif
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
template <typename SliceLengths,
InMemoryDataOperation DstInMemOp,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
typename SrcVectorTensorLengths,
typename DstVectorTensorLengths,
typename SrcVectorTensorContiguousDimOrder,
typename DstVectorTensorContiguousDimOrder,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
struct ThreadwiseDynamicTensorSliceTransfer_v3r1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3r1(const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
const Index& dst_slice_origin)
: src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin))
{
// TODO: fix this
static_assert(is_same<SrcData, DstData>::value,
"wrong! current implementation assume SrcData and DstData are same type");
static_for<0, nDim, 1>{}([](auto i) {
static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0 &&
SliceLengths::At(i) % DstVectorTensorLengths::At(i) == 0,
"wrong!");
});
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcBuffer, typename SrcIteratorHacks>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const SrcIteratorHacks& src_iterator_hacks)
{
static_assert(SrcBuffer::GetAddressSpace() == AddressSpace::Global or
SrcBuffer::GetAddressSpace() == AddressSpace::Lds,
"wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
remove_cv_t<remove_reference_t<SrcData>>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent");
// tensor descriptor for src_vector
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
constexpr auto src_vector_tensor_strides = container_reorder_given_old2new(
container_reverse_exclusive_scan(
container_reorder_given_new2old(src_vector_tensor_lengths,
SrcVectorTensorContiguousDimOrder{}),
math::multiplies_v2{},
I1),
SrcVectorTensorContiguousDimOrder{});
constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2(
sequence_to_tuple_of_number(src_vector_tensor_lengths),
sequence_to_tuple_of_number(src_vector_tensor_strides));
// access order and lengths
constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward iterators
const auto src_forward_iterators = generate_tuple(
[&](auto i) {
Index forward_step;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
});
return make_dynamic_tensor_coordinate_iterator(
src_desc, forward_step, src_iterator_hacks[I0][i]);
},
Number<nDim>{});
// make backward iterators
const auto src_backward_iterators = generate_tuple(
[&](auto i) {
Index backward_step;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
});
return make_dynamic_tensor_coordinate_iterator(
src_desc, backward_step, src_iterator_hacks[I1][i]);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
auto src_data_idx =
container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_vector_tensor_lengths;
return src_data_idx;
}();
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
using src_vector_t = typename decltype(src_vector)::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
// copy data from src_buf to src_vector
src_vector.template AsType<src_vector_t>()(I0) =
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
// copy data from src_vector to buffer_
static_ford<SrcVectorTensorLengths>{}([&](auto src_vector_idx_) {
constexpr auto src_vector_idx = to_multi_index(src_vector_idx_);
constexpr index_t src_vector_offset =
src_vector_desc.CalculateOffset(src_vector_idx);
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(src_data_idx + src_vector_idx);
buffer_(Number<buffer_offset>{}) =
src_vector.template AsType<SrcData>()[Number<src_vector_offset>{}];
});
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
});
return move_on_dim;
}
();
// move
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_dynamic_tensor_coordinate(
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]);
}
else
{
move_dynamic_tensor_coordinate(
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_iterator =
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
}
}
template <typename DstBuffer, typename DstIteratorHacks>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
const DstIteratorHacks& dst_iterator_hacks)
{
static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Global or
DstBuffer::GetAddressSpace() == AddressSpace::Lds,
"wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
remove_cv_t<remove_reference_t<DstData>>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
// tensor descriptor for dst_vector
constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{};
constexpr auto dst_vector_tensor_strides = container_reorder_given_old2new(
container_reverse_exclusive_scan(
container_reorder_given_new2old(dst_vector_tensor_lengths,
DstVectorTensorContiguousDimOrder{}),
math::multiplies_v2{},
I1),
DstVectorTensorContiguousDimOrder{});
constexpr auto dst_vector_desc = make_dynamic_naive_tensor_descriptor_v2(
sequence_to_tuple_of_number(dst_vector_tensor_lengths),
sequence_to_tuple_of_number(dst_vector_tensor_strides));
// dst access order and lengths
constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward iterators
const auto dst_forward_iterators = generate_tuple(
[&](auto i) {
Index forward_step;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
});
const auto forward_iterator = make_dynamic_tensor_coordinate_iterator(
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
return forward_iterator;
},
Number<nDim>{});
// make backward iterators
const auto dst_backward_iterators = generate_tuple(
[&](auto i) {
Index backward_step;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
});
const auto backward_iterator = make_dynamic_tensor_coordinate_iterator(
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
return backward_iterator;
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
: ordered_dst_access_lengths[i] - 1 -
ordered_dst_access_idx[i];
});
auto dst_data_idx =
container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_vector_tensor_lengths;
return dst_data_idx;
}();
vector_type_maker_t<DstData, dst_vector_desc.GetElementSpaceSize()> dst_vector;
// copy data from buffer_ to dst_vector (also cast from SrcData to DstData)
static_ford<DstVectorTensorLengths>{}([&](auto dst_vector_idx_) {
constexpr auto dst_vector_idx = to_multi_index(dst_vector_idx_);
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(dst_data_idx + dst_vector_idx);
constexpr index_t dst_vector_offset =
dst_vector_desc.CalculateOffset(dst_vector_idx);
dst_vector.template AsType<DstData>()(Number<dst_vector_offset>{}) =
type_convert<DstData>{}(buffer_[Number<buffer_offset>{}]);
});
using dst_vector_t = typename decltype(dst_vector)::type;
// copy data from dst_vector to dst_buf
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim;
}
();
// move
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_dynamic_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]);
}
else
{
move_dynamic_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_iterator =
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
}
}
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_iterator_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunRead(src_desc, src_buf, src_iterator_hacks);
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_iterator_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunWrite(dst_desc, dst_buf, dst_iterator_hacks);
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
});
auto src_data_idx = container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_vector_tensor_lengths;
return src_data_idx;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; });
return reset_src_data_step;
}();
return reset_src_data_step;
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{};
constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
});
auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_vector_tensor_lengths;
return dst_data_idx;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; });
return reset_dst_data_step;
}();
return reset_dst_data_step;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step =
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowIteratorHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_dynamic_tensor_coordinate_iterator(
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack);
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step =
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
}
private:
static constexpr auto buffer_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{}));
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticBuffer<AddressSpace::Vgpr, SrcData, buffer_size_> buffer_;
SrcCoord src_coord_;
DstCoord dst_coord_;
};
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time
// 2. SrcBuffer is DynamicBuffer
// 3. src_ref_idx is known at run-time
// 4. SrcRefToOriginDisplacement is known at compile-time
// 5. use #-iterator
// 2. dst:
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation
// 3. vector access on src
template <
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename DimAccessOrder,
typename SrcVectorTensorLengths,
typename SrcVectorTensorContiguousDimOrder,
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseDynamicTensorSliceTransfer_v4r1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4r1(const Index& src_ref_idx)
: src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx))
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_for<0, nDim, 1>{}([](auto i) {
static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!");
});
}
template <typename SrcRefToOriginDisplacement,
typename DstOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcRefToOriginDisplacement&,
const SrcBuffer& src_buf,
const DstDesc&,
const DstOriginIdx&,
DstBuffer& dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
remove_cv_t<remove_reference_t<SrcData>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
remove_cv_t<remove_reference_t<DstData>>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(
is_known_at_compile_time<
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
// tensor descriptor for src_vector
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
constexpr auto src_vector_tensor_strides = container_reorder_given_old2new(
container_reverse_exclusive_scan(
container_reorder_given_new2old(src_vector_tensor_lengths,
SrcVectorTensorContiguousDimOrder{}),
math::multiplies_v2{},
I1),
SrcVectorTensorContiguousDimOrder{});
constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2(
sequence_to_tuple_of_number(src_vector_tensor_lengths),
sequence_to_tuple_of_number(src_vector_tensor_strides));
// access order and lengths
constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// position in slice window
constexpr auto data_to_origin_disp_idx =
ordered_access_idx.ReorderGivenOld2New(dim_access_order) *
src_vector_tensor_lengths;
// src coordinate at starting point of src_vector
constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
constexpr auto src_ref_to_data_disp_coord_iterator =
make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx);
auto src_data_coord = src_ref_coord_;
move_dynamic_tensor_coordinate(
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
using src_vector_t = typename decltype(src_vector)::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord);
// copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(I0) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
// copy data from src_vector into dst_buf (also cast from SrcData to DstData)
static_ford<SrcVectorTensorLengths>{}([&](auto src_vector_idx_) {
constexpr auto src_vector_idx = to_multi_index(src_vector_idx_);
constexpr index_t src_vector_offset =
src_vector_desc.CalculateOffset(src_vector_idx);
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + src_vector_idx);
dst_buf(Number<dst_offset>{}) = type_convert<DstData>{}(
src_vector.template AsType<DstData>()[Number<src_vector_offset>{}]);
});
});
}
template <typename SrcSliceMoveStepIdx>
__device__ void MoveSrcSliceWindow(const SrcDesc&,
const SrcSliceMoveStepIdx& src_slice_move_step_idx)
{
constexpr auto src_desc = SrcDesc{};
const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator(
src_desc, to_multi_index(src_slice_move_step_idx));
move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
}
private:
SrcCoord src_ref_coord_;
};
} // namespace ck
#endif
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
namespace ck { namespace ck {
// C[M, N] += transpose(A[K, M]) * B[K, N] // C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1]
// Element of matrix can be vectorized data // Tensor element can be vectorized data
// Assume: // Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time // 1. ADesc, BDesc, CDesc are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time // 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
...@@ -17,11 +17,27 @@ template <typename FloatA, ...@@ -17,11 +17,27 @@ template <typename FloatA,
typename ADesc, typename ADesc,
typename BDesc, typename BDesc,
typename CDesc, typename CDesc,
typename KLengths,
typename MLengths,
typename NLengths,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_km_kn_mn_v1r1 struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1()
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
// TODO remove this restriction
static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer, template <typename ABuffer,
typename AOriginIdx, typename AOriginIdx,
typename BBuffer, typename BBuffer,
...@@ -35,10 +51,6 @@ struct ThreadwiseGemm_km_kn_mn_v1r1 ...@@ -35,10 +51,6 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
CBuffer& c_buf, CBuffer& c_buf,
COriginIdx) COriginIdx)
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert( static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value && is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value && is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
...@@ -58,89 +70,42 @@ struct ThreadwiseGemm_km_kn_mn_v1r1 ...@@ -58,89 +70,42 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto M = CDesc{}.GetLength(I0); constexpr auto K = KLengths{}[I0];
constexpr auto N = CDesc{}.GetLength(I1); constexpr auto M0 = MLengths{}[I0];
constexpr auto K = ADesc{}.GetLength(I0); constexpr auto M1 = MLengths{}[I1];
constexpr auto N0 = NLengths{}[I0];
constexpr auto N1 = NLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, K, 1>{}([&](auto k) { static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) { static_for<0, M0, 1>{}([&](auto m0) {
constexpr index_t a_offset = static_for<0, M1, 1>{}([&](auto m1) {
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m)); static_for<0, N0, 1>{}([&](auto n0) {
static_for<0, N1, 1>{}([&](auto n1) {
#if 0
if constexpr(N == 2)
{
constexpr index_t b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr index_t b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr index_t c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
amd_assembly_outer_product_1x2(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}));
}
else if constexpr(N == 4)
{
constexpr index_t b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr index_t b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr index_t b_offset_2 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2));
constexpr index_t b_offset_3 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3));
constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr index_t c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
constexpr index_t c_offset_2 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2));
constexpr index_t c_offset_3 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3));
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
b_buf[Number<b_offset_2>{}],
b_buf[Number<b_offset_3>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}),
c_buf(Number<c_offset_2>{}),
c_buf(Number<c_offset_3>{}));
}
else
#endif
{
static_for<0, N, 1>{}([&](auto n) {
constexpr index_t b_offset = constexpr index_t a_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n)); ADesc{}.CalculateOffset(a_origin_idx + make_multi_index(k, m0, m1));
constexpr index_t c_offset = constexpr index_t b_offset =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n)); BDesc{}.CalculateOffset(b_origin_idx + make_multi_index(k, n0, n1));
constexpr index_t c_offset = CDesc{}.CalculateOffset(
c_origin_idx + make_multi_index(m0, m1, n0, n1));
amd_assembly_inner_product(a_buf[Number<a_offset>{}], amd_inner_product_dlop<FloatA, FloatB, FloatC>(
b_buf[Number<b_offset>{}], a_buf[Number<a_offset>{}],
c_buf(Number<c_offset>{})); b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{}));
});
}); });
} });
}); });
}); });
} }
}; };
// C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1] // C[M0, M1, N0, N1] += A[K0, M0, M1, K1] * B[K0, N0, N1, K1]
// Tensor element can be vectorized data // Tensor element can be vectorized data
// Assume: // Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time // 1. ADesc, BDesc, CDesc are known at compile-time
...@@ -157,9 +122,9 @@ template <typename FloatA, ...@@ -157,9 +122,9 @@ template <typename FloatA,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1() __device__ constexpr ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1()
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
...@@ -168,7 +133,7 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 ...@@ -168,7 +133,7 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths // TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
// TODO remove this restriction // TODO remove this restriction
static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2, static_assert(KLengths::Size() == 2 && MLengths::Size() == 2 && NLengths::Size() == 2,
"wrong!"); "wrong!");
} }
...@@ -204,32 +169,48 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 ...@@ -204,32 +169,48 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto K = KLengths{}[I0]; constexpr index_t K0 = KLengths{}[I0];
constexpr auto M0 = MLengths{}[I0]; constexpr index_t K1 = KLengths{}[I1];
constexpr auto M1 = MLengths{}[I1]; constexpr index_t M0 = MLengths{}[I0];
constexpr auto N0 = NLengths{}[I0]; constexpr index_t M1 = MLengths{}[I1];
constexpr auto N1 = NLengths{}[I1]; constexpr index_t N0 = NLengths{}[I0];
constexpr index_t N1 = NLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, K, 1>{}([&](auto k) { static_for<0, K0, 1>{}([&](auto k0) {
static_for<0, M0, 1>{}([&](auto m0) { static_for<0, M0, 1>{}([&](auto m0) {
static_for<0, M1, 1>{}([&](auto m1) { static_for<0, M1, 1>{}([&](auto m1) {
static_for<0, N0, 1>{}([&](auto n0) { static_for<0, N0, 1>{}([&](auto n0) {
static_for<0, N1, 1>{}([&](auto n1) { static_for<0, N1, 1>{}([&](auto n1) {
constexpr index_t a_offset = vector_type<FloatA, K1> a_vec;
ADesc{}.CalculateOffset(a_origin_idx + make_multi_index(k, m0, m1)); vector_type<FloatB, K1> b_vec;
constexpr index_t b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_multi_index(k, n0, n1)); static_for<0, K1, 1>{}([&](auto k1) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(
a_origin_idx + make_multi_index(k0, m0, m1, k1));
constexpr index_t b_offset = BDesc{}.CalculateOffset(
b_origin_idx + make_multi_index(k0, n0, n1, k1));
a_vec.template AsType<FloatA>()(k1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(k1) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, K1>::type;
using b_vector_t = typename vector_type<FloatB, K1>::type;
constexpr index_t c_offset = CDesc{}.CalculateOffset( constexpr index_t c_offset = CDesc{}.CalculateOffset(
c_origin_idx + make_multi_index(m0, m1, n0, n1)); c_origin_idx + make_multi_index(m0, m1, n0, n1));
amd_assembly_inner_product(a_buf[Number<a_offset>{}], amd_inner_product_dlop<a_vector_t, b_vector_t, FloatC>(
b_buf[Number<b_offset>{}], a_vec.template AsType<a_vector_t>()[I0],
c_buf(Number<c_offset>{})); b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{}));
}); });
}); });
}); });
......
#ifndef CK_AMD_DLOP_HPP
#define CK_AMD_DLOP_HPP
#include "float_type.hpp"
namespace ck {
template <typename TA, typename TB, typename TC>
__device__ void amd_inner_product_dlop(const TA& a, const TB& b, TC& c);
template <>
__device__ void
amd_inner_product_dlop<float, float, float>(const float& a, const float& b, float& c)
{
#if CK_USE_AMD_DLOP_INLINE_ASM
asm volatile("\n \
v_fmac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
c += a * b;
#endif
}
#if CK_USE_AMD_DLOP
template <>
__device__ void
amd_inner_product_dlop<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{
#if CK_USE_AMD_DLOP_INLINE_ASM
asm volatile("\n \
v_dot2_f32_f16 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
c = __builtin_amdgcn_sdot2(a, b, c, false);
#endif
}
template <>
__device__ void
amd_inner_product_dlop<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_inner_product_dlop(vector_type<half_t, 4>{a}.AsType<half2_t>()[I0],
vector_type<half_t, 4>{b}.AsType<half2_t>()[I0],
c);
amd_inner_product_dlop(vector_type<half_t, 4>{a}.AsType<half2_t>()[I1],
vector_type<half_t, 4>{b}.AsType<half2_t>()[I1],
c);
}
template <>
__device__ void
amd_inner_product_dlop<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I0],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I0],
c);
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I1],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I1],
c);
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I2],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I2],
c);
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I3],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I3],
c);
}
template <>
__device__ void amd_inner_product_dlop<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a,
const int8x4_t& b,
int32_t& c)
{
#if CK_USE_AMD_DLOP_INLINE_ASM
asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
#else
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
#endif
}
template <>
__device__ void amd_inner_product_dlop<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a,
const int8x8_t& b,
int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_inner_product_dlop(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
c);
amd_inner_product_dlop(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
c);
}
template <>
__device__ void amd_inner_product_dlop<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a,
const int8x16_t& b,
int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
c);
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
c);
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
c);
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
c);
}
#endif // CK_USE_AMD_DLOP
} // namespace ck
#endif
...@@ -5,94 +5,16 @@ ...@@ -5,94 +5,16 @@
namespace ck { namespace ck {
// c += inner_product(a, b)
__device__ void amd_assembly_inner_product(const float& a, const float& b, float& c)
{
#if CK_USE_AMD_V_FMAC_F32
asm volatile("\n \
v_fmac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
asm volatile("\n \
v_mac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#endif
}
__device__ void amd_assembly_inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
#else
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
#endif
}
__device__ void amd_assembly_inner_product(const int8x8_t& a, const int8x8_t& b, int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_assembly_inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
c);
amd_assembly_inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
c);
}
__device__ void amd_assembly_inner_product(const int8x16_t& a, const int8x16_t& b, int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
c);
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
c);
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
c);
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
c);
}
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1) // c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
{ {
#if CK_USE_AMD_V_FMAC_F32
asm volatile("\n \ asm volatile("\n \
v_fmac_f32 %0, %2, %3 \n \ v_fmac_f32 %0, %2, %3 \n \
v_fmac_f32 %1, %2, %4 \n \ v_fmac_f32 %1, %2, %4 \n \
" "
: "=v"(c0), "=v"(c1) : "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#else
asm volatile("\n \
v_mac_f32 %0, %2, %3 \n \
v_mac_f32 %1, %2, %4 \n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#endif
} }
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
...@@ -102,7 +24,6 @@ __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, floa ...@@ -102,7 +24,6 @@ __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, floa
__device__ void amd_assembly_outer_product_1x4( __device__ void amd_assembly_outer_product_1x4(
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3) float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
{ {
#if CK_USE_AMD_V_FMAC_F32
asm volatile("\n \ asm volatile("\n \
v_fmac_f32 %0, %4, %5 \n \ v_fmac_f32 %0, %4, %5 \n \
v_fmac_f32 %1, %4, %6 \n \ v_fmac_f32 %1, %4, %6 \n \
...@@ -111,16 +32,6 @@ __device__ void amd_assembly_outer_product_1x4( ...@@ -111,16 +32,6 @@ __device__ void amd_assembly_outer_product_1x4(
" "
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
#else
asm volatile("\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
#endif
} }
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
......
...@@ -28,10 +28,15 @@ ...@@ -28,10 +28,15 @@
#include "static_buffer.hpp" #include "static_buffer.hpp"
#include "dynamic_buffer.hpp" #include "dynamic_buffer.hpp"
// TODO: remove this
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp" #include "amd_inline_asm.hpp"
#endif #endif
#if CK_USE_AMD_DLOP
#include "amd_dlop.hpp"
#endif
#if CK_USE_AMD_XDLOPS #if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp" #include "amd_xdlops.hpp"
#include "amd_xdlops_inline_asm.hpp" #include "amd_xdlops_inline_asm.hpp"
......
...@@ -54,8 +54,13 @@ ...@@ -54,8 +54,13 @@
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1 #define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1
#endif #endif
#ifndef CK_USE_AMD_V_FMAC_F32 // AMD DLOPS
#define CK_USE_AMD_V_FMAC_F32 1 #ifndef CK_USE_AMD_DLOP
#define CK_USE_AMD_DLOP 1
#endif
#ifndef CK_USE_AMD_DLOP_INLINE_ASM
#define CK_USE_AMD_DLOP_INLINE_ASM 1
#endif #endif
// AMD buffer addressing // AMD buffer addressing
...@@ -116,7 +121,7 @@ ...@@ -116,7 +121,7 @@
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1
// merge transformation use magic number division // merge transformation use magic number division
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1 #define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
// hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
......
...@@ -94,7 +94,7 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(Sequence<Is.. ...@@ -94,7 +94,7 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(Sequence<Is..
constexpr auto new2old = typename sequence_map_inverse<Sequence<IRs...>>::type{}; constexpr auto new2old = typename sequence_map_inverse<Sequence<IRs...>>::type{};
return container_reorder_give_new2old(old_seq, new2old); return container_reorder_given_new2old(old_seq, new2old);
} }
#if !CK_WORKAROUND_SWDEV_275126 #if !CK_WORKAROUND_SWDEV_275126
...@@ -223,6 +223,13 @@ container_reverse_exclusive_scan(const Array<TData, NSize>& x, Reduce f, TData i ...@@ -223,6 +223,13 @@ container_reverse_exclusive_scan(const Array<TData, NSize>& x, Reduce f, TData i
return y; return y;
} }
template <index_t... Is, typename Reduce, index_t Init>
__host__ __device__ constexpr auto
container_reverse_exclusive_scan(const Sequence<Is...>& seq, Reduce f, Number<Init>)
{
return reverse_exclusive_scan_sequence(seq, f, Number<Init>{});
}
#if !CK_WORKAROUND_SWDEV_275126 #if !CK_WORKAROUND_SWDEV_275126
// rocm4.1 compiler would crash with recursive lambda // rocm4.1 compiler would crash with recursive lambda
template <typename... Xs, typename Reduce, typename Init> template <typename... Xs, typename Reduce, typename Init>
...@@ -366,6 +373,19 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>& ...@@ -366,6 +373,19 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
} }
template <typename Container>
__host__ __device__ constexpr auto to_tuple_of_number(const Container&)
{
static_assert(is_known_at_compile_time<Container>::value, "wrong!");
return generate_tuple(
[&](auto i) {
constexpr index_t tmp = Container::At(i);
return Number<tmp>{};
},
Container::Size());
}
template <index_t... Is> template <index_t... Is>
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>) __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
{ {
......
...@@ -100,39 +100,72 @@ struct DynamicBuffer ...@@ -100,39 +100,72 @@ struct DynamicBuffer
*reinterpret_cast<X*>(&p_data_[i]) = x; *reinterpret_cast<X*>(&p_data_[i]) = x;
#else #else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
// ISA, so I try to let compiler emit use IR "store<i32, 4>" which would be lower to // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128 // ds_write_b128
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type, if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type,
int8_t>::value) int8_t>::value)
{ {
static_assert( static_assert(
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value && (is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) || is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value && (is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) || is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value && (is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value), is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value),
"wrong! not implemented for this combination, please add implementation"); "wrong! not implemented for this combination, please add implementation");
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value && if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*reinterpret_cast<int8_t*>(&p_data_[i]) =
*reinterpret_cast<const int8_t*>(&x);
}
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*reinterpret_cast<int16_t*>(&p_data_[i]) =
*reinterpret_cast<const int16_t*>(&x);
}
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*reinterpret_cast<int32_t*>(&p_data_[i]) =
*reinterpret_cast<const int32_t*>(&x);
}
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
int8x4_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*reinterpret_cast<int32_t*>(&p_data_[i]) = *reinterpret_cast<int32_t*>(&p_data_[i]) =
*reinterpret_cast<const int32_t*>(&x); *reinterpret_cast<const int32_t*>(&x);
} }
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value && else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) int8x8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*reinterpret_cast<int32x2_t*>(&p_data_[i]) = *reinterpret_cast<int32x2_t*>(&p_data_[i]) =
*reinterpret_cast<const int32x2_t*>(&x); *reinterpret_cast<const int32x2_t*>(&x);
} }
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value && else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value) int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp"
...@@ -24,23 +26,27 @@ ...@@ -24,23 +26,27 @@
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4_NHWC 0 #define USE_CONV_FWD_V4R4_NHWC 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V4R5_NCHW 0 #define USE_CONV_FWD_V4R5_NCHW 0
#define USE_CONV_FWD_V4R5R2_NCHW 1
#define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4_XDL_NCHW 0 #define USE_CONV_FWD_V4R4_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NHWC 0 #define USE_CONV_FWD_V4R4R2_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R3_XDL_NHWC 1 #define USE_CONV_FWD_V4R4R3_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
enum ConvForwardAlgo enum ConvForwardAlgo
{ {
V4R4NCHW, // 0 V4R4NCHW, // 0
V4R4NHWC, // 1 V4R4NHWC, // 1
V4R5NCHW, // 2 V4R4R2NHWC, // 2
V5R1NCHW, // 3 V4R5NCHW, // 3
V4R4XDLNCHW, // 4 V4R5R2NCHW, // 4
V4R4R2XDLNHWC, // 5 V5R1NCHW, // 5
V4R4R3XDLNHWC, // 6 V4R4XDLNCHW, // 6
V4R4R4XDLNHWC // 7 V4R4R2XDLNHWC, // 7
V4R4R3XDLNHWC, // 8
V4R4R4XDLNHWC // 9
}; };
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -132,21 +138,18 @@ int main(int argc, char* argv[]) ...@@ -132,21 +138,18 @@ int main(int argc, char* argv[])
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
#endif #endif
#if 0 #if 1
constexpr index_t in_vector_size = 1; using in_data_t = float;
using in_data_t = float; using acc_data_t = float;
using acc_data_t = float; using out_data_t = float;
using out_data_t = float;
#elif 1 #elif 1
constexpr index_t in_vector_size = 1; using in_data_t = half_t;
using in_data_t = half_t; using acc_data_t = float;
using acc_data_t = float; using out_data_t = half_t;
using out_data_t = half_t;
#elif 1 #elif 1
constexpr index_t in_vector_size = 16; using in_data_t = int8_t;
using in_data_t = int8_t; using acc_data_t = int32_t;
using acc_data_t = int32_t; using out_data_t = int8_t;
using out_data_t = int8_t;
#endif #endif
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
...@@ -348,6 +351,33 @@ int main(int argc, char* argv[]) ...@@ -348,6 +351,33 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_FWD_V4R4R2_NHWC
if(algo == ConvForwardAlgo::V4R4R2NHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V4R5_NCHW #if USE_CONV_FWD_V4R5_NCHW
if(algo == ConvForwardAlgo::V4R5NCHW) if(algo == ConvForwardAlgo::V4R5NCHW)
{ {
...@@ -374,6 +404,33 @@ int main(int argc, char* argv[]) ...@@ -374,6 +404,33 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_FWD_V4R5R2_NCHW
if(algo == ConvForwardAlgo::V4R5R2NCHW)
{
if(layout != ConvTensorLayout::NCHW)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V5R1_NCHW #if USE_CONV_FWD_V5R1_NCHW
if(algo == ConvForwardAlgo::V5R1NCHW) if(algo == ConvForwardAlgo::V5R1NCHW)
{ {
...@@ -385,7 +442,7 @@ int main(int argc, char* argv[]) ...@@ -385,7 +442,7 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
in_vector_size, 16,
acc_data_t, acc_data_t,
out_data_t>(tmp[I0], out_data_t>(tmp[I0],
tmp[I1], tmp[I1],
...@@ -525,10 +582,10 @@ int main(int argc, char* argv[]) ...@@ -525,10 +582,10 @@ int main(int argc, char* argv[])
#if 0 #if 0
if(do_log) if(do_log)
{ {
LogRange(std::cout << "in : ", in.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
LogRange(std::cout << "wei: ", wei.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
} }
#endif #endif
} }
......
...@@ -55,7 +55,7 @@ float launch_and_time_kernel(F kernel, ...@@ -55,7 +55,7 @@ float launch_and_time_kernel(F kernel,
{ {
KernelTimer timer; KernelTimer timer;
printf("%s: block_dim {%d, %d, %d}, grid_dim {%d, %d, %d} \n", printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__, __func__,
grid_dim.x, grid_dim.x,
grid_dim.y, grid_dim.y,
......
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_v1r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmK1 = 1;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 1>;
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 8, 2] for fp16
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmK1 = 2;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 2>;
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 2>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 8, 4] for i8
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmK1 = 4;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 4>;
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 4>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
Sequence<0, 0, 0, 0, 0>{}, // 3+: GemmN0
Sequence<0, 0, 0, 0, 0>{}, // 4+: GemmN10
Sequence<0, 0, 0, 0, 0>{}), // 5+: GemmN11
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmM0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM10
Sequence<0, 0, 0, 0, 0>{}, // 2-: GemmM11
Sequence<0, 0, 0, 0, 0>{}, // 3-: GemmN0
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_dynamic_gemm_v1r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperation::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlockM1,
GemmNPerBlockN1,
GemmKPerBlock,
GemmM1PerThreadM111,
GemmN1PerThreadN111,
GemmKPerThread,
GemmM11N11ThreadClusterM1100,
GemmM11N11ThreadClusterN1100,
GemmM11N11ThreadClusterM1101,
GemmM11N11ThreadClusterN1101,
GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1,
GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1,
Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1,
GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1,
Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_N11,
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks),
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks),
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks),
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks)>(
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Hi = in_n_hi_wi_c_lengths[I1];
const auto Wi = in_n_hi_wi_c_lengths[I2];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
}
...@@ -275,12 +275,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -275,12 +275,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0>{}, // 6-: M2
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1 Sequence<0, 0, 0, 0, 0>{})); // 7-: N1
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_dynamic_gemm_xdlops_v2r3< float ave_time = driver_dynamic_gemm_xdlops_v2r3<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment