"docs/source/en/vscode:/vscode.git/clone" did not exist on "30c977d1f5300bffdf14b92d06c90293e789b587"
Commit 4ef5865b authored by root's avatar root
Browse files

clean code

parent caa91db0
......@@ -19,8 +19,6 @@ template <index_t BlockSize,
index_t HPerThread,
index_t WPerThread,
index_t CYXPerThreadLoop,
index_t HThreadCluster,
index_t WThreadCluster,
index_t ThreadGemmADataPerRead_K,
index_t ThreadGemmBDataPerRead_W>
struct BlockwiseGemm_km_kn_m0m1n0n1_v3
......@@ -46,11 +44,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
// constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
// MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert(BlockSize == HThreadCluster * WThreadCluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n");
......@@ -59,10 +52,13 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr index_t H = BlockMatrixB{}.GetLength(I2);
constexpr index_t W = BlockMatrixB{}.GetLength(I3);
static_assert(
K % (KPerThread) == 0 &&
(N * H * W) % (HPerThread * WPerThread * HThreadCluster * WThreadCluster) == 0,
"wrong! Cannot evenly divide work among\n");
static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0,
"wrong! Cannot evenly divide work among\n");
constexpr auto HThreadCluster = H / HPerThread;
constexpr auto WThreadCluster = W / WPerThread;
static_assert(BlockSize == HThreadCluster * WThreadCluster, "wrong! wrong blocksize\n");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
......
......@@ -117,23 +117,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const index_t h_block_work_id = hw_block_work_id / w_block_work_num;
const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num;
constexpr auto h_num_threads = HPerBlock / HPerThread;
constexpr auto w_num_threads = WPerBlock / WPerThread;
static_assert(KPerBlock == KPerThread, "");
const auto k_thread_id = 0;
const auto h_thread_id = get_thread_local_1d_id() / w_num_threads;
const auto w_thread_id = get_thread_local_1d_id() % w_num_threads;
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t h_block_data_on_global = h_block_work_id * HPerBlock;
const index_t w_block_data_on_global = w_block_work_id * WPerBlock;
const index_t k_thread_data_on_global = k_block_data_on_global + k_thread_id * KPerThread;
const index_t h_thread_data_on_global = h_block_data_on_global + h_thread_id * HPerThread;
const index_t w_thread_data_on_global = w_block_data_on_global + w_thread_id * WPerThread;
// lds max alignment
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, Number<KPerThread>{});
......@@ -149,6 +134,39 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<CYXPerBlock>{}, Number<1>{}, Number<HPerBlock>{}, Number<WPerBlock>{}));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_h_w_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_cyx_k_block_desc),
decltype(b_cyx_n_h_w_block_desc),
decltype(c_k_n_h_w_thread_desc),
KPerThread, // KPerThreadSubC
HPerThread, // HPerThreadSubC
WPerThread, // WPerThreadSubC
CYXPerThread, // CYXPerThreadLoop
1, // ThreadGemmADataPerRead_K
1 // ThreadGemmBDataPerRead_W
>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const auto k_thread_id = c_thread_mtx_index.k;
const auto h_thread_id = c_thread_mtx_index.h;
const auto w_thread_id = c_thread_mtx_index.w;
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t h_block_data_on_global = h_block_work_id * HPerBlock;
const index_t w_block_data_on_global = w_block_work_id * WPerBlock;
const index_t k_thread_data_on_global = k_block_data_on_global + k_thread_id * KPerThread;
const index_t h_thread_data_on_global = h_block_data_on_global + h_thread_id * HPerThread;
const index_t w_thread_data_on_global = w_block_data_on_global + w_thread_id * WPerThread;
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
......@@ -182,7 +200,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<CYXPerBlock>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
using ThreadwiseTensorSliceTransferB = ThreadwiseDynamicTensorSliceTransfer_v2<
auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
Float,
Float,
decltype(b_cyx_n_h_w_global_desc),
......@@ -195,34 +213,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>;
ThreadwiseTensorSliceTransferB b_threadwise_transfer(
true>(
b_cyx_n_h_w_global_desc,
make_multi_index(0, 0, h_thread_data_on_global, w_thread_data_on_global));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_h_w_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
#if 1
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_cyx_k_block_desc),
decltype(b_cyx_n_h_w_block_desc),
decltype(c_k_n_h_w_thread_desc),
KPerThread, // KPerThreadSubC
HPerThread, // HPerThreadSubC
WPerThread, // WPerThreadSubC
CYXPerThread, // CYXPerThreadLoop
h_num_threads, // HThreadCluster
w_num_threads, // WThreadCluster
1, // ThreadGemmADataPerRead_K
1 // ThreadGemmBDataPerRead_W
>{};
#endif
make_multi_index(
k_thread_data_on_global, 0, h_thread_data_on_global, w_thread_data_on_global));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
......@@ -267,14 +261,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
b_cyx_n_h_w_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_double);
#if 0
__syncthreads();
p_c_thread[0] += p_b_thread_double[0] + p_b_thread_double[1] + p_b_thread_double[2];
p_c_thread[0] += p_b_thread_double[3] + p_b_thread_double[4] + p_b_thread_double[5];
p_c_thread[0] += p_b_thread_double[6] + p_b_thread_double[7] + p_b_thread_double[8];
#endif
}
#if 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment