Commit c23de07d authored by root's avatar root
Browse files

add kthread

parent b9b089f3
...@@ -55,14 +55,17 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -55,14 +55,17 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0, static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0,
"wrong! Cannot evenly divide work among\n"); "wrong! Cannot evenly divide work among\n");
constexpr auto KThreadCluster = K / KPerThread;
constexpr auto HThreadCluster = H / HPerThread; constexpr auto HThreadCluster = H / HPerThread;
constexpr auto WThreadCluster = W / WPerThread; constexpr auto WThreadCluster = W / WPerThread;
static_assert(BlockSize == HThreadCluster * WThreadCluster, "wrong! wrong blocksize\n"); static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
"wrong! wrong blocksize\n");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.k)); mMyThreadOffsetA =
BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.k * KPerThread));
} }
__device__ static constexpr auto GetThreadMatrixCLengths() __device__ static constexpr auto GetThreadMatrixCLengths()
...@@ -75,13 +78,15 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -75,13 +78,15 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{}); constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{});
constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{}); constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{});
constexpr auto num_w_threads = W / WPerThread; constexpr auto num_w_threads = W / WPerThread;
constexpr auto num_h_threads = H / HPerThread; constexpr auto num_h_threads = H / HPerThread;
constexpr auto num_hw_threads = num_w_threads * num_h_threads;
index_t k_thread_id = thread_id / (num_w_threads * num_h_threads); index_t k_thread_id = thread_id / num_hw_threads;
index_t hw_thread_id = thread_id % num_hw_threads;
index_t h_thread_id = thread_id / num_w_threads; index_t h_thread_id = hw_thread_id / num_w_threads;
index_t w_thread_id = thread_id % num_w_threads; index_t w_thread_id = hw_thread_id % num_w_threads;
return MatrixIndex{k_thread_id, h_thread_id, w_thread_id}; return MatrixIndex{k_thread_id, h_thread_id, w_thread_id};
} }
...@@ -127,8 +132,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -127,8 +132,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto CYXPerBlock = a_block_mtx.GetLength(I0); constexpr auto CYXPerBlock = a_block_mtx.GetLength(I0);
static_assert(CYXPerBlock == CYXPerThreadLoop, "");
// thread A, B for GEMM // thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<CYXPerThreadLoop>{}, Number<KPerThread>{})); make_tuple(Number<CYXPerThreadLoop>{}, Number<KPerThread>{}));
...@@ -153,23 +156,21 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -153,23 +156,21 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
// loop over k // loop over k
for(index_t cyx_begin = 0; cyx_begin < CYXPerBlock; cyx_begin += CYXPerThreadLoop) for(index_t cyx_begin = 0; cyx_begin < CYXPerBlock; cyx_begin += CYXPerThreadLoop)
{ {
#if 1
a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(cyx_begin, 0)) + a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(cyx_begin, 0)) +
mMyThreadOffsetA, mMyThreadOffsetA,
p_a_thread + p_a_thread);
b_thread_mtx.CalculateOffset(make_tuple(cyx_begin, 0, 0, 0)));
#else threadwise_gemm.Run(p_a_thread,
for(index_t i = 0; i < a_thread_mtx.GetElementSpaceSize(); i++) p_b_thread +
p_a_thread[i] = 1; b_thread_mtx.CalculateOffset(make_tuple(cyx_begin, 0, 0, 0)),
#endif p_c_thread);
threadwise_gemm.Run(p_a_thread, p_b_thread + cyx_begin, p_c_thread);
} }
} }
template <typename FloatA, typename FloatB, typename FloatC> template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const __device__ void Run(const FloatA* p_a_block, const FloatB* p_b_thread, FloatC* p_c_thread) const
{ {
Run_naive(p_a_block, p_b_block, p_c_thread); Run_naive(p_a_block, p_b_thread, p_c_thread);
} }
}; };
......
...@@ -53,7 +53,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -53,7 +53,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr auto max_lds_align = constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<KPerThread>{}); math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -92,7 +92,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -92,7 +92,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// divide block work by [M, N] // divide block work by [M, N]
#if 1 #if 1
const auto m_block_work_num = K / Number<KPerBlock>{}; const auto k_block_work_num = K / Number<KPerBlock>{};
const auto h_block_work_num = H / Number<HPerBlock>{}; const auto h_block_work_num = H / Number<HPerBlock>{};
const auto w_block_work_num = W / Number<WPerBlock>{}; const auto w_block_work_num = W / Number<WPerBlock>{};
const auto hw_block_work_num = h_block_work_num * w_block_work_num; const auto hw_block_work_num = h_block_work_num * w_block_work_num;
...@@ -102,7 +102,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -102,7 +102,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
#else #else
// Hack: this force result into SGPR // Hack: this force result into SGPR
const index_t m_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock); const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock);
const index_t h_block_work_num = __builtin_amdgcn_readfirstlane(H / HPerBlock); const index_t h_block_work_num = __builtin_amdgcn_readfirstlane(H / HPerBlock);
const index_t w_block_work_num = __builtin_amdgcn_readfirstlane(W / WPerBlock); const index_t w_block_work_num = __builtin_amdgcn_readfirstlane(W / WPerBlock);
const index_t hw_block_work_num = h_block_work_num * w_block_work_num; const index_t hw_block_work_num = h_block_work_num * w_block_work_num;
...@@ -115,11 +115,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -115,11 +115,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const index_t h_block_work_id = hw_block_work_id / w_block_work_num; const index_t h_block_work_id = hw_block_work_id / w_block_work_num;
const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num; const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num;
static_assert(KPerBlock == KPerThread, "");
// lds max alignment // lds max alignment
constexpr auto max_lds_align = constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<KPerThread>{}); math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -161,7 +159,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -161,7 +159,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const index_t h_block_data_on_global = h_block_work_id * HPerBlock; const index_t h_block_data_on_global = h_block_work_id * HPerBlock;
const index_t w_block_data_on_global = w_block_work_id * WPerBlock; const index_t w_block_data_on_global = w_block_work_id * WPerBlock;
const index_t k_thread_data_on_global = k_block_data_on_global + k_thread_id * KPerThread;
const index_t h_thread_data_on_global = h_block_data_on_global + h_thread_id * HPerThread; const index_t h_thread_data_on_global = h_block_data_on_global + h_thread_id * HPerThread;
const index_t w_thread_data_on_global = w_block_data_on_global + w_thread_id * WPerThread; const index_t w_thread_data_on_global = w_block_data_on_global + w_thread_id * WPerThread;
...@@ -211,10 +208,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -211,10 +208,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
AddressSpace::Vgpr, AddressSpace::Vgpr,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
1, 1,
true>( true>(b_cyx_n_h_w_global_desc,
b_cyx_n_h_w_global_desc, make_multi_index(0, 0, h_thread_data_on_global, w_thread_data_on_global));
make_multi_index(
k_thread_data_on_global, 0, h_thread_data_on_global, w_thread_data_on_global));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
...@@ -380,15 +375,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -380,15 +375,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
#if 1 #if 1
// output: register to global memory // output: register to global memory
{ {
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_k_n_h_w_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
// hack to control index calculation when iterating over c_k_n_h_w_global tensor // hack to control index calculation when iterating over c_k_n_h_w_global tensor
constexpr auto c_k_n_h_w_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_h_w_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread;
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat, AccFloat,
Float, Float,
......
...@@ -68,7 +68,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -68,7 +68,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
#endif #endif
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 128;
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HPerBlock = 8; constexpr index_t HPerBlock = 8;
...@@ -78,7 +78,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -78,7 +78,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t KPerThread = 8; constexpr index_t KPerThread = 8;
constexpr index_t HPerThread = 1; constexpr index_t HPerThread = 1;
constexpr index_t WPerThread = 1; constexpr index_t WPerThread = 1;
constexpr index_t CYXPerThread = 4; constexpr index_t CYXPerThread = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment