"git@developer.sourcefind.cn:wqshmzh/ktransformers.git" did not exist on "36fbeee341e283a93b6befa2a4d9085b7a5dd2b1"
Unverified Commit 646fcc26 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Merge pull request #47 from ROCmSoftwarePlatform/develop

Merge develop into master
parents 38a90b6e 6014185a
...@@ -19,7 +19,8 @@ template <typename GridwiseGemm, ...@@ -19,7 +19,8 @@ template <typename GridwiseGemm,
typename AK0MK1GridDesc, typename AK0MK1GridDesc,
typename BK0NK1GridDesc, typename BK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc, typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor> typename CBlockClusterAdaptor,
bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -37,7 +38,7 @@ __global__ void ...@@ -37,7 +38,7 @@ __global__ void
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
...@@ -81,7 +82,7 @@ __global__ void ...@@ -81,7 +82,7 @@ __global__ void
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
...@@ -102,7 +103,7 @@ template <index_t BlockSize, ...@@ -102,7 +103,7 @@ template <index_t BlockSize,
typename CMNGridDesc, typename CMNGridDesc,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t K0PerBlock,
index_t MPerXDL, index_t MPerXDL,
index_t NPerXDL, index_t NPerXDL,
index_t K1Value, index_t K1Value,
...@@ -158,13 +159,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -158,13 +159,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
} }
}(); }();
...@@ -173,13 +174,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -173,13 +174,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
} }
}(); }();
...@@ -217,7 +218,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -217,7 +218,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
K1 == b_k0_n_k1_grid_desc.GetLength(I2))) K1 == b_k0_n_k1_grid_desc.GetLength(I2)))
return false; return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false; return false;
// check M01, N01 // check M01, N01
...@@ -245,6 +246,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -245,6 +246,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return grid_size; return grid_size;
} }
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1;
return has_main_k0_block_loop;
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{ {
...@@ -255,13 +263,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -255,13 +263,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
} }
}(); }();
...@@ -270,13 +278,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -270,13 +278,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
} }
}(); }();
...@@ -334,6 +342,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -334,6 +342,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1)); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1));
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -371,13 +380,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -371,13 +380,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
} }
}(); }();
...@@ -386,13 +395,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -386,13 +395,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
} }
}(); }();
...@@ -400,7 +409,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -400,7 +409,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -426,7 +435,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -426,7 +435,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
...@@ -450,8 +459,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -450,8 +459,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS // a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS // b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
...@@ -477,8 +486,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -477,8 +486,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatAB* p_a_block = p_shared_block; FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size; FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
...@@ -504,8 +513,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -504,8 +513,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
// main body // main body
index_t k_block_data_begin = 0; index_t k0_block_data_begin = 0;
if constexpr(HasMainKBlockLoop)
{
do do
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
...@@ -515,11 +526,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -515,11 +526,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
b_block_slice_copy_step, b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack); b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_blockwise_copy.RunRead(
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); b_blockwise_copy.RunRead(
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
...@@ -528,8 +541,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -528,8 +541,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
k_block_data_begin += KPerBlock; k0_block_data_begin += K0PerBlock;
} while(k_block_data_begin < (K0 - KPerBlock)); } while(k0_block_data_begin < (K0 - K0PerBlock));
}
// tail // tail
{ {
......
...@@ -19,7 +19,8 @@ template <typename GridwiseGemm, ...@@ -19,7 +19,8 @@ template <typename GridwiseGemm,
typename ABK0MK1GridDesc, typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc, typename BBK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc, typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor> typename CBlockClusterAdaptor,
bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -37,7 +38,7 @@ __global__ void ...@@ -37,7 +38,7 @@ __global__ void
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
...@@ -53,7 +54,8 @@ template <typename GridwiseGemm, ...@@ -53,7 +54,8 @@ template <typename GridwiseGemm,
typename ABK0MK1GridDesc, typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc, typename BBK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc, typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor> typename CBlockClusterAdaptor,
bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -81,7 +83,7 @@ __global__ void ...@@ -81,7 +83,7 @@ __global__ void
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
...@@ -102,7 +104,7 @@ template <index_t BlockSize, ...@@ -102,7 +104,7 @@ template <index_t BlockSize,
typename CMNGridDesc, typename CMNGridDesc,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t K0PerBlock,
index_t MPerXDL, index_t MPerXDL,
index_t NPerXDL, index_t NPerXDL,
index_t K1Value, index_t K1Value,
...@@ -158,13 +160,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -158,13 +160,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
} }
}(); }();
...@@ -173,13 +175,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -173,13 +175,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
} }
}(); }();
...@@ -220,7 +222,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -220,7 +222,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
return false; return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false; return false;
// check M01, N01 // check M01, N01
...@@ -248,6 +250,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -248,6 +250,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
return grid_size; return grid_size;
} }
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = K0 > K0PerBlock;
return has_main_k0_block_loop;
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{ {
...@@ -258,13 +267,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -258,13 +267,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
} }
}(); }();
...@@ -273,13 +282,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -273,13 +282,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
} }
}(); }();
...@@ -338,6 +347,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -338,6 +347,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -376,13 +386,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -376,13 +386,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
} }
}(); }();
...@@ -390,8 +400,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -390,8 +400,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<KPerBlock>{} * Number<MPerBlock + 1>{} * K1, make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
Number<MPerBlock + 1>{} * K1, Number<MPerBlock + 1>{} * K1,
K1, K1,
I1)); I1));
...@@ -399,7 +409,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -399,7 +409,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
max_lds_align); max_lds_align);
} }
}(); }();
...@@ -408,13 +418,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -408,13 +418,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
} }
}(); }();
...@@ -422,8 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -422,8 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<KPerBlock>{} * Number<NPerBlock + 1>{} * K1, make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
Number<NPerBlock + 1>{} * K1, Number<NPerBlock + 1>{} * K1,
K1, K1,
I1)); I1));
...@@ -431,7 +441,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -431,7 +441,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
max_lds_align); max_lds_align);
} }
}(); }();
...@@ -439,7 +449,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -439,7 +449,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<1, KPerBlock, MPerBlock, K1>, Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -466,7 +476,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -466,7 +476,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<1, KPerBlock, NPerBlock, K1>, Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
...@@ -491,8 +501,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -491,8 +501,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS // a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS // b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
...@@ -518,8 +528,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -518,8 +528,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
FloatAB* p_a_block = p_shared_block; FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size; FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
...@@ -546,7 +556,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -546,7 +556,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// main body // main body
index_t k_block_data_begin = 0; index_t k_block_data_begin = 0;
if constexpr(HasMainKBlockLoop)
{
do do
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc,
...@@ -556,11 +567,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -556,11 +567,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
b_block_slice_copy_step, b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack); b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_blockwise_copy.RunRead(
a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); b_blockwise_copy.RunRead(
b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
...@@ -569,8 +582,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -569,8 +582,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
k_block_data_begin += KPerBlock; k_block_data_begin += K0PerBlock;
} while(k_block_data_begin < (K0 - KPerBlock)); } while(k_block_data_begin < (K0 - K0PerBlock));
}
// tail // tail
{ {
......
...@@ -95,12 +95,10 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_ ...@@ -95,12 +95,10 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
const auto GemmN = Y * X * C; const auto GemmN = Y * X * C;
const auto GemmKTotal = N * Ho * Wo; const auto GemmKTotal = N * Ho * Wo;
const auto GemmK = GemmKTotal / GemmK1;
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); const index_t GemmK0 =
const index_t GemmK0 = BatchLen * GemmKPerBlock; math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
......
...@@ -123,12 +123,10 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -123,12 +123,10 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
const auto GemmN = K; const auto GemmN = K;
const auto GemmKTotal = N * Ho * Wo; const auto GemmKTotal = N * Ho * Wo;
const auto GemmK = GemmKTotal / GemmK1;
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); const index_t GemmK0 =
const index_t GemmK0 = BatchLen * GemmKPerBlock; math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
......
...@@ -107,8 +107,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ ...@@ -107,8 +107,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0 #elif 1
// [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32 // [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32 and fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -291,12 +291,10 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ ...@@ -291,12 +291,10 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
const auto GemmN = Y * X * C; const auto GemmN = Y * X * C;
const auto GemmKTotal = N * Ho * Wo; const auto GemmKTotal = N * Ho * Wo;
const auto GemmK = GemmKTotal / GemmK1;
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); const index_t GemmK0 =
const index_t GemmK0 = BatchLen * GemmKPerBlock; math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
......
...@@ -160,7 +160,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -160,7 +160,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0 #elif 1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -188,7 +188,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -188,7 +188,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1 #elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
......
...@@ -148,16 +148,25 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -148,16 +148,25 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0MK1GridDesc>, remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>, remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>, remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>; remove_reference_t<CBlockClusterAdaptor>,
true>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE ave_time = launch_and_time_kernel(kernel,
float ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -169,7 +178,31 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -169,7 +178,31 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
...@@ -181,7 +214,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -181,7 +214,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel( if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>,
true>;
ave_time = launch_and_time_kernel(
kernel, kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
...@@ -194,7 +238,37 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -194,7 +238,37 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); cast_pointer_to_constant_address_space(
c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
}
#endif #endif
return ave_time; return ave_time;
} }
......
...@@ -156,16 +156,46 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -156,16 +156,46 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
std::cout << "gridSize : " << grid_size << std::endl; std::cout << "gridSize : " << grid_size << std::endl;
} }
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm, const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<ABK0MK1GridDesc>, remove_reference_t<ABK0MK1GridDesc>,
remove_reference_t<BBK0NK1GridDesc>, remove_reference_t<BBK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>, remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>; remove_reference_t<CBlockClusterAdaptor>,
true>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE ave_time = launch_and_time_kernel(kernel,
float ave_time = launch_and_time_kernel(kernel, nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<ABK0MK1GridDesc>,
remove_reference_t<BBK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -177,6 +207,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -177,6 +207,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
b_b_k0_n_k1_grid_desc, b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_b_k0_m_k1_grid_desc_dev_buf(sizeof(ABK0MK1GridDesc)); DeviceMem a_b_k0_m_k1_grid_desc_dev_buf(sizeof(ABK0MK1GridDesc));
...@@ -189,7 +220,43 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -189,7 +220,43 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel( if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<ABK0MK1GridDesc>,
remove_reference_t<BBK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<ABK0MK1GridDesc>,
remove_reference_t<BBK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>,
false>;
ave_time = launch_and_time_kernel(
kernel, kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
...@@ -202,7 +269,9 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -202,7 +269,9 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); cast_pointer_to_constant_address_space(
c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
#endif #endif
return ave_time; return ave_time;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment