Commit ee40f5a9 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use type alias to reduce code

parent f17fa4d7
...@@ -21,10 +21,10 @@ struct DevicePermute : BaseOperator ...@@ -21,10 +21,10 @@ struct DevicePermute : BaseOperator
using Strides = Lengths; using Strides = Lengths;
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths inLengths, MakeArgumentPointer(const Lengths& inLengths,
const Strides inStrides, const Strides& inStrides,
const Lengths outLengths, const Lengths& outLengths,
const Strides outStrides, const Strides& outStrides,
const void* in_dev_buffer, const void* in_dev_buffer,
void* out_dev_buffer, void* out_dev_buffer,
ElementwiseOperation elementwise_op) = 0; ElementwiseOperation elementwise_op) = 0;
...@@ -32,73 +32,6 @@ struct DevicePermute : BaseOperator ...@@ -32,73 +32,6 @@ struct DevicePermute : BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t NumDim,
typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
typename DerivedDeviceOperator>
struct DevicePermuteCRTP : DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>
{
private:
using BaseType = DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>;
public:
// override methods inherited from 'BaseOperator'
bool IsSupportedArgument(const BaseArgument* arg) override final
{
const auto* const argument =
dynamic_cast<const typename DerivedDeviceOperator::Argument*>(arg);
if(!argument)
{
return false;
}
return DerivedDeviceOperator::IsSupportedArgument(*argument);
}
// override methods inherited from 'DevicePermute'
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const typename BaseType::Lengths inLengths,
const typename BaseType::Strides inStrides,
const typename BaseType::Lengths outLengths,
const typename BaseType::Strides outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) override final
{
return std::make_unique<typename DerivedDeviceOperator::Argument>(inLengths,
inStrides,
outLengths,
outStrides,
in_dev_buffer,
out_dev_buffer,
elementwise_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
{
return std::make_unique<typename DerivedDeviceOperator::Invoker>();
};
// generate other utility methods
template <typename... Args>
static auto MakeArgument(Args&&... args) noexcept(
std::is_nothrow_constructible_v<typename DerivedDeviceOperator::Argument, Args...>)
{
static_assert(std::is_constructible_v<typename DerivedDeviceOperator::Argument, Args...>);
return typename DerivedDeviceOperator::Argument{std::forward<Args>(args)...};
}
static auto MakeInvoker() noexcept(
std::is_nothrow_default_constructible_v<typename DerivedDeviceOperator::Invoker>)
{
static_assert(std::is_default_constructible_v<typename DerivedDeviceOperator::Invoker>);
return typename DerivedDeviceOperator::Invoker{};
}
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -41,27 +41,12 @@ template <index_t NumDim, ...@@ -41,27 +41,12 @@ template <index_t NumDim,
index_t DstVectorDim, index_t DstVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t DstScalarPerVector> index_t DstScalarPerVector>
struct DevicePermuteImpl struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>
: DevicePermuteCRTP<NumDim,
InDataType,
OutDataType,
ElementwiseOperation,
DevicePermuteImpl<NumDim,
InDataType,
OutDataType,
ElementwiseOperation,
BlockSize,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector>>
{ {
using BaseType = DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>;
using typename BaseType::Lengths;
using typename BaseType::Strides;
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor"); static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim); static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim);
static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim); static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim);
...@@ -75,8 +60,7 @@ struct DevicePermuteImpl ...@@ -75,8 +60,7 @@ struct DevicePermuteImpl
return generate_tuple([&](auto I) { return array[I]; }, Number<N>{}); return generate_tuple([&](auto I) { return array[I]; }, Number<N>{});
} }
static auto MakeDescriptor_N_H_W(const std::array<index_t, NumDim>& lengths, static auto MakeDescriptor_N_H_W(const Lengths& lengths, const Strides stride)
const std::array<index_t, NumDim>& stride)
{ {
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], // create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
// d[NumDim-1]] // d[NumDim-1]]
...@@ -123,12 +107,14 @@ struct DevicePermuteImpl ...@@ -123,12 +107,14 @@ struct DevicePermuteImpl
SrcScalarPerVector, SrcScalarPerVector,
DstScalarPerVector>; DstScalarPerVector>;
using Block2TileMap = typename GridwisePermute::DefaultBlock2TileMap;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::array<index_t, NumDim> inLengths, Argument(const Lengths& inLengths,
const std::array<index_t, NumDim> inStrides, const Strides& inStrides,
const std::array<index_t, NumDim> outLengths, const Lengths& outLengths,
const std::array<index_t, NumDim> outStrides, const Strides& outStrides,
const void* in_dev_buffer, const void* in_dev_buffer,
void* out_dev_buffer, void* out_dev_buffer,
ElementwiseOperation elementwise_op) ElementwiseOperation elementwise_op)
...@@ -150,14 +136,14 @@ struct DevicePermuteImpl ...@@ -150,14 +136,14 @@ struct DevicePermuteImpl
InGridDesc in_grid_desc_; InGridDesc in_grid_desc_;
OutGridDesc out_grid_desc_; OutGridDesc out_grid_desc_;
std::array<index_t, NumDim> inLengths_; Lengths inLengths_;
std::array<index_t, NumDim> inStrides_; Strides inStrides_;
std::array<index_t, NumDim> outLengths_; Lengths outLengths_;
std::array<index_t, NumDim> outStrides_; Strides outStrides_;
ElementwiseOperation elementwise_op_; ElementwiseOperation elementwise_op_;
typename GridwisePermute::DefaultBlock2TileMap block_2_tile_map_; Block2TileMap block_2_tile_map_;
}; };
struct Invoker : BaseInvokerCRTP<Invoker, Argument> struct Invoker : BaseInvokerCRTP<Invoker, Argument>
...@@ -172,7 +158,7 @@ struct DevicePermuteImpl ...@@ -172,7 +158,7 @@ struct DevicePermuteImpl
InDataType, InDataType,
OutDataType, OutDataType,
ElementwiseOperation, ElementwiseOperation,
typename GridwisePermute::DefaultBlock2TileMap>; Block2TileMap>;
float elapsed_time = launch_and_time_kernel(stream_config, float elapsed_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -227,6 +213,56 @@ struct DevicePermuteImpl ...@@ -227,6 +213,56 @@ struct DevicePermuteImpl
DstScalarPerVector) && DstScalarPerVector) &&
GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_); GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
}; };
// override methods inherited from 'BaseOperator'
bool IsSupportedArgument(const BaseArgument* arg) override final
{
const auto* const argument = dynamic_cast<const Argument*>(arg);
if(!argument)
{
return false;
}
return IsSupportedArgument(*argument);
}
// override methods inherited from 'DevicePermute'
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths& inLengths,
const Strides& inStrides,
const Lengths& outLengths,
const Strides& outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) override final
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
in_dev_buffer,
out_dev_buffer,
elementwise_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
{
return std::make_unique<Invoker>();
};
// other constructor methods
template <typename... Args>
static std::enable_if_t<std::is_constructible_v<Argument, Args...>, Argument>
MakeArgument(Args&&... args) noexcept(std::is_nothrow_constructible_v<Argument, Args...>)
{
return Argument{std::forward<Args>(args)...};
}
static std::enable_if_t<std::is_default_constructible_v<Invoker>, Invoker>
MakeInvoker() noexcept(std::is_nothrow_default_constructible_v<Invoker>)
{
return Invoker{};
}
}; };
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment