Commit 08848bb6 authored by Astha Rai's avatar Astha Rai
Browse files

fixed 2d thread indexing

parent 5f01c06f
......@@ -107,14 +107,7 @@ struct GridwiseElementwise_2D
const index_t block_1d = get_block_1d_id();
const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0);
const auto N = in_grid_2d_desc_tuple[I1].GetLength(I1);
const auto M0 = math::integer_divide_ceil(M, MPerBlock); //define MPerBlock and NPerBlock
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
block_1d = block_1d % (M0 * N0); // swallow batch index
index_t idx_N0 = block_1d % N0;
index_t idx_M0 = block_1d / N0;
const auto N = in_grid_2d_desc_tuple[I0].GetLength(I1);
const index_t loop_step_m = blockPerGrid_m * blockSize * MPerThread;
const index_t loop_step_n = blockPerGrid_n * blockSize * NPerThread;
......@@ -122,10 +115,33 @@ struct GridwiseElementwise_2D
// const auto thread_global_id_2d =
// thread_buffer_desc_mn.CalculateBottomIndex(make_multi_index(block_1d));
const auto blockId_m = thread_global_id_2d[I0];
const auto blockId_n = thread_global_id_2d[I1];
const auto thread_global_offset =
make_multi_index(thread_global_id_2d * MPerThread, thread_global_id_2d * NPerThread);
auto thread_1d_id = get_thread_local_1d_id();
index_t M01_ = 8;
const auto M0 = math::integer_divide_ceil(M, MPerThread);
const auto N0 = math::integer_divide_ceil(N, NPerThread);
thread_1d_id = thread_1d_id % (M0 * N0); // swallow batch index
index_t idx_N0 = thread_1d_id % N0;
index_t idx_M0 = thread_1d_id / N0;
// const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
// index_t idx_M00 = idx_M0 / M01_;
// index_t idx_M01 = idx_M0 % M01_;
// index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
// const auto thread_global_id_2d =make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
// idx_N0_M01_local /
//M01_adapt);
index_t tid_m = get_thread_global_1d_id();
index_t tid_n = blockDim.y * blockIdx.y + threadIdx.y;
const auto thread_global_offset = make_multi_index(tid_m * MPerThread, tid_n * NPerThread);
// make_multi_index(thread_global_id_2d[I0] * MPerThread, thread_global_id_2d[I1] *
// NPerThread);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment