"...composable_kernel-1.git" did not exist on "b491ebf38480bc0d6cb329ba6825dee610c59097"
Commit b56ddad3 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Create new base type for 'DervicePermute' implementations

parent b4e2b28c
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp" #include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_permute_base.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -20,55 +21,6 @@ namespace ck { ...@@ -20,55 +21,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace detail {
template <typename Derived>
struct DevicePermuteBase : BaseOperator
{
bool IsSupportedArgument(const BaseArgument* arg) override final
{
const auto* argument = dynamic_cast<const typename Derived::Argument*>(arg);
if(!argument)
{
return false;
}
return Derived::IsSupportedArgument(*argument);
}
template <typename... Args>
static auto MakeArgument(Args&&... args)
{
return typename Derived::Argument{std::forward<Args>(args)...};
}
template <typename... Args>
static auto MakeArgumentPointer(Args&&... args)
{
return std::make_unique<typename Derived::Argument>(std::forward<Args>(args)...);
}
static auto MakeInvoker() { return typename Derived::Invoker{}; }
static auto MakeInvokerPointer() { return std::make_unique<typename Derived::Invoker>(); };
};
template <typename Derived, typename Argument>
struct InvokerBase : BaseInvoker
{
float Run(const BaseArgument* arg,
const StreamConfig& stream_config = StreamConfig{}) override final
{
const auto* argument = dynamic_cast<const Argument*>(arg);
if(!argument)
{
return 0.f;
}
return Derived::Run(*argument, stream_config);
}
};
} // namespace detail
// Swap last 2 dimensions // Swap last 2 dimensions
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]] // input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
// ^^^^^^^^^^^ // ^^^^^^^^^^^
...@@ -89,22 +41,25 @@ template <typename InDataType, ...@@ -89,22 +41,25 @@ template <typename InDataType,
index_t DstVectorDim, index_t DstVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t DstScalarPerVector> index_t DstScalarPerVector>
struct DevicePermute struct DevicePermute : DevicePermuteBaseCRTP<NumDim,
: detail::DevicePermuteBase<DevicePermute<InDataType, InDataType,
OutDataType, OutDataType,
ElementwiseOperation, ElementwiseOperation,
NumDim, DevicePermute<InDataType,
BlockSize, OutDataType,
NPerBlock, ElementwiseOperation,
HPerBlock, NumDim,
WPerBlock, BlockSize,
InBlockLdsExtraW, NPerBlock,
InBlockTransferThreadClusterLengths, HPerBlock,
InBlockTransferThreadClusterArrangeOrder, WPerBlock,
SrcVectorDim, InBlockLdsExtraW,
DstVectorDim, InBlockTransferThreadClusterLengths,
SrcScalarPerVector, InBlockTransferThreadClusterArrangeOrder,
DstScalarPerVector>> SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector>>
{ {
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor"); static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim); static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim);
...@@ -204,7 +159,7 @@ struct DevicePermute ...@@ -204,7 +159,7 @@ struct DevicePermute
typename GridwisePermute::DefaultBlock2TileMap block_2_tile_map_; typename GridwisePermute::DefaultBlock2TileMap block_2_tile_map_;
}; };
struct Invoker : detail::InvokerBase<Invoker, Argument> struct Invoker : BaseInvokerCRTP<Invoker, Argument>
{ {
static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <cmath>
#include <memory>
#include <type_traits>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumDim, typename InDataType, typename OutDataType, typename ElementwiseOperation>
struct DevicePermuteBase : BaseOperator
{
using Lengths = std::array<index_t, NumDim>;
using Strides = Lengths;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths inLengths,
const Strides inStrides,
const Lengths outLengths,
const Strides outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t NumDim,
typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
typename DerivedDeviceOperator>
struct DevicePermuteBaseCRTP
: DevicePermuteBase<NumDim, InDataType, OutDataType, ElementwiseOperation>
{
private:
using BaseType = DevicePermuteBase<NumDim, InDataType, OutDataType, ElementwiseOperation>;
public:
// override methods inherited from 'BaseOperator'
bool IsSupportedArgument(const BaseArgument* arg) override final
{
const auto* const argument =
dynamic_cast<const typename DerivedDeviceOperator::Argument*>(arg);
if(!argument)
{
return false;
}
return DerivedDeviceOperator::IsSupportedArgument(*argument);
}
// override methods inherited from 'DevicePermuteBase'
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const typename BaseType::Lengths inLengths,
const typename BaseType::Strides inStrides,
const typename BaseType::Lengths outLengths,
const typename BaseType::Strides outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) override final
{
return std::make_unique<typename DerivedDeviceOperator::Argument>(inLengths,
inStrides,
outLengths,
outStrides,
in_dev_buffer,
out_dev_buffer,
elementwise_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
{
return std::make_unique<typename DerivedDeviceOperator::Invoker>();
};
// generate other utility methods
template <typename... Args>
static auto MakeArgument(Args&&... args)
{
static_assert(std::is_constructible_v<typename DerivedDeviceOperator::Argument, Args...>);
return typename DerivedDeviceOperator::Argument{std::forward<Args>(args)...};
}
static auto MakeInvoker() noexcept(
std::is_nothrow_default_constructible_v<typename DerivedDeviceOperator::Invoker>)
{
static_assert(std::is_default_constructible_v<typename DerivedDeviceOperator::Invoker>);
return typename DerivedDeviceOperator::Invoker{};
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment