"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "12856c3b39287eabe15b6f2f1a6bbea9a6e33f3a"
Commit 63f87662 authored by aska-0096's avatar aska-0096
Browse files

tidy up

parent 13af8cc4
...@@ -22,20 +22,13 @@ using CElementOp = PassThrough; ...@@ -22,20 +22,13 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
// using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmWmma
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWMMA|NMMMA| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 6, 1>;
// clang-format on
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWmma|NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| // ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -201,7 +201,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -201,7 +201,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma<
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
...@@ -353,7 +354,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -353,7 +354,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{ {
const auto kernel = kernel_gemm_wmma< const auto kernel = kernel_gemm_wmma<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType,
BDataType,
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
...@@ -384,7 +386,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -384,7 +386,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{ {
const auto kernel = kernel_gemm_wmma< const auto kernel = kernel_gemm_wmma<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType,
BDataType,
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
...@@ -33,8 +34,8 @@ __global__ void ...@@ -33,8 +34,8 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_wmma( kernel_gemm_wmma(
const FloatAB* __restrict__ p_a_grid, const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
...@@ -77,7 +78,8 @@ __global__ void ...@@ -77,7 +78,8 @@ __global__ void
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatAcc, typename FloatAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
...@@ -216,7 +218,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -216,7 +218,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto b_block_space_size_aligned = math::integer_least_multiple( constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); return (a_block_space_size_aligned * sizeof(FloatA) + b_block_space_size_aligned * sizeof(FloatB));
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -270,120 +272,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -270,120 +272,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
// Vector write
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<
BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1),
decltype(b_block_desc_k0perblock_nperblock_k1),
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>;
return BlockwiseGemm::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_m_n);
}
// Per pixel
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
const CGridDesc_M_N& c_grid_desc_m_n)
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<
BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1),
decltype(b_block_desc_k0perblock_nperblock_k1),
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>;
return BlockwiseGemm::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
c_grid_desc_m_n);
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
...@@ -410,11 +298,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -410,11 +298,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n); c_grid_desc_m_n);
} }
// using
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// = remove_cvref_t<decltype(
// MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
// CGridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
...@@ -422,17 +306,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -422,17 +306,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup&
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
...@@ -476,8 +357,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -476,8 +357,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>, /* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatAB, /* typename SrcData, */ FloatA,
/* typename DstData, */ FloatAB, /* typename DstData, */ FloatA,
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), /* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), /* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
...@@ -496,8 +377,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -496,8 +377,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
a_block_desc_k0perblock_mperblock_k1, a_block_desc_k0perblock_mperblock_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// printf("BlockSliceLengths K0 = %d, M = %d, K1 = %d\n", K0PerBlock, MPerBlock, K1());
// printf("a_block_wise_copy: %s\n", std::string(type_name<decltype(a_blockwise_copy)>()).c_str());
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
...@@ -508,8 +387,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -508,8 +387,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatB,
FloatAB, FloatB,
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0perblock_nperblock_k1), decltype(b_block_desc_k0perblock_nperblock_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -530,18 +409,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -530,18 +409,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
/*******************************************************************************/ /*******************************************************************************/
// GEMM definition // GEMM
// c_mtx += a_mtx * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in register
constexpr auto WmmaK = 16; constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
FloatAB, FloatA,
FloatB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1), decltype(a_block_desc_k0perblock_mperblock_k1),
decltype(b_block_desc_k0perblock_nperblock_k1), decltype(b_block_desc_k0perblock_nperblock_k1),
...@@ -557,8 +432,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -557,8 +432,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/ /*******************************************************************************/
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatAB*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatA*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize());
// Shift Per SUB_K // Shift Per SUB_K
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
...@@ -582,101 +457,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -582,101 +457,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
c_thread_buf, c_thread_buf,
K0BlockMainLoop); K0BlockMainLoop);
/*******************************************************************************/ /*******************************************************************************/
#ifdef CK_EXPERIMENTAL_ARBITRARY_WRITEOUT // write out to C, implement shuffle
// write out C matrix, c shuffle not implemented
{
static_for<0, 16, 1>{}([&](auto i){
char info[4];
info[0] = 'C';
info[1] = i/10 + '0';
info[2] = i%10 + '0';
info[3] = '\0';
debug_hexprinter(0xffffffff, c_thread_buf[Number<i>{}], info);
});
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// This API Provide All dimension (size) you need
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1);
constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2);
constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4);
constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5);
constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6);
// printf("MWave = %d, MSubGroup = %d, NWave = %d, NThreadPerSubGroup = %d, MAccVgprs = %d\n", MWave, MSubGroup, NWave, NThreadPerSubGroup, MAccVgprs);
// Mapping
const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
const index_t m_thread_data_on_grid = m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
// Checked
// debug_hexprinter(0xffffffff, m_thread_data_on_grid, "c_m");
// debug_hexprinter(0xffffffff, n_thread_data_on_grid, "c_n");
const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx = m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_grid));
debug_hexprinter(0x4, MRepeat, "mblockxrepeat");
debug_hexprinter(0x2, MWave, "mwave");
debug_hexprinter(0x2, MSubGroup, "msubgroup");
debug_hexprinter(0x8, MAccVgprs, "maccvgprs");
debug_hexprinter(0x4, NWave, "nwave");
const auto n_thread_data_on_grid_idx = n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid));
// printf("write out dimension access order = (%d, %d, %d, %d, %d, %d, %d)\n", CThreadTransferSrcDstAccessOrder{}[Number<0>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<1>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<2>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<3>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<4>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<5>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<6>{}].value);
auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<
/* typename SrcData */ FloatAcc,
/* typename DstData */ FloatC,
/* typename SrcDesc */ decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
/* typename DstDesc */ decltype(c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup),
/* typename ElementwiseOperation */ CElementwiseOperation,
// Thread register Mapping 0 1 2 4 5 6 3
/* typename SliceLengths */ Sequence<MRepeat, I1, I1, NRepeat, I1, I1, MAccVgprs>,
/* typename DimAccessOrder */ CThreadTransferSrcDstAccessOrder,
/* index_t DstVectorDim */ CThreadTransferSrcDstVectorDim,
/* index_t DstScalarPerVector */ CThreadTransferDstScalarPerVector,
/* InMemoryDataOperationEnum DstInMemOp */ CGlobalMemoryDataOperation,
/* index_t DstScalarStrideInVector */ 1,
/* bool DstResetCoordinateAfterRun */ true>
{
/* dst_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
/* dst_slice_origin_idx */ make_multi_index(m_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
n_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I2]),
/* element_op */ c_element_op
};
c_thread_copy.Run(
/* c_thread_desc */ c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
/* c_register_beginning*/ make_tuple(I0, I0, I0, I0, I0, I0, I0),
/* c_local(register) */ c_thread_buf,
/* c_grid_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
/* c_grid_buf */ c_grid_buf);
}
#endif
{ {
// write out to C, implement shuffle
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
......
...@@ -128,12 +128,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -128,12 +128,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
// printf("src_access_lengths: %d, %d, %d\n", (src_access_lengths[Number<0>{}])(), src_access_lengths[Number<1>{}](), src_access_lengths[Number<2>{}]());
constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths = constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order); container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// printf("ordered_src_access_lengths: %d, %d, %d\n", (ordered_src_access_lengths[Number<0>{}])(), ordered_src_access_lengths[Number<1>{}](), ordered_src_access_lengths[Number<2>{}]());
// make forward steps // make forward steps
const auto src_forward_steps = generate_tuple( const auto src_forward_steps = generate_tuple(
...@@ -210,7 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -210,7 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)}; src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// apply SrcElementwiseOperation on src_vector_container // apply SrcElementwiseOperation on src_vector_container
// debug_hexprinter(0xffffffff, src_coord_.GetOffset());
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
SrcData src_v; SrcData src_v;
......
...@@ -283,51 +283,51 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8, ...@@ -283,51 +283,51 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
} }
}; };
template <typename src_type, typename dst_type, index_t MPerWmma, index_t NPerWmma> template <typename src_type_a, typename src_type_b, typename dst_type, index_t MPerWmma, index_t NPerWmma>
struct WmmaSelector struct WmmaSelector
{ {
template <typename src_type_, typename dst_type_, index_t MPerWmma_, index_t NPerWmma_> template <typename src_type_a_, typename src_type_b_, typename dst_type_, index_t MPerWmma_, index_t NPerWmma_>
static constexpr auto GetWmma(); static constexpr auto GetWmma();
template <> template <>
static constexpr auto GetWmma<half_t, float, 16, 16>() static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{ {
return WmmaInstr::wmma_f32_16x16x16_f16; return WmmaInstr::wmma_f32_16x16x16_f16;
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, float, 16, 16>() static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
{ {
return WmmaInstr::wmma_f32_16x16x16_bf16; return WmmaInstr::wmma_f32_16x16x16_bf16;
} }
template <> template <>
static constexpr auto GetWmma<half_t, half_t, 16, 16>() static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
{ {
return WmmaInstr::wmma_f16_16x16x16_f16; return WmmaInstr::wmma_f16_16x16x16_f16;
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, 16, 16>() static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
{ {
return WmmaInstr::wmma_bf16_16x16x16_bf16; return WmmaInstr::wmma_bf16_16x16x16_bf16;
} }
template <> template <>
static constexpr auto GetWmma<int8_t, float, 16, 16>() static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
{ {
return WmmaInstr::wmma_i32_16x16x16_iu8; return WmmaInstr::wmma_i32_16x16x16_iu8;
} }
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <> template <>
static constexpr auto GetWmma<int4_t, float, 16, 16>() static constexpr auto GetWmma<int4_t, int, 16, 16>()
{ {
return WmmaInstr::wmma_i32_16x16x16_iu4; return WmmaInstr::wmma_i32_16x16x16_iu4;
} }
#endif #endif
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
static constexpr auto selected_wmma = static constexpr auto selected_wmma =
wmma_type<GetWmma<src_type, dst_type, MPerWmma, NPerWmma>(), Number<32>{}>{}; wmma_type<GetWmma<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>(), Number<32>{}>{};
__host__ __device__ constexpr WmmaSelector() __host__ __device__ constexpr WmmaSelector()
{ {
...@@ -344,7 +344,8 @@ struct WmmaSelector ...@@ -344,7 +344,8 @@ struct WmmaSelector
} }
}; };
template <typename src_type, template <typename src_type_a,
typename src_type_b,
typename dst_type, typename dst_type,
index_t MPerWmma, index_t MPerWmma,
index_t NPerWmma, index_t NPerWmma,
...@@ -412,46 +413,6 @@ struct WmmaGemm ...@@ -412,46 +413,6 @@ struct WmmaGemm
Sequence<5>{})); Sequence<5>{}));
} }
// Per-Pixel write
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{
const auto MBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(
make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{})),
make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(NWave),
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2, 3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{}));
}
__device__ static constexpr index_t GetRegSizePerWmma() __device__ static constexpr index_t GetRegSizePerWmma()
{ {
return wmma_instr.num_acc_vgprs_per_wave; return wmma_instr.num_acc_vgprs_per_wave;
...@@ -463,13 +424,13 @@ struct WmmaGemm ...@@ -463,13 +424,13 @@ struct WmmaGemm
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert( static_assert(
(is_same<src_type, half_t>::value && is_same<dst_type, float>::value) || (is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value && is_same<dst_type, float>::value) ||
(is_same<src_type, bhalf_t>::value && is_same<dst_type, float>::value) || (is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value && is_same<dst_type, float>::value) ||
(is_same<src_type, half_t>::value && is_same<dst_type, half_t>::value) || (is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value && is_same<dst_type, half_t>::value) ||
(is_same<src_type, bhalf_t>::value && is_same<dst_type, bhalf_t>::value) || (is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value && is_same<dst_type, bhalf_t>::value) ||
(is_same<src_type, int8_t>::value && is_same<dst_type, int32_t>::value) (is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value && is_same<dst_type, int32_t>::value)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| (is_same<src_type, int4_t>::value && is_same<dst_type, int32_t>::value) || (is_same<src_type_a, int4_t>::value && is_same<src_type_b, int4_t>::value && is_same<dst_type, int32_t>::value)
#endif #endif
, ,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
...@@ -518,7 +479,7 @@ struct WmmaGemm ...@@ -518,7 +479,7 @@ struct WmmaGemm
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
} }
static constexpr auto wmma = WmmaSelector<src_type, dst_type, MPerWmma, NPerWmma>{}; static constexpr auto wmma = WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma_instr = wmma.selected_wmma; static constexpr auto wmma_instr = wmma.selected_wmma;
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment