Commit ee40f5a9 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use type alias to reduce code

parent f17fa4d7
......@@ -21,10 +21,10 @@ struct DevicePermute : BaseOperator
using Strides = Lengths;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths inLengths,
const Strides inStrides,
const Lengths outLengths,
const Strides outStrides,
MakeArgumentPointer(const Lengths& inLengths,
const Strides& inStrides,
const Lengths& outLengths,
const Strides& outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) = 0;
......@@ -32,73 +32,6 @@ struct DevicePermute : BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t NumDim,
typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
typename DerivedDeviceOperator>
struct DevicePermuteCRTP : DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>
{
private:
using BaseType = DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>;
public:
// override methods inherited from 'BaseOperator'
bool IsSupportedArgument(const BaseArgument* arg) override final
{
const auto* const argument =
dynamic_cast<const typename DerivedDeviceOperator::Argument*>(arg);
if(!argument)
{
return false;
}
return DerivedDeviceOperator::IsSupportedArgument(*argument);
}
// override methods inherited from 'DevicePermute'
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const typename BaseType::Lengths inLengths,
const typename BaseType::Strides inStrides,
const typename BaseType::Lengths outLengths,
const typename BaseType::Strides outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) override final
{
return std::make_unique<typename DerivedDeviceOperator::Argument>(inLengths,
inStrides,
outLengths,
outStrides,
in_dev_buffer,
out_dev_buffer,
elementwise_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
{
return std::make_unique<typename DerivedDeviceOperator::Invoker>();
};
// generate other utility methods
template <typename... Args>
static auto MakeArgument(Args&&... args) noexcept(
std::is_nothrow_constructible_v<typename DerivedDeviceOperator::Argument, Args...>)
{
static_assert(std::is_constructible_v<typename DerivedDeviceOperator::Argument, Args...>);
return typename DerivedDeviceOperator::Argument{std::forward<Args>(args)...};
}
static auto MakeInvoker() noexcept(
std::is_nothrow_default_constructible_v<typename DerivedDeviceOperator::Invoker>)
{
static_assert(std::is_default_constructible_v<typename DerivedDeviceOperator::Invoker>);
return typename DerivedDeviceOperator::Invoker{};
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -41,27 +41,12 @@ template <index_t NumDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector>
struct DevicePermuteImpl
: DevicePermuteCRTP<NumDim,
InDataType,
OutDataType,
ElementwiseOperation,
DevicePermuteImpl<NumDim,
InDataType,
OutDataType,
ElementwiseOperation,
BlockSize,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector>>
struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>
{
using BaseType = DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>;
using typename BaseType::Lengths;
using typename BaseType::Strides;
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim);
static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim);
......@@ -75,8 +60,7 @@ struct DevicePermuteImpl
return generate_tuple([&](auto I) { return array[I]; }, Number<N>{});
}
static auto MakeDescriptor_N_H_W(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride)
static auto MakeDescriptor_N_H_W(const Lengths& lengths, const Strides stride)
{
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
// d[NumDim-1]]
......@@ -123,12 +107,14 @@ struct DevicePermuteImpl
SrcScalarPerVector,
DstScalarPerVector>;
using Block2TileMap = typename GridwisePermute::DefaultBlock2TileMap;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> inLengths,
const std::array<index_t, NumDim> inStrides,
const std::array<index_t, NumDim> outLengths,
const std::array<index_t, NumDim> outStrides,
Argument(const Lengths& inLengths,
const Strides& inStrides,
const Lengths& outLengths,
const Strides& outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op)
......@@ -150,14 +136,14 @@ struct DevicePermuteImpl
InGridDesc in_grid_desc_;
OutGridDesc out_grid_desc_;
std::array<index_t, NumDim> inLengths_;
std::array<index_t, NumDim> inStrides_;
std::array<index_t, NumDim> outLengths_;
std::array<index_t, NumDim> outStrides_;
Lengths inLengths_;
Strides inStrides_;
Lengths outLengths_;
Strides outStrides_;
ElementwiseOperation elementwise_op_;
typename GridwisePermute::DefaultBlock2TileMap block_2_tile_map_;
Block2TileMap block_2_tile_map_;
};
struct Invoker : BaseInvokerCRTP<Invoker, Argument>
......@@ -172,7 +158,7 @@ struct DevicePermuteImpl
InDataType,
OutDataType,
ElementwiseOperation,
typename GridwisePermute::DefaultBlock2TileMap>;
Block2TileMap>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -227,6 +213,56 @@ struct DevicePermuteImpl
DstScalarPerVector) &&
GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
};
// override methods inherited from 'BaseOperator'
bool IsSupportedArgument(const BaseArgument* arg) override final
{
const auto* const argument = dynamic_cast<const Argument*>(arg);
if(!argument)
{
return false;
}
return IsSupportedArgument(*argument);
}
// override methods inherited from 'DevicePermute'
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths& inLengths,
const Strides& inStrides,
const Lengths& outLengths,
const Strides& outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) override final
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
in_dev_buffer,
out_dev_buffer,
elementwise_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
{
return std::make_unique<Invoker>();
};
// other constructor methods
template <typename... Args>
static std::enable_if_t<std::is_constructible_v<Argument, Args...>, Argument>
MakeArgument(Args&&... args) noexcept(std::is_nothrow_constructible_v<Argument, Args...>)
{
return Argument{std::forward<Args>(args)...};
}
static std::enable_if_t<std::is_default_constructible_v<Invoker>, Invoker>
MakeInvoker() noexcept(std::is_nothrow_default_constructible_v<Invoker>)
{
return Invoker{};
}
};
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment