Commit 10947a54 authored by Astha Rai's avatar Astha Rai
Browse files

added variables to distribute threads through both dimensions

parent 0bff049a
......@@ -115,7 +115,7 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
desc,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<1>{}, Sequence<0>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize);
}
......@@ -150,6 +150,8 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
ElementwiseOperation,
MPerThread,
NPerThread,
//num_threads_m,
//num_threads_n,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
......@@ -199,6 +201,9 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
lengths, outStridesArray[I.value], gridSize_, blockSize_);
},
Number<NumOutput>{});
//num_threads_m = 1;
//num_threads_n = gridSize_ * blockSize_;
}
InDataTypePointerTuple in_dev_buffers_;
......@@ -264,14 +269,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector,
index_t vectorDim) {
// std::cout << "scalarPerVector: " << scalarPerVector << std::endl;
// std::cout << "stride back: " << strides.back() << std::endl;
// std::cout << "len back: " << lengths.back() << std::endl;
// std::cout << "NumDim-1: " << NumDim - 1 << std::endl;
// std::cout << "stride[nd-1]: " << strides[NumDim - 1] << std::endl;
// std::cout << "NumDim_m-1: " << NumDim_m - 1 << std::endl;
// std::cout << std::endl;
// std::cout << "ISPVV Check 1 starting" << std::endl;
if(strides[vectorDim] == 1 &&
(lengths[vectorDim] % scalarPerVector == 0 ||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
......@@ -293,18 +290,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
return false;
};
/**auto IsOutScalarPerVectorValid =
[&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector) {
std::cout << "ISPVV Check 1 starting" << std::endl;
if(strides.back() != 1 && lengths.back() % scalarPerVector == strides[NumDim - 1])
{
std::cout << "Check 1 passed " << std::endl;
return true;
}
std::cout << "Check 1 failed" << std::endl;
};**/
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
......
......@@ -104,18 +104,22 @@ struct GridwiseElementwise_2D
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const index_t totalNumThread = blockSize * blockPerGrid;
const index_t num_threads_m = 4;
const index_t num_threads_n = totalNumThread/4;
//static_assert(num_threads_m * num_threads_n == totalNumThread, "error: threads per dimension not equal to total threads");
const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0);
const auto N = in_grid_2d_desc_tuple[I0].GetLength(I1);
const index_t loop_step_m = MPerThread;
const index_t loop_step_n = totalNumThread * NPerThread;
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const index_t thread_1d_id = get_thread_global_1d_id();
// index_t tid_m = thread_1d_id / (N / NPerThread);
// index_t tid_n = thread_1d_id % (N / NPerThread);
index_t tid_m = thread_1d_id / num_threads_n;
index_t tid_n = thread_1d_id % num_threads_n;
const auto thread_global_offset = make_multi_index(0, thread_1d_id * NPerThread);
const auto thread_global_offset = make_multi_index(tid_m * MPerThread, tid_n * NPerThread);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment