Commit 0aef936b authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Transform descriptor into 3 dimensions

parent 692f9e0e
...@@ -109,29 +109,44 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -109,29 +109,44 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
return desc_m_pad; return desc_m_pad;
} }
static auto MakeDescriptor_M(const std::array<index_t, NumDim>& lengths, template <index_t N = NumDim>
static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array)
{
static_assert(1 <= N && N <= NumDim);
return generate_tuple([&](auto I) { return array[I]; }, Number<N>{});
}
static auto MakeDescriptor_N_H_W(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride, const std::array<index_t, NumDim>& stride,
index_t gridSize, index_t gridSize,
index_t blockSize) index_t blockSize)
{ {
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{}); // create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{}); // d[NumDim-1]]
const auto desc =
// nd desc - [s0, s1, s2, ...] make_naive_tensor_descriptor(ConvertArrayToTuple(lengths), ConvertArrayToTuple(stride));
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2],
// merge nd to 1d desc - [s0 * s1 * ...] // d[NumDim-1]]
const auto desc_m = transform_tensor_descriptor( // => [N, H, W]
const index_t H = *std::next(rbegin(lengths));
const index_t W = *rbegin(lengths);
const auto desc_n_h_w = transform_tensor_descriptor(
desc, desc,
make_tuple(make_merge_transform(tupleOfShape)), make_tuple(make_merge_transform(ConvertArrayToTuple<NumDim - 2>(lengths)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim>{})), make_pass_through_transform(H),
make_tuple(Sequence<0>{})); make_pass_through_transform(W)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
return PadDescriptor_M_1d(desc_m, gridSize, blockSize); Sequence<NumDim - 2>{},
Sequence<NumDim - 1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return PadDescriptor_M_1d(desc_n_h_w, gridSize, blockSize);
} }
using InGrid1dDesc = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); using InGrid1dDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}, 1, 1));
using OutGrid1dDesc = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); using OutGrid1dDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}, 1, 1));
using GridwisePermute = GridwisePermute<InGrid1dDesc, using GridwisePermute = GridwisePermute<InGrid1dDesc,
OutGrid1dDesc, OutGrid1dDesc,
...@@ -155,8 +170,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -155,8 +170,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future
in_dev_buffer_(static_cast<InDataTypePointer>(in_dev_buffer)), in_dev_buffer_(static_cast<InDataTypePointer>(in_dev_buffer)),
out_dev_buffer_(static_cast<OutDataTypePointer>(out_dev_buffer)), out_dev_buffer_(static_cast<OutDataTypePointer>(out_dev_buffer)),
in_grid_1d_desc_(MakeDescriptor_M(inLengths, inStrides, gridSize_, blockSize_)), in_grid_1d_desc_(MakeDescriptor_N_H_W(inLengths, inStrides, gridSize_, blockSize_)),
out_grid_1d_desc_(MakeDescriptor_M(inLengths, inStrides, gridSize_, blockSize_)), out_grid_1d_desc_(MakeDescriptor_N_H_W(inLengths, inStrides, gridSize_, blockSize_)),
inLengths_(inLengths), inLengths_(inLengths),
inStrides_(inStrides), inStrides_(inStrides),
outLengths_(outLengths), outLengths_(outLengths),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment