Commit d179a12a authored by Astha Rai's avatar Astha Rai
Browse files

removed most of the extraneous code, testing with different dimensions

parent dab61372
......@@ -48,7 +48,7 @@ int main()
bool do_verification = true;
bool time_kernel = true;
const int N = 16;
const int N = 120;
const int H = 32;
const int W = 64;
......@@ -72,7 +72,6 @@ int main()
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths{N, H, W, C};
// std::copy(nhwc.begin(), nhwc.end(), ab_lengths.begin());
std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W};
std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1};
......
......@@ -69,7 +69,11 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_MN>
static auto PadDescriptor_MN_2d(Desc_MN desc_mn, index_t gridSize, index_t blockSize, index_t num_threads_m, index_t num_threads_n)
static auto PadDescriptor_MN_2d(Desc_MN desc_mn,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n)
{
std::ignore = blockSize;
std::ignore = gridSize;
......@@ -172,8 +176,8 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
elementwise_op_(elementwise_op),
blockSize_(256),
gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future
num_threads_m_((gridSize_*blockSize_)/8),
num_threads_n_(8)
num_threads_m_((gridSize_ * blockSize_) / 16),
num_threads_n_(16)
{
static_assert(NumDim_m > 0, "");
static_assert(NumDim_n > 0, "");
......@@ -194,15 +198,23 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
in_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(
lengths, inStridesArray[I.value], gridSize_, blockSize_, num_threads_m_, num_threads_n_);
return MakeDescriptor_MN(lengths,
inStridesArray[I.value],
gridSize_,
blockSize_,
num_threads_m_,
num_threads_n_);
},
Number<NumInput>{});
out_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(
lengths, outStridesArray[I.value], gridSize_, blockSize_, num_threads_m_, num_threads_n_);
return MakeDescriptor_MN(lengths,
outStridesArray[I.value],
gridSize_,
blockSize_,
num_threads_m_,
num_threads_n_);
},
Number<NumOutput>{});
}
......@@ -284,10 +296,8 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
return false;
};
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
std::cout << "running: " << I << std::endl;
if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->inStridesArray_[I.value],
InScalarPerVectorSeq::At(I),
......@@ -296,7 +306,6 @@ struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
});
static_for<0, NumOutput, 1>{}([&](auto I) {
std::cout << "running 2: " << I << std::endl;
if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->outStridesArray_[I.value],
OutScalarPerVectorSeq::At(I),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment