"docs/en/get_started/faq.md" did not exist on "6738247142db175faee67ec1f3a0efbb5016d1a9"
Commit facdb52e authored by Astha Rai's avatar Astha Rai
Browse files

changed blockID to 2D

parent 76b44c60
...@@ -105,14 +105,23 @@ struct GridwiseElementwise_2D ...@@ -105,14 +105,23 @@ struct GridwiseElementwise_2D
const index_t blockPerGrid_m = get_grid_size(); const index_t blockPerGrid_m = get_grid_size();
const index_t blockPerGrid_n = gridDim.y; const index_t blockPerGrid_n = gridDim.y;
const index_t block_1d = get_block_1d_id(); const index_t block_1d = get_block_1d_id();
const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0); const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0);
const auto N = in_grid_2d_desc_tuple[I1].GetLength(I1); const auto N = in_grid_2d_desc_tuple[I1].GetLength(I1);
const auto M0 = math::integer_divide_ceil(M, MPerBlock); //define MPerBlock and NPerBlock
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
block_1d = block_1d % (M0 * N0); // swallow batch index
index_t idx_N0 = block_1d % N0;
index_t idx_M0 = block_1d / N0;
const index_t loop_step_m = blockPerGrid_m * blockSize * MPerThread; const index_t loop_step_m = blockPerGrid_m * blockSize * MPerThread;
const index_t loop_step_n = blockPerGrid_n * blockSize * NPerThread; const index_t loop_step_n = blockPerGrid_n * blockSize * NPerThread;
const auto loop_step_index = make_multi_index(loop_step_m, loop_step_n); const auto loop_step_index = make_multi_index(loop_step_m, loop_step_n);
const auto thread_global_id_2d = // const auto thread_global_id_2d =
thread_buffer_desc_mn.CalculateBottomIndex(make_multi_index(block_1d)); // thread_buffer_desc_mn.CalculateBottomIndex(make_multi_index(block_1d));
const auto blockId_m = thread_global_id_2d[I0]; const auto blockId_m = thread_global_id_2d[I0];
const auto blockId_n = thread_global_id_2d[I1]; const auto blockId_n = thread_global_id_2d[I1];
const auto thread_global_offset = const auto thread_global_offset =
...@@ -182,13 +191,17 @@ struct GridwiseElementwise_2D ...@@ -182,13 +191,17 @@ struct GridwiseElementwise_2D
// get reference to in data // get reference to in data
const auto in_data_refs = generate_tie( const auto in_data_refs = generate_tie(
// return type should be lvalue // return type should be lvalue
[&](auto I) -> const auto& { return in_thread_buf_tuple(offset); }, [&](auto I) -> const auto& {
return in_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumInput>{}); Number<NumInput>{});
// get referenec to dst data // get referenec to dst data
auto out_data_refs = generate_tie( auto out_data_refs = generate_tie(
// return type should be lvalue // return type should be lvalue
[&](auto I) -> auto& { return out_thread_buf_tuple(offset); }, [&](auto I) -> auto& {
return out_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumOutput>{}); Number<NumOutput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs); unpack2(elementwise_op, out_data_refs, in_data_refs);
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment