Commit 1d97c3a4 authored by Astha Rai's avatar Astha Rai
Browse files

updated Grid Desc

parent facdb52e
......@@ -76,7 +76,7 @@ struct DeviceElementwise
const index_t loop_step_n = gridSize * blockSize * NPerThread;
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
const auto desc_mn_pad = transform_tesor_descriptor(
const auto desc_mn_pad = transform_tensor_descriptor(
desc_mn,
make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
......@@ -96,14 +96,14 @@ struct DeviceElementwise
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type();
constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDim_n, 1>::type();
constexpr auto nDimIds = typename arithmetic_sequence_gen< NumDim_m, NumDim_m + NumDim_n, 1>::type();
const auto mLengths = get_container_subset(lengths, mDimIds);
const auto nLengths = get_container_subset(lengths, nDimIds);
const auto mLengths = get_container_subset(tupleOfShape, mDimIds);
const auto nLengths = get_container_subset(tupleOfShape, nDimIds);
// merge nd to 2d desc - [s0 * s1 * ...]
if constexpr(NumDim_m + NumDim_n > 2)
if constexpr(NumDim > 2)
{
const auto desc_mn = transform_tensor_descriptor(
desc,
......@@ -118,11 +118,11 @@ struct DeviceElementwise
}
template <index_t TupleSize>
static auto GenerateInOutGrid1dDescTuple(Number<TupleSize>)
static auto GenerateInOutGrid2dDescTuple(Number<TupleSize>)
{
return generate_tuple(
[&](auto) {
if constexpr(NumDim_m + NumDim_n > 2)
if constexpr(NumDim > 2)
{
return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1);
}
......@@ -134,8 +134,9 @@ struct DeviceElementwise
Number<TupleSize>{});
};
using InGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumInput>{}));
using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumOutput>{}));
using InGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumInput>{}));
//using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumOutput>{}));
using GridwiseElementwise = GridwiseElementwise_2D<InGrid2dDescTuple,
OutGrid2dDescTuple,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment