Commit 03f7892a authored by Chao Liu's avatar Chao Liu
Browse files

replacing array with vector for tensor data

parent e8421cca
...@@ -12,6 +12,9 @@ namespace ck { ...@@ -12,6 +12,9 @@ namespace ck {
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster, // MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster // MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename BlockMatrixA, typename BlockMatrixA,
typename BlockMatrixB, typename BlockMatrixB,
typename ThreadMatrixC, typename ThreadMatrixC,
...@@ -104,7 +107,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -104,7 +107,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
} }
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void __device__ void
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{ {
...@@ -150,7 +152,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -150,7 +152,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
NPerThreadSubC, NPerThreadSubC,
ThreadGemmBDataPerRead_N>{}; ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_mtx), constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<FloatA,
FloatB,
FloatC,
decltype(a_thread_mtx),
decltype(b_thread_mtx), decltype(b_thread_mtx),
decltype(c_thread_mtx)>{}; decltype(c_thread_mtx)>{};
// loop over k // loop over k
...@@ -180,7 +185,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -180,7 +185,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
}); });
} }
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void __device__ void
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{ {
...@@ -243,7 +247,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -243,7 +247,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
NPerThreadSubC, NPerThreadSubC,
ThreadGemmBDataPerRead_N>{}; ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx), constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<FloatA,
FloatB,
FloatC,
decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx), decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{}; decltype(c_thread_sub_mtx)>{};
...@@ -331,7 +338,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -331,7 +338,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC))); p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
} }
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const __device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{ {
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE #if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
...@@ -540,7 +546,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -540,7 +546,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
FloatA p_a_thread[a_thread_mtx_desc_.GetElementSpaceSize()]; FloatA p_a_thread[a_thread_mtx_desc_.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx_desc_.GetElementSpaceSize()]; FloatB p_b_thread[b_thread_mtx_desc_.GetElementSpaceSize()];
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx), constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<FloatA,
FloatB,
FloatC,
decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx), decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{}; decltype(c_thread_sub_mtx)>{};
......
...@@ -52,6 +52,7 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl ...@@ -52,6 +52,7 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl
} }
#endif #endif
#if 1
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
...@@ -255,9 +256,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -255,9 +256,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{})); make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{}));
#if 1 // debug
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v1<BlockSize, BlockwiseGemm_km_kn_m0m1n0n1_v1<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc), decltype(b_k_n_block_desc),
decltype(c_m0m1_n0n1_thread_desc), decltype(c_m0m1_n0n1_thread_desc),
...@@ -270,7 +273,434 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -270,7 +273,434 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
NLevel1Cluster, NLevel1Cluster,
MPerThread, MPerThread,
NPerThread>{}; NPerThread>{};
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
// zero out threadwise output
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
}
if constexpr(HasMainKBlockLoop)
{
FloatAB* p_a_block_even = p_a_block_double;
FloatAB* p_b_block_even = p_b_block_double;
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
// output: register to global memory
{
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<MRepeat>{},
Number<MPerThread>{},
Number<NRepeat>{},
Number<NPerThread>{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
constexpr auto tmp = make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
true>(c_m0_m1_n0_n1_global_desc,
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0),
p_c_thread,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
}
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
#else
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{},
Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N]
#if 0
const auto m_block_work_num = M / Number<MPerBlock>{};
const auto n_block_work_num = N / Number<NPerBlock>{};
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#else #else
// Hack: this force result into SGPR
const index_t m_block_work_num = __builtin_amdgcn_readfirstlane(M / MPerBlock);
const index_t n_block_work_num = __builtin_amdgcn_readfirstlane(N / NPerBlock);
const index_t m_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / n_block_work_num);
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#endif
const index_t m_block_data_on_global = m_block_work_id * MPerBlock;
const index_t n_block_data_on_global = n_block_work_id * NPerBlock;
// lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{},
Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AddressSpace::Global,
AddressSpace::Lds,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k_m_global_desc,
make_multi_index(0, m_block_data_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1>,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
AddressSpace::Global,
AddressSpace::Lds,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k_n_global_desc,
make_multi_index(0, n_block_data_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v1r1<BlockSize, BlockwiseGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
FloatAB, FloatAB,
...@@ -289,8 +719,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -289,8 +719,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
MPerThread, MPerThread,
NPerThread>{}; NPerThread>{};
#endif
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
...@@ -514,6 +942,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -514,6 +942,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
}; };
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -1429,6 +1429,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1429,6 +1429,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) { static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// position in slice window // position in slice window
#if 0 // debug #if 0 // debug
// TODO: unable to compile
constexpr auto data_to_origin_disp_idx = constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) * container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access; src_scalar_per_access;
......
...@@ -57,7 +57,10 @@ struct ThreadwiseMatrixSliceCopy_v2 ...@@ -57,7 +57,10 @@ struct ThreadwiseMatrixSliceCopy_v2
// C[M, N] += transpose(A[K, M]) * B[K, N] // C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data // Element of matrix can be vectorized data
template <typename ADesc, template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc, typename BDesc,
typename CDesc, typename CDesc,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
...@@ -65,7 +68,6 @@ template <typename ADesc, ...@@ -65,7 +68,6 @@ template <typename ADesc,
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_km_kn_mn_v1 struct ThreadwiseGemm_km_kn_mn_v1
{ {
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) __device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
...@@ -94,7 +96,6 @@ struct ThreadwiseGemm_km_kn_mn_v1 ...@@ -94,7 +96,6 @@ struct ThreadwiseGemm_km_kn_mn_v1
} }
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) __device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
...@@ -157,7 +158,6 @@ struct ThreadwiseGemm_km_kn_mn_v1 ...@@ -157,7 +158,6 @@ struct ThreadwiseGemm_km_kn_mn_v1
} }
#endif #endif
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) __device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
......
...@@ -5,10 +5,14 @@ ...@@ -5,10 +5,14 @@
namespace ck { namespace ck {
template <typename T, index_t N> template <
struct StaticBuffer : public vector_type_maker<T, N>::type typename ScalarType,
index_t N,
typename std::enable_if<is_same<typename scalar_type<ScalarType>::type, ScalarType>::value,
bool>::type = false>
struct StaticBuffer : public vector_type<ScalarType, N>
{ {
using base = typename vector_type_maker<T, N>::type; using base = vector_type<ScalarType, N>;
__host__ __device__ constexpr StaticBuffer() : base{} {} __host__ __device__ constexpr StaticBuffer() : base{} {}
}; };
...@@ -16,7 +20,60 @@ struct StaticBuffer : public vector_type_maker<T, N>::type ...@@ -16,7 +20,60 @@ struct StaticBuffer : public vector_type_maker<T, N>::type
template <typename T, index_t N> template <typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>) __host__ __device__ constexpr auto make_static_buffer(Number<N>)
{ {
return StaticBuffer<T, N>{}; using scalar_t = scalar_type<T>;
constexpr index_t scalar_per_vector = scalar_type<T>::vector_size;
return StaticBuffer<scalar_t, N * scalar_per_vector>{};
}
template <
typename ScalarType,
typename std::enable_if<is_same<typename scalar_type<ScalarType>::type, ScalarType>::value,
bool>::type = false>
struct DynamicBuffer
{
template <typename T>
struct PointerWrapper
{
T* p_;
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_[i]; }
};
ScalarType* p_scalar_;
__host__ __device__ constexpr DynamicBuffer(ScalarType* p_scalar) : p_scalar_{p_scalar} {}
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
ScalarType>::value,
bool>::type = false>
__host__ __device__ constexpr const auto& AsType() const
{
return PointerWrapper<X>{reinterpret_cast<X*>(p_scalar_)};
}
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
ScalarType>::value,
bool>::type = false>
__host__ __device__ constexpr auto& AsType()
{
return PointerWrapper<X>{reinterpret_cast<X*>(p_scalar_)};
}
};
template <typename T>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p)
{
using scalar_t = scalar_type<T>;
constexpr index_t scalar_per_vector = scalar_type<T>::vector_size;
return DynamicBuffer<scalar_t>{p};
} }
} // namespace ck } // namespace ck
......
...@@ -28,11 +28,11 @@ ...@@ -28,11 +28,11 @@
#endif #endif
// launch bounds // launch bounds
#define CK_USE_LAUNCH_BOUNDS 0 #define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS #ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256 #define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MIN_BLOCK_PER_CU 1 #define CK_MIN_BLOCK_PER_CU 2
#endif #endif
// buffer resourse // buffer resourse
......
...@@ -728,19 +728,18 @@ int main(int argc, char* argv[]) ...@@ -728,19 +728,18 @@ int main(int argc, char* argv[])
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
out_data_t> out_data_t>(
in_nchw_desc,
(in_nchw_desc, in_nchw,
in_nchw, wei_kcyx_desc,
wei_kcyx_desc, wei_kcyx,
wei_kcyx, out_nkhw_desc,
out_nkhw_desc, out_nkhw_device,
out_nkhw_device, ConvStrides{},
ConvStrides{}, ConvDilations{},
ConvDilations{}, LeftPads{},
LeftPads{}, RightPads{},
RightPads{}, nrepeat);
nrepeat);
#elif 0 #elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size, in_vector_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment