Commit b4abe4e2 authored by Astha Rai's avatar Astha Rai
Browse files

integrated variable for thread distribution into device elementwise and added...

integrated variable for thread distribution into device elementwise and added as parameter for gridwise elementwise
parent b61bb071
...@@ -69,12 +69,14 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -69,12 +69,14 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_MN> template <typename Desc_MN>
static auto PadDescriptor_MN_2d(Desc_MN desc_mn, index_t gridSize, index_t blockSize) static auto PadDescriptor_MN_2d(Desc_MN desc_mn, index_t gridSize, index_t blockSize, index_t num_threads_m, index_t num_threads_n)
{ {
std::ignore = blockSize;
std::ignore = gridSize;
const auto m = desc_mn.GetLength(I0); const auto m = desc_mn.GetLength(I0);
const auto n = desc_mn.GetLength(I1); const auto n = desc_mn.GetLength(I1);
const index_t loop_step_m = MPerThread; const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = gridSize * blockSize * NPerThread; const index_t loop_step_n = num_threads_n * NPerThread;
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m; const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n; const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
std::cout << NumDim_m << " m: " << m << " loop_step_m: " << loop_step_m std::cout << NumDim_m << " m: " << m << " loop_step_m: " << loop_step_m
...@@ -92,7 +94,9 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -92,7 +94,9 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
static auto MakeDescriptor_MN(const std::array<index_t, NumDim>& lengths, static auto MakeDescriptor_MN(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride, const std::array<index_t, NumDim>& stride,
index_t gridSize, index_t gridSize,
index_t blockSize) index_t blockSize,
index_t num_threads_m,
index_t num_threads_n)
{ {
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{}); auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{}); auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
...@@ -117,10 +121,10 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -117,10 +121,10 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
make_tuple(mDimIds, nDimIds), make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize); return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize, num_threads_m, num_threads_n);
} }
else else
return PadDescriptor_MN_2d(desc, gridSize, blockSize); return PadDescriptor_MN_2d(desc, gridSize, blockSize, num_threads_m, num_threads_n);
} }
template <index_t TupleSize> template <index_t TupleSize>
...@@ -130,11 +134,11 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -130,11 +134,11 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
[&](auto) { [&](auto) {
if constexpr(NumDim > 2) if constexpr(NumDim > 2)
{ {
return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1); return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1, 1, 1);
} }
else else
{ {
return MakeDescriptor_MN({1}, {1}, 1, 1); return MakeDescriptor_MN({1}, {1}, 1, 1, 1, 1);
}; };
}, },
Number<TupleSize>{}); Number<TupleSize>{});
...@@ -150,8 +154,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -150,8 +154,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
ElementwiseOperation, ElementwiseOperation,
MPerThread, MPerThread,
NPerThread, NPerThread,
//num_threads_m,
//num_threads_n,
InScalarPerVectorSeq, InScalarPerVectorSeq,
OutScalarPerVectorSeq>; OutScalarPerVectorSeq>;
...@@ -169,7 +171,9 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -169,7 +171,9 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
outStridesArray_(outStridesArray), outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op), elementwise_op_(elementwise_op),
blockSize_(256), blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future
num_threads_m_((gridSize_*blockSize_)/8),
num_threads_n_(8)
{ {
static_assert(NumDim_m > 0, ""); static_assert(NumDim_m > 0, "");
static_assert(NumDim_n > 0, ""); static_assert(NumDim_n > 0, "");
...@@ -191,19 +195,16 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -191,19 +195,16 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
in_grid_2d_desc_tuple_ = generate_tuple( in_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) { [&](auto I) {
return MakeDescriptor_MN( return MakeDescriptor_MN(
lengths, inStridesArray[I.value], gridSize_, blockSize_); lengths, inStridesArray[I.value], gridSize_, blockSize_, num_threads_m_, num_threads_n_);
}, },
Number<NumInput>{}); Number<NumInput>{});
out_grid_2d_desc_tuple_ = generate_tuple( out_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) { [&](auto I) {
return MakeDescriptor_MN( return MakeDescriptor_MN(
lengths, outStridesArray[I.value], gridSize_, blockSize_); lengths, outStridesArray[I.value], gridSize_, blockSize_, num_threads_m_, num_threads_n_);
}, },
Number<NumOutput>{}); Number<NumOutput>{});
//num_threads_m = 1;
//num_threads_n = gridSize_ * blockSize_;
} }
InDataTypePointerTuple in_dev_buffers_; InDataTypePointerTuple in_dev_buffers_;
...@@ -218,6 +219,8 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -218,6 +219,8 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
ElementwiseOperation elementwise_op_; ElementwiseOperation elementwise_op_;
index_t blockSize_; index_t blockSize_;
index_t gridSize_; index_t gridSize_;
index_t num_threads_m_;
index_t num_threads_n_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
...@@ -240,7 +243,9 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -240,7 +243,9 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
arg.out_grid_2d_desc_tuple_, arg.out_grid_2d_desc_tuple_,
arg.in_dev_buffers_, arg.in_dev_buffers_,
arg.out_dev_buffers_, arg.out_dev_buffers_,
arg.elementwise_op_); arg.elementwise_op_,
arg.num_threads_m_,
arg.num_threads_n_);
return elapsed_time; return elapsed_time;
} }
...@@ -258,9 +263,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -258,9 +263,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
if(pArg == nullptr) if(pArg == nullptr)
return false; return false;
std::cout << "made it here" << std::endl;
std::cout << "lengths back: " << pArg->lengths_.back() << std::endl;
if(pArg->lengths_.back() % MPerThread != 0) if(pArg->lengths_.back() % MPerThread != 0)
return false; return false;
...@@ -273,20 +275,12 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -273,20 +275,12 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
(lengths[vectorDim] % scalarPerVector == 0 || (lengths[vectorDim] % scalarPerVector == 0 ||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim])) lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
{ {
// std::cout << "Check 1 passed" << std::endl;
return true; return true;
} }
// std::cout << "Check 1 failed " << std::endl;
// std::cout << "ISPVV Check 2 starting" << std::endl;
// std::cout << "strides[vectorDim]: " << strides[vectorDim] << std::endl;
if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim]) if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim])
{ {
// std::cout << "Check 2 passed " << std::endl;
return true; return true;
} }
// std::cout << "Check 2 failed" << std::endl;
return false; return false;
}; };
...@@ -300,7 +294,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -300,7 +294,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
NumDim_m - 1)) NumDim_m - 1))
valid = false; valid = false;
}); });
std::cout << "valid after loop through input: " << valid << std::endl;
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
std::cout << "running 2: " << I << std::endl; std::cout << "running 2: " << I << std::endl;
...@@ -310,7 +303,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, ...@@ -310,7 +303,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
NumDim - 1)) NumDim - 1))
valid = false; valid = false;
}); });
std::cout << "valid after loop through output: " << valid << std::endl;
return valid; return valid;
}; };
......
...@@ -20,13 +20,17 @@ __global__ void kernel_elementwise_2d(const InGrid2dDescTuple in_grid_2d_desc_tu ...@@ -20,13 +20,17 @@ __global__ void kernel_elementwise_2d(const InGrid2dDescTuple in_grid_2d_desc_tu
const OutGrid2dDescTuple out_grid_2d_desc_tuple, const OutGrid2dDescTuple out_grid_2d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple, const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple, const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op) const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n)
{ {
GridwiseElementwise2dFunctor::Run(in_grid_2d_desc_tuple, GridwiseElementwise2dFunctor::Run(in_grid_2d_desc_tuple,
out_grid_2d_desc_tuple, out_grid_2d_desc_tuple,
p_in_global_tuple, p_in_global_tuple,
p_out_global_tuple, p_out_global_tuple,
elementwise_op); elementwise_op,
num_threads_m,
num_threads_n);
} }
template <typename InGrid2dDescTuple, template <typename InGrid2dDescTuple,
...@@ -61,7 +65,9 @@ struct GridwiseElementwise_2D ...@@ -61,7 +65,9 @@ struct GridwiseElementwise_2D
const OutGrid2dDescTuple out_grid_2d_desc_tuple, const OutGrid2dDescTuple out_grid_2d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple, const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple, const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op) const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n)
{ {
auto in_thread_buf_tuple = generate_tuple( auto in_thread_buf_tuple = generate_tuple(
[&](auto I) { [&](auto I) {
...@@ -101,14 +107,6 @@ struct GridwiseElementwise_2D ...@@ -101,14 +107,6 @@ struct GridwiseElementwise_2D
}, },
Number<NumOutput>{}); Number<NumOutput>{});
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const index_t totalNumThread = blockSize * blockPerGrid;
const index_t num_threads_m = 4;
const index_t num_threads_n = totalNumThread/4;
//static_assert(num_threads_m * num_threads_n == totalNumThread, "error: threads per dimension not equal to total threads");
const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0); const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0);
const auto N = in_grid_2d_desc_tuple[I0].GetLength(I1); const auto N = in_grid_2d_desc_tuple[I0].GetLength(I1);
...@@ -201,9 +199,6 @@ struct GridwiseElementwise_2D ...@@ -201,9 +199,6 @@ struct GridwiseElementwise_2D
}); });
}); });
// static_for<0, MPerThread * NPerThread, 1>{}(
//[&](auto i) { out_thread_buf_tuple(I0)(i) = 1; });
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_mn, out_global_store_tuple(I).Run(thread_buffer_desc_mn,
make_tuple(I0, I0), make_tuple(I0, I0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment