Commit 5ae42120 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Only accept single-input-single-output for 'DervicePermute'

parent 179092df
......@@ -7,7 +7,7 @@ using ADataType = F16;
using BDataType = F16;
using DevicePermuteInstance = ck::tensor_operation::device::
DevicePermute<ck::Tuple<ADataType>, ck::Tuple<BDataType>, PassThrough, 4, 8, S<8>, S<1>>;
DevicePermute<ADataType, BDataType, PassThrough, 4, 8, S<8>, S<1>>;
#include "run_permute_example.inc"
......
......@@ -51,52 +51,30 @@ struct DevicePermuteBase : BaseOperator
};
} // namespace detail
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
template <typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
index_t NumDim,
index_t MPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataTypeTuple,
OutDataTypeTuple,
struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
OutDataType,
ElementwiseOperation,
NumDim,
MPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>>
{
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
static constexpr int NumInput = 1;
static constexpr int NumOutput = 1;
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
static auto GenerateInDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(nullptr);
},
Number<NumInput>{});
};
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumOutput>{});
};
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
using InDataTypePointerTuple = Tuple<const InDataType*>;
using OutDataTypePointerTuple = Tuple<OutDataType*>;
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
......@@ -187,14 +165,14 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataTypeTuple,
{
in_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
using DataType = InDataType;
return static_cast<const DataType*>(in_dev_buffers[I.value]);
},
Number<NumInput>{});
out_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
using DataType = OutDataType;
return static_cast<DataType*>(out_dev_buffers[I.value]);
},
Number<NumOutput>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment