Commit d356c871 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove no-longer used type argument

parent 7a6dbadc
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using DevicePermuteInstance = ck::tensor_operation::device:: using DevicePermuteInstance =
DevicePermute<ADataType, BDataType, PassThrough, 4, 8, S<8>, S<1>>; ck::tensor_operation::device::DevicePermute<ADataType, BDataType, PassThrough, 4, 8, 8, 1>;
#include "run_permute_example.inc" #include "run_permute_example.inc"
......
...@@ -73,23 +73,16 @@ template <typename InDataType, ...@@ -73,23 +73,16 @@ template <typename InDataType,
typename ElementwiseOperation, typename ElementwiseOperation,
index_t NumDim, index_t NumDim,
index_t MPerThread, index_t MPerThread,
typename InScalarPerVectorSeq, index_t InScalarPerVector,
typename OutScalarPerVectorSeq> index_t OutScalarPerVector>
struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
OutDataType, OutDataType,
ElementwiseOperation, ElementwiseOperation,
NumDim, NumDim,
MPerThread, MPerThread,
InScalarPerVectorSeq, InScalarPerVector,
OutScalarPerVectorSeq>> OutScalarPerVector>>
{ {
static constexpr int NumInput = 1;
static constexpr int NumOutput = 1;
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
using InDataTypePointer = const InDataType*; using InDataTypePointer = const InDataType*;
using OutDataTypePointer = OutDataType*; using OutDataTypePointer = OutDataType*;
...@@ -156,8 +149,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -156,8 +149,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
OutDataTypePointer, OutDataTypePointer,
ElementwiseOperation, ElementwiseOperation,
MPerThread, MPerThread,
InScalarPerVectorSeq, InScalarPerVector,
OutScalarPerVectorSeq>; OutScalarPerVector>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -243,12 +236,12 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -243,12 +236,12 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
}; };
bool valid = true; bool valid = true;
if(!IsScalarPerVectorValid(arg.inLengths_, arg.inStrides_, InScalarPerVectorSeq::At(0))) if(!IsScalarPerVectorValid(arg.inLengths_, arg.inStrides_, InScalarPerVector))
{ {
valid = false; valid = false;
} }
if(!IsScalarPerVectorValid(arg.inLengths_, arg.outStrides_, OutScalarPerVectorSeq::At(0))) if(!IsScalarPerVectorValid(arg.inLengths_, arg.outStrides_, OutScalarPerVector))
{ {
valid = false; valid = false;
} }
......
...@@ -32,17 +32,10 @@ template <typename InGrid1dDesc, ...@@ -32,17 +32,10 @@ template <typename InGrid1dDesc,
typename OutDataTypePointer, typename OutDataTypePointer,
typename ElementwiseOperation, typename ElementwiseOperation,
index_t MPerThread, index_t MPerThread,
typename InScalarPerVectorSeq, index_t InScalarPerVector,
typename OutScalarPerVectorSeq> index_t OutScalarPerVector>
struct GridwisePermute struct GridwisePermute
{ {
static constexpr index_t NumInput = 1;
static constexpr index_t NumOutput = 1;
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto thread_buffer_desc_m = static constexpr auto thread_buffer_desc_m =
...@@ -86,7 +79,7 @@ struct GridwisePermute ...@@ -86,7 +79,7 @@ struct GridwisePermute
Sequence<MPerThread>, // SliceLengths Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder Sequence<0>, // DimAccessOrder
0, // SrcVectorDim 0, // SrcVectorDim
InScalarPerVectorSeq::At(0), // ScalarPerVector InScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
false>{in_grid_1d_desc, thread_global_offset}; false>{in_grid_1d_desc, thread_global_offset};
...@@ -99,7 +92,7 @@ struct GridwisePermute ...@@ -99,7 +92,7 @@ struct GridwisePermute
Sequence<MPerThread>, // SliceLengths Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder Sequence<0>, // DimAccessOrder
0, // SrcVectorDim 0, // SrcVectorDim
OutScalarPerVectorSeq::At(0), OutScalarPerVector,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>( false>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment