Commit 1d97c3a4 authored by Astha Rai's avatar Astha Rai
Browse files

updated Grid Desc

parent facdb52e
...@@ -76,7 +76,7 @@ struct DeviceElementwise ...@@ -76,7 +76,7 @@ struct DeviceElementwise
const index_t loop_step_n = gridSize * blockSize * NPerThread; const index_t loop_step_n = gridSize * blockSize * NPerThread;
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m; const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n; const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
const auto desc_mn_pad = transform_tesor_descriptor( const auto desc_mn_pad = transform_tensor_descriptor(
desc_mn, desc_mn,
make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n)), make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
...@@ -96,14 +96,14 @@ struct DeviceElementwise ...@@ -96,14 +96,14 @@ struct DeviceElementwise
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type(); constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type();
constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDim_n, 1>::type(); constexpr auto nDimIds = typename arithmetic_sequence_gen< NumDim_m, NumDim_m + NumDim_n, 1>::type();
const auto mLengths = get_container_subset(lengths, mDimIds); const auto mLengths = get_container_subset(tupleOfShape, mDimIds);
const auto nLengths = get_container_subset(lengths, nDimIds); const auto nLengths = get_container_subset(tupleOfShape, nDimIds);
// merge nd to 2d desc - [s0 * s1 * ...] // merge nd to 2d desc - [s0 * s1 * ...]
if constexpr(NumDim_m + NumDim_n > 2) if constexpr(NumDim > 2)
{ {
const auto desc_mn = transform_tensor_descriptor( const auto desc_mn = transform_tensor_descriptor(
desc, desc,
...@@ -118,11 +118,11 @@ struct DeviceElementwise ...@@ -118,11 +118,11 @@ struct DeviceElementwise
} }
template <index_t TupleSize> template <index_t TupleSize>
static auto GenerateInOutGrid1dDescTuple(Number<TupleSize>) static auto GenerateInOutGrid2dDescTuple(Number<TupleSize>)
{ {
return generate_tuple( return generate_tuple(
[&](auto) { [&](auto) {
if constexpr(NumDim_m + NumDim_n > 2) if constexpr(NumDim > 2)
{ {
return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1); return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1);
} }
...@@ -134,8 +134,9 @@ struct DeviceElementwise ...@@ -134,8 +134,9 @@ struct DeviceElementwise
Number<TupleSize>{}); Number<TupleSize>{});
}; };
using InGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumInput>{}));
using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumOutput>{})); using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumOutput>{}));
using InGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumInput>{}));
//using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumOutput>{}));
using GridwiseElementwise = GridwiseElementwise_2D<InGrid2dDescTuple, using GridwiseElementwise = GridwiseElementwise_2D<InGrid2dDescTuple,
OutGrid2dDescTuple, OutGrid2dDescTuple,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment