Commit 71467cfc authored by ltqin's avatar ltqin
Browse files

add kbatch to CalculateBottomIndex

parent 3279fca1
...@@ -310,15 +310,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -310,15 +310,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
const auto N0 = N / N1; const auto N0 = N / N1;
#if 1 #if 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto c_blockid_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), make_tuple(make_merge_transform(make_tuple(KBatch, M0, N0))),
make_tuple(Sequence<0, 1>{}), make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
#elif 1 #elif 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto c_blockid_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))), make_tuple(make_merge_transform(make_tuple(KBatch, N0, M0))),
make_tuple(Sequence<1, 0>{}), make_tuple(Sequence<0, 2, 1>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
#endif #endif
return c_blockid_to_m0_n0_block_cluster_adaptor; return c_blockid_to_m0_n0_block_cluster_adaptor;
...@@ -345,32 +345,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -345,32 +345,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); // divide block work by [B, M, N]
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
const auto b_grid_size = CalculateBatchGridSize(M, N);
const auto k_batch_id = get_block_1d_id() / b_grid_size;
const auto block_id_in_batch = get_block_1d_id() % b_grid_size;
if(get_block_1d_id() == 200000)
printf("grid size: %d, k0: %d, blockid: %d, threadid %d, Batch: %d block_id: %d \n",
b_grid_size,
K0,
get_block_1d_id(),
get_thread_local_1d_id(),
k_batch_id,
block_id_in_batch);
// divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(block_id_in_batch)); c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t k_batch_id = block_work_idx[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
// lds max alignment // lds max alignment
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
......
...@@ -76,7 +76,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk ...@@ -76,7 +76,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
constexpr index_t KBatch = 64; constexpr index_t KBatch = 96;
#elif 1 #elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment