Commit 71467cfc authored by ltqin's avatar ltqin
Browse files

add kbatch to CalculateBottomIndex

parent 3279fca1
......@@ -310,14 +310,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
const auto N0 = N / N1;
#if 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
const auto c_blockid_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KBatch, M0, N0))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
#elif 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))),
make_tuple(Sequence<1, 0>{}),
const auto c_blockid_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KBatch, N0, M0))),
make_tuple(Sequence<0, 2, 1>{}),
make_tuple(Sequence<0>{}));
#endif
......@@ -346,31 +346,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
const auto b_grid_size = CalculateBatchGridSize(M, N);
const auto k_batch_id = get_block_1d_id() / b_grid_size;
const auto block_id_in_batch = get_block_1d_id() % b_grid_size;
if(get_block_1d_id() == 200000)
printf("grid size: %d, k0: %d, blockid: %d, threadid %d, Batch: %d block_id: %d \n",
b_grid_size,
K0,
get_block_1d_id(),
get_thread_local_1d_id(),
k_batch_id,
block_id_in_batch);
// divide block work by [M, N]
// divide block work by [B, M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(block_id_in_batch));
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t k_batch_id = block_work_idx[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
......
......@@ -76,7 +76,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
constexpr index_t KBatch = 64;
constexpr index_t KBatch = 96;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment