Commit b4abe4e2 authored by Astha Rai's avatar Astha Rai
Browse files

integrated variable for thread distribution into device elementwise and added...

integrated variable for thread distribution into device elementwise and added as parameter for gridwise elementwise
parent b61bb071
......@@ -69,12 +69,14 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_MN>
static auto PadDescriptor_MN_2d(Desc_MN desc_mn, index_t gridSize, index_t blockSize)
static auto PadDescriptor_MN_2d(Desc_MN desc_mn, index_t gridSize, index_t blockSize, index_t num_threads_m, index_t num_threads_n)
{
std::ignore = blockSize;
std::ignore = gridSize;
const auto m = desc_mn.GetLength(I0);
const auto n = desc_mn.GetLength(I1);
const index_t loop_step_m = MPerThread;
const index_t loop_step_n = gridSize * blockSize * NPerThread;
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
std::cout << NumDim_m << " m: " << m << " loop_step_m: " << loop_step_m
......@@ -92,7 +94,9 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
static auto MakeDescriptor_MN(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride,
index_t gridSize,
index_t blockSize)
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
......@@ -117,10 +121,10 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize);
return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize, num_threads_m, num_threads_n);
}
else
return PadDescriptor_MN_2d(desc, gridSize, blockSize);
return PadDescriptor_MN_2d(desc, gridSize, blockSize, num_threads_m, num_threads_n);
}
template <index_t TupleSize>
......@@ -130,11 +134,11 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
[&](auto) {
if constexpr(NumDim > 2)
{
return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1);
return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1, 1, 1);
}
else
{
return MakeDescriptor_MN({1}, {1}, 1, 1);
return MakeDescriptor_MN({1}, {1}, 1, 1, 1, 1);
};
},
Number<TupleSize>{});
......@@ -150,8 +154,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
ElementwiseOperation,
MPerThread,
NPerThread,
//num_threads_m,
//num_threads_n,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
......@@ -169,7 +171,9 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op),
blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future
num_threads_m_((gridSize_*blockSize_)/8),
num_threads_n_(8)
{
static_assert(NumDim_m > 0, "");
static_assert(NumDim_n > 0, "");
......@@ -191,19 +195,16 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
in_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(
lengths, inStridesArray[I.value], gridSize_, blockSize_);
lengths, inStridesArray[I.value], gridSize_, blockSize_, num_threads_m_, num_threads_n_);
},
Number<NumInput>{});
out_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(
lengths, outStridesArray[I.value], gridSize_, blockSize_);
lengths, outStridesArray[I.value], gridSize_, blockSize_, num_threads_m_, num_threads_n_);
},
Number<NumOutput>{});
//num_threads_m = 1;
//num_threads_n = gridSize_ * blockSize_;
}
InDataTypePointerTuple in_dev_buffers_;
......@@ -218,6 +219,8 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
ElementwiseOperation elementwise_op_;
index_t blockSize_;
index_t gridSize_;
index_t num_threads_m_;
index_t num_threads_n_;
};
struct Invoker : public BaseInvoker
......@@ -240,7 +243,9 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
arg.out_grid_2d_desc_tuple_,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
arg.elementwise_op_);
arg.elementwise_op_,
arg.num_threads_m_,
arg.num_threads_n_);
return elapsed_time;
}
......@@ -258,9 +263,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
if(pArg == nullptr)
return false;
std::cout << "made it here" << std::endl;
std::cout << "lengths back: " << pArg->lengths_.back() << std::endl;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
......@@ -273,20 +275,12 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
(lengths[vectorDim] % scalarPerVector == 0 ||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
{
// std::cout << "Check 1 passed" << std::endl;
return true;
}
// std::cout << "Check 1 failed " << std::endl;
// std::cout << "ISPVV Check 2 starting" << std::endl;
// std::cout << "strides[vectorDim]: " << strides[vectorDim] << std::endl;
if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim])
{
// std::cout << "Check 2 passed " << std::endl;
return true;
}
// std::cout << "Check 2 failed" << std::endl;
return false;
};
......@@ -300,7 +294,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
NumDim_m - 1))
valid = false;
});
std::cout << "valid after loop through input: " << valid << std::endl;
static_for<0, NumOutput, 1>{}([&](auto I) {
std::cout << "running 2: " << I << std::endl;
......@@ -310,7 +303,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
NumDim - 1))
valid = false;
});
std::cout << "valid after loop through output: " << valid << std::endl;
return valid;
};
......
......@@ -20,13 +20,17 @@ __global__ void kernel_elementwise_2d(const InGrid2dDescTuple in_grid_2d_desc_tu
const OutGrid2dDescTuple out_grid_2d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op)
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n)
{
GridwiseElementwise2dFunctor::Run(in_grid_2d_desc_tuple,
out_grid_2d_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op);
elementwise_op,
num_threads_m,
num_threads_n);
}
template <typename InGrid2dDescTuple,
......@@ -61,7 +65,9 @@ struct GridwiseElementwise_2D
const OutGrid2dDescTuple out_grid_2d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op)
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n)
{
auto in_thread_buf_tuple = generate_tuple(
[&](auto I) {
......@@ -101,14 +107,6 @@ struct GridwiseElementwise_2D
},
Number<NumOutput>{});
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const index_t totalNumThread = blockSize * blockPerGrid;
const index_t num_threads_m = 4;
const index_t num_threads_n = totalNumThread/4;
//static_assert(num_threads_m * num_threads_n == totalNumThread, "error: threads per dimension not equal to total threads");
const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0);
const auto N = in_grid_2d_desc_tuple[I0].GetLength(I1);
......@@ -201,9 +199,6 @@ struct GridwiseElementwise_2D
});
});
// static_for<0, MPerThread * NPerThread, 1>{}(
//[&](auto i) { out_thread_buf_tuple(I0)(i) = 1; });
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_mn,
make_tuple(I0, I0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment