Commit 4ef5865b authored by root's avatar root
Browse files

clean code

parent caa91db0
...@@ -19,8 +19,6 @@ template <index_t BlockSize, ...@@ -19,8 +19,6 @@ template <index_t BlockSize,
index_t HPerThread, index_t HPerThread,
index_t WPerThread, index_t WPerThread,
index_t CYXPerThreadLoop, index_t CYXPerThreadLoop,
index_t HThreadCluster,
index_t WThreadCluster,
index_t ThreadGemmADataPerRead_K, index_t ThreadGemmADataPerRead_K,
index_t ThreadGemmBDataPerRead_W> index_t ThreadGemmBDataPerRead_W>
struct BlockwiseGemm_km_kn_m0m1n0n1_v3 struct BlockwiseGemm_km_kn_m0m1n0n1_v3
...@@ -46,11 +44,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -46,11 +44,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
// constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
// MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert(BlockSize == HThreadCluster * WThreadCluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n"); "wrong! K dimension not consistent\n");
...@@ -59,10 +52,13 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -59,10 +52,13 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr index_t H = BlockMatrixB{}.GetLength(I2); constexpr index_t H = BlockMatrixB{}.GetLength(I2);
constexpr index_t W = BlockMatrixB{}.GetLength(I3); constexpr index_t W = BlockMatrixB{}.GetLength(I3);
static_assert( static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0,
K % (KPerThread) == 0 && "wrong! Cannot evenly divide work among\n");
(N * H * W) % (HPerThread * WPerThread * HThreadCluster * WThreadCluster) == 0,
"wrong! Cannot evenly divide work among\n"); constexpr auto HThreadCluster = H / HPerThread;
constexpr auto WThreadCluster = W / WPerThread;
static_assert(BlockSize == HThreadCluster * WThreadCluster, "wrong! wrong blocksize\n");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
......
...@@ -117,23 +117,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -117,23 +117,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const index_t h_block_work_id = hw_block_work_id / w_block_work_num; const index_t h_block_work_id = hw_block_work_id / w_block_work_num;
const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num; const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num;
constexpr auto h_num_threads = HPerBlock / HPerThread;
constexpr auto w_num_threads = WPerBlock / WPerThread;
static_assert(KPerBlock == KPerThread, ""); static_assert(KPerBlock == KPerThread, "");
const auto k_thread_id = 0;
const auto h_thread_id = get_thread_local_1d_id() / w_num_threads;
const auto w_thread_id = get_thread_local_1d_id() % w_num_threads;
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t h_block_data_on_global = h_block_work_id * HPerBlock;
const index_t w_block_data_on_global = w_block_work_id * WPerBlock;
const index_t k_thread_data_on_global = k_block_data_on_global + k_thread_id * KPerThread;
const index_t h_thread_data_on_global = h_block_data_on_global + h_thread_id * HPerThread;
const index_t w_thread_data_on_global = w_block_data_on_global + w_thread_id * WPerThread;
// lds max alignment // lds max alignment
constexpr auto max_lds_align = constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<KPerThread>{}); math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<KPerThread>{});
...@@ -149,6 +134,39 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -149,6 +134,39 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<CYXPerBlock>{}, Number<1>{}, Number<HPerBlock>{}, Number<WPerBlock>{})); Number<CYXPerBlock>{}, Number<1>{}, Number<HPerBlock>{}, Number<WPerBlock>{}));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_h_w_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_cyx_k_block_desc),
decltype(b_cyx_n_h_w_block_desc),
decltype(c_k_n_h_w_thread_desc),
KPerThread, // KPerThreadSubC
HPerThread, // HPerThreadSubC
WPerThread, // WPerThreadSubC
CYXPerThread, // CYXPerThreadLoop
1, // ThreadGemmADataPerRead_K
1 // ThreadGemmBDataPerRead_W
>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const auto k_thread_id = c_thread_mtx_index.k;
const auto h_thread_id = c_thread_mtx_index.h;
const auto w_thread_id = c_thread_mtx_index.w;
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t h_block_data_on_global = h_block_work_id * HPerBlock;
const index_t w_block_data_on_global = w_block_work_id * WPerBlock;
const index_t k_thread_data_on_global = k_block_data_on_global + k_thread_id * KPerThread;
const index_t h_thread_data_on_global = h_block_data_on_global + h_thread_id * HPerThread;
const index_t w_thread_data_on_global = w_block_data_on_global + w_thread_id * WPerThread;
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
...@@ -182,7 +200,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -182,7 +200,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<CYXPerBlock>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<CYXPerBlock>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
using ThreadwiseTensorSliceTransferB = ThreadwiseDynamicTensorSliceTransfer_v2< auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
Float, Float,
Float, Float,
decltype(b_cyx_n_h_w_global_desc), decltype(b_cyx_n_h_w_global_desc),
...@@ -195,34 +213,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -195,34 +213,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
AddressSpace::Vgpr, AddressSpace::Vgpr,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
1, 1,
true>; true>(
ThreadwiseTensorSliceTransferB b_threadwise_transfer(
b_cyx_n_h_w_global_desc, b_cyx_n_h_w_global_desc,
make_multi_index(0, 0, h_thread_data_on_global, w_thread_data_on_global)); make_multi_index(
k_thread_data_on_global, 0, h_thread_data_on_global, w_thread_data_on_global));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_h_w_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
#if 1
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_cyx_k_block_desc),
decltype(b_cyx_n_h_w_block_desc),
decltype(c_k_n_h_w_thread_desc),
KPerThread, // KPerThreadSubC
HPerThread, // HPerThreadSubC
WPerThread, // WPerThreadSubC
CYXPerThread, // CYXPerThreadLoop
h_num_threads, // HThreadCluster
w_num_threads, // WThreadCluster
1, // ThreadGemmADataPerRead_K
1 // ThreadGemmBDataPerRead_W
>{};
#endif
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
...@@ -267,14 +261,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -267,14 +261,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
b_cyx_n_h_w_global_iterator_hacks); b_cyx_n_h_w_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_double); a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_double);
#if 0
__syncthreads();
p_c_thread[0] += p_b_thread_double[0] + p_b_thread_double[1] + p_b_thread_double[2];
p_c_thread[0] += p_b_thread_double[3] + p_b_thread_double[4] + p_b_thread_double[5];
p_c_thread[0] += p_b_thread_double[6] + p_b_thread_double[7] + p_b_thread_double[8];
#endif
} }
#if 1 #if 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment