Commit 339e51d1 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Simplify 'DevicePermute' interface

parent 5ae42120
...@@ -22,14 +22,14 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem) ...@@ -22,14 +22,14 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
a_device_buf.ToDevice(a.mData.data()); a_device_buf.ToDevice(a.mData.data());
std::array<ck::index_t, 4> ab_lengths; std::array<ck::index_t, 4> ab_lengths;
std::array<std::array<ck::index_t, 4>, 1> a_strides, b_strides; std::array<ck::index_t, 4> a_strides, b_strides;
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()}; const void* input = a_device_buf.GetDeviceBuffer();
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()}; void* output = b_device_buf.GetDeviceBuffer();
std::copy(begin(shape), end(shape), begin(ab_lengths)); std::copy(begin(shape), end(shape), begin(ab_lengths));
std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(front(a_strides))); std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(a_strides));
std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(front(b_strides))); std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(b_strides));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>); static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
......
...@@ -118,25 +118,20 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -118,25 +118,20 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
return PadDescriptor_M_1d(desc, gridSize, blockSize); return PadDescriptor_M_1d(desc, gridSize, blockSize);
} }
template <index_t TupleSize> static auto GenerateInOutGrid1dDesc()
static auto GenerateInOutGrid1dDescTuple(Number<TupleSize>)
{ {
return generate_tuple( if constexpr(NumDim > 1)
[&](auto) { {
if constexpr(NumDim > 1) return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1);
{ }
return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1); else
} {
else return MakeDescriptor_M({1}, {1}, 1, 1);
{ };
return MakeDescriptor_M({1}, {1}, 1, 1);
};
},
Number<TupleSize>{});
}; };
using InGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumInput>{})); using InGrid1dDescTuple = Tuple<decltype(GenerateInOutGrid1dDesc())>;
using OutGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumOutput>{})); using OutGrid1dDescTuple = Tuple<decltype(GenerateInOutGrid1dDesc())>;
using GridwiseElementwise = GridwiseElementwise_1D<InGrid1dDescTuple, using GridwiseElementwise = GridwiseElementwise_1D<InGrid1dDescTuple,
OutGrid1dDescTuple, OutGrid1dDescTuple,
...@@ -150,48 +145,44 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -150,48 +145,44 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::array<index_t, NumDim> lengths, Argument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray, const std::array<index_t, NumDim> inStrides,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray, const std::array<index_t, NumDim> outStrides,
const std::array<const void*, NumInput> in_dev_buffers, const void* in_dev_buffer,
const std::array<void*, NumOutput> out_dev_buffers, void* out_dev_buffer,
ElementwiseOperation elementwise_op) ElementwiseOperation elementwise_op)
: blockSize_(256),
: lengths_(lengths), gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future
inStridesArray_(inStridesArray), lengths_(lengths),
outStridesArray_(outStridesArray), inStridesArray_({inStrides}),
elementwise_op_(elementwise_op), outStridesArray_({outStrides}),
blockSize_(256), elementwise_op_(elementwise_op)
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{ {
in_dev_buffers_ = generate_tuple( in_dev_buffers_ = generate_tuple(
[&](auto I) { [&](auto) {
using DataType = InDataType; using DataType = InDataType;
return static_cast<const DataType*>(in_dev_buffers[I.value]); return static_cast<const DataType*>(in_dev_buffer);
}, },
Number<NumInput>{}); Number<NumInput>{});
out_dev_buffers_ = generate_tuple( out_dev_buffers_ = generate_tuple(
[&](auto I) { [&](auto) {
using DataType = OutDataType; using DataType = OutDataType;
return static_cast<DataType*>(out_dev_buffers[I.value]); return static_cast<DataType*>(out_dev_buffer);
}, },
Number<NumOutput>{}); Number<NumOutput>{});
in_grid_1d_desc_tuple_ = generate_tuple( in_grid_1d_desc_tuple_ = generate_tuple(
[&](auto I) { [&](auto) { return MakeDescriptor_M(lengths, inStrides, gridSize_, blockSize_); },
return MakeDescriptor_M(
lengths, inStridesArray[I.value], gridSize_, blockSize_);
},
Number<NumInput>{}); Number<NumInput>{});
out_grid_1d_desc_tuple_ = generate_tuple( out_grid_1d_desc_tuple_ = generate_tuple(
[&](auto I) { [&](auto) { return MakeDescriptor_M(lengths, outStrides, gridSize_, blockSize_); },
return MakeDescriptor_M(
lengths, outStridesArray[I.value], gridSize_, blockSize_);
},
Number<NumOutput>{}); Number<NumOutput>{});
} }
index_t blockSize_;
index_t gridSize_;
InDataTypePointerTuple in_dev_buffers_; InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_; OutDataTypePointerTuple out_dev_buffers_;
InGrid1dDescTuple in_grid_1d_desc_tuple_; InGrid1dDescTuple in_grid_1d_desc_tuple_;
...@@ -202,8 +193,6 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -202,8 +193,6 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_; std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_; ElementwiseOperation elementwise_op_;
index_t blockSize_;
index_t gridSize_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment