"git@developer.sourcefind.cn:OpenDAS/rodnet.git" did not exist on "80ded822a2b6f0411764305fc93d01184770c627"
Commit 734a12da authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Rename 'DevicePermute' to 'DevicePermuteImpl'

parent 16b116a9
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#include <utility> #include <utility>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_permute.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -7,7 +7,7 @@ using InDataType = F16; ...@@ -7,7 +7,7 @@ using InDataType = F16;
using OutDataType = F16; using OutDataType = F16;
// clang-format off // clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute using DevicePermuteInstance = ck::tensor_operation::device::DevicePermuteImpl
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst| // ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector| // ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | |
......
...@@ -9,7 +9,7 @@ using BundleType = F64; ...@@ -9,7 +9,7 @@ using BundleType = F64;
static_assert(sizeof(BundleType) % sizeof(DataType) == 0); static_assert(sizeof(BundleType) % sizeof(DataType) == 0);
// clang-format off // clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute using DevicePermuteInstance = ck::tensor_operation::device::DevicePermuteImpl
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst| // ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector| // ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | |
......
...@@ -7,7 +7,7 @@ using InDataType = F16; ...@@ -7,7 +7,7 @@ using InDataType = F16;
using OutDataType = F16; using OutDataType = F16;
// clang-format off // clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute using DevicePermuteInstance = ck::tensor_operation::device::DevicePermuteImpl
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst| // ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector| // ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | |
......
...@@ -4,228 +4,98 @@ ...@@ -4,228 +4,98 @@
#pragma once #pragma once
#include <array> #include <array>
#include <cmath>
#include <memory> #include <memory>
#include <utility> #include <type_traits>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_permute_base.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Swap last 2 dimensions template <index_t NumDim, typename InDataType, typename OutDataType, typename ElementwiseOperation>
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]] struct DevicePermute : BaseOperator
// ^^^^^^^^^^^ {
// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]] using Lengths = std::array<index_t, NumDim>;
// ^^^^^^^^^^^ using Strides = Lengths;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths inLengths,
const Strides inStrides,
const Lengths outLengths,
const Strides outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t NumDim, template <index_t NumDim,
typename InDataType, typename InDataType,
typename OutDataType, typename OutDataType,
typename ElementwiseOperation, typename ElementwiseOperation,
index_t BlockSize, typename DerivedDeviceOperator>
index_t NPerBlock, struct DevicePermuteCRTP : DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>
index_t HPerBlock,
index_t WPerBlock,
index_t InBlockLdsExtraW,
typename InBlockTransferThreadClusterLengths,
typename InBlockTransferThreadClusterArrangeOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector>
struct DevicePermute : DevicePermuteBaseCRTP<NumDim,
InDataType,
OutDataType,
ElementwiseOperation,
DevicePermute<NumDim,
InDataType,
OutDataType,
ElementwiseOperation,
BlockSize,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector>>
{ {
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor"); private:
static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim); using BaseType = DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>;
static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim);
static_assert(SrcVectorDim != DstVectorDim);
template <index_t N = NumDim> public:
static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array) // override methods inherited from 'BaseOperator'
bool IsSupportedArgument(const BaseArgument* arg) override final
{ {
static_assert(1 <= N && N <= NumDim); const auto* const argument =
dynamic_cast<const typename DerivedDeviceOperator::Argument*>(arg);
if(!argument)
{
return false;
}
return generate_tuple([&](auto I) { return array[I]; }, Number<N>{}); return DerivedDeviceOperator::IsSupportedArgument(*argument);
} }
static auto MakeDescriptor_N_H_W(const std::array<index_t, NumDim>& lengths, // override methods inherited from 'DevicePermute'
const std::array<index_t, NumDim>& stride) std::unique_ptr<BaseArgument>
MakeArgumentPointer(const typename BaseType::Lengths inLengths,
const typename BaseType::Strides inStrides,
const typename BaseType::Lengths outLengths,
const typename BaseType::Strides outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) override final
{ {
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], return std::make_unique<typename DerivedDeviceOperator::Argument>(inLengths,
// d[NumDim-1]] inStrides,
const auto desc = outLengths,
make_naive_tensor_descriptor(ConvertArrayToTuple(lengths), ConvertArrayToTuple(stride)); outStrides,
in_dev_buffer,
// merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2], out_dev_buffer,
// d[NumDim-1]] elementwise_op);
// => [N, H, W]
const index_t H = *std::next(rbegin(lengths));
const index_t W = *rbegin(lengths);
const auto desc_n_h_w = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(ConvertArrayToTuple<NumDim - 2>(lengths)),
make_pass_through_transform(H),
make_pass_through_transform(W)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
Sequence<NumDim - 2>{},
Sequence<NumDim - 1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return PadTensorDescriptor(
desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence<true, true, true>{});
} }
using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1})); std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
using OutGridDesc = InGridDesc;
using GridwisePermute = GridwisePermute<
InGridDesc,
OutGridDesc,
InDataType,
OutDataType,
ElementwiseOperation,
BlockSize,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
SrcVectorDim - (NumDim - 3), // calculate new SrcVectorDim for the merged descriptor
DstVectorDim - (NumDim - 3), // calculate new DstVectorDim for the merged descriptor
SrcScalarPerVector,
DstScalarPerVector>;
struct Argument : public BaseArgument
{ {
Argument(const std::array<index_t, NumDim> inLengths, return std::make_unique<typename DerivedDeviceOperator::Invoker>();
const std::array<index_t, NumDim> inStrides,
const std::array<index_t, NumDim> outLengths,
const std::array<index_t, NumDim> outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op)
: in_dev_buffer_(static_cast<const InDataType*>(in_dev_buffer)),
out_dev_buffer_(static_cast<OutDataType*>(out_dev_buffer)),
in_grid_desc_(MakeDescriptor_N_H_W(inLengths, inStrides)),
out_grid_desc_(MakeDescriptor_N_H_W(outLengths, outStrides)),
inLengths_(inLengths),
inStrides_(inStrides),
outLengths_(outLengths),
outStrides_(outStrides),
elementwise_op_(elementwise_op),
block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_))
{
}
const InDataType* in_dev_buffer_;
OutDataType* out_dev_buffer_;
InGridDesc in_grid_desc_;
OutGridDesc out_grid_desc_;
std::array<index_t, NumDim> inLengths_;
std::array<index_t, NumDim> inStrides_;
std::array<index_t, NumDim> outLengths_;
std::array<index_t, NumDim> outStrides_;
ElementwiseOperation elementwise_op_;
typename GridwisePermute::DefaultBlock2TileMap block_2_tile_map_;
}; };
struct Invoker : BaseInvokerCRTP<Invoker, Argument> // generate other utility methods
template <typename... Args>
static auto MakeArgument(Args&&... args)
{ {
static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) static_assert(std::is_constructible_v<typename DerivedDeviceOperator::Argument, Args...>);
{
const index_t grid_size = arg.block_2_tile_map_.CalculateGridSize(arg.in_grid_desc_);
const auto kernel = kernel_nd_permute<GridwisePermute,
InGridDesc,
OutGridDesc,
InDataType,
OutDataType,
ElementwiseOperation,
typename GridwisePermute::DefaultBlock2TileMap>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.in_grid_desc_,
arg.out_grid_desc_,
arg.in_dev_buffer_,
arg.out_dev_buffer_,
arg.elementwise_op_,
arg.block_2_tile_map_);
return elapsed_time;
}
};
static bool IsSupportedArgument(const Argument& arg) return typename DerivedDeviceOperator::Argument{std::forward<Args>(args)...};
}
static auto MakeInvoker() noexcept(
std::is_nothrow_default_constructible_v<typename DerivedDeviceOperator::Invoker>)
{ {
constexpr auto GetPaddedLength = [](index_t length, index_t tile_length) { static_assert(std::is_default_constructible_v<typename DerivedDeviceOperator::Invoker>);
return math::integer_divide_ceil(length, tile_length) * tile_length;
}; return typename DerivedDeviceOperator::Invoker{};
}
constexpr auto IsScalarPerVectorValid =
[](index_t length, index_t stride, index_t scalar_per_vector) {
if(stride == 1 && length % scalar_per_vector == 0)
{
return true;
}
else if(stride != 1 && scalar_per_vector == 1)
{
return true;
}
return false;
};
return IsScalarPerVectorValid(arg.inLengths_[SrcVectorDim],
arg.inStrides_[SrcVectorDim],
SrcScalarPerVector) &&
IsScalarPerVectorValid(
GetPaddedLength(arg.inLengths_[SrcVectorDim],
(SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
arg.inStrides_[SrcVectorDim],
SrcScalarPerVector) &&
IsScalarPerVectorValid(arg.outLengths_[DstVectorDim],
arg.outStrides_[DstVectorDim],
DstScalarPerVector) &&
IsScalarPerVectorValid(
GetPaddedLength(arg.outLengths_[DstVectorDim],
(DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
arg.inStrides_[DstVectorDim],
DstScalarPerVector) &&
GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
};
}; };
} // namespace device } // namespace device
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <cmath>
#include <memory>
#include <type_traits>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumDim, typename InDataType, typename OutDataType, typename ElementwiseOperation>
struct DevicePermuteBase : BaseOperator
{
using Lengths = std::array<index_t, NumDim>;
using Strides = Lengths;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths inLengths,
const Strides inStrides,
const Lengths outLengths,
const Strides outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t NumDim,
typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
typename DerivedDeviceOperator>
struct DevicePermuteBaseCRTP
: DevicePermuteBase<NumDim, InDataType, OutDataType, ElementwiseOperation>
{
private:
using BaseType = DevicePermuteBase<NumDim, InDataType, OutDataType, ElementwiseOperation>;
public:
// override methods inherited from 'BaseOperator'
bool IsSupportedArgument(const BaseArgument* arg) override final
{
const auto* const argument =
dynamic_cast<const typename DerivedDeviceOperator::Argument*>(arg);
if(!argument)
{
return false;
}
return DerivedDeviceOperator::IsSupportedArgument(*argument);
}
// override methods inherited from 'DevicePermuteBase'
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const typename BaseType::Lengths inLengths,
const typename BaseType::Strides inStrides,
const typename BaseType::Lengths outLengths,
const typename BaseType::Strides outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) override final
{
return std::make_unique<typename DerivedDeviceOperator::Argument>(inLengths,
inStrides,
outLengths,
outStrides,
in_dev_buffer,
out_dev_buffer,
elementwise_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
{
return std::make_unique<typename DerivedDeviceOperator::Invoker>();
};
// generate other utility methods
template <typename... Args>
static auto MakeArgument(Args&&... args)
{
static_assert(std::is_constructible_v<typename DerivedDeviceOperator::Argument, Args...>);
return typename DerivedDeviceOperator::Argument{std::forward<Args>(args)...};
}
static auto MakeInvoker() noexcept(
std::is_nothrow_default_constructible_v<typename DerivedDeviceOperator::Invoker>)
{
static_assert(std::is_default_constructible_v<typename DerivedDeviceOperator::Invoker>);
return typename DerivedDeviceOperator::Invoker{};
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include <utility>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_permute.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Swap last 2 dimensions
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
// ^^^^^^^^^^^
// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]]
// ^^^^^^^^^^^
template <index_t NumDim,
typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
index_t BlockSize,
index_t NPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t InBlockLdsExtraW,
typename InBlockTransferThreadClusterLengths,
typename InBlockTransferThreadClusterArrangeOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector>
struct DevicePermuteImpl
: DevicePermuteCRTP<NumDim,
InDataType,
OutDataType,
ElementwiseOperation,
DevicePermuteImpl<NumDim,
InDataType,
OutDataType,
ElementwiseOperation,
BlockSize,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector>>
{
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim);
static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim);
static_assert(SrcVectorDim != DstVectorDim);
template <index_t N = NumDim>
static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array)
{
static_assert(1 <= N && N <= NumDim);
return generate_tuple([&](auto I) { return array[I]; }, Number<N>{});
}
static auto MakeDescriptor_N_H_W(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride)
{
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
// d[NumDim-1]]
const auto desc =
make_naive_tensor_descriptor(ConvertArrayToTuple(lengths), ConvertArrayToTuple(stride));
// merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2],
// d[NumDim-1]]
// => [N, H, W]
const index_t H = *std::next(rbegin(lengths));
const index_t W = *rbegin(lengths);
const auto desc_n_h_w = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(ConvertArrayToTuple<NumDim - 2>(lengths)),
make_pass_through_transform(H),
make_pass_through_transform(W)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
Sequence<NumDim - 2>{},
Sequence<NumDim - 1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return PadTensorDescriptor(
desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence<true, true, true>{});
}
using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}));
using OutGridDesc = InGridDesc;
using GridwisePermute = GridwisePermute<
InGridDesc,
OutGridDesc,
InDataType,
OutDataType,
ElementwiseOperation,
BlockSize,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
SrcVectorDim - (NumDim - 3), // calculate new SrcVectorDim for the merged descriptor
DstVectorDim - (NumDim - 3), // calculate new DstVectorDim for the merged descriptor
SrcScalarPerVector,
DstScalarPerVector>;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> inLengths,
const std::array<index_t, NumDim> inStrides,
const std::array<index_t, NumDim> outLengths,
const std::array<index_t, NumDim> outStrides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op)
: in_dev_buffer_(static_cast<const InDataType*>(in_dev_buffer)),
out_dev_buffer_(static_cast<OutDataType*>(out_dev_buffer)),
in_grid_desc_(MakeDescriptor_N_H_W(inLengths, inStrides)),
out_grid_desc_(MakeDescriptor_N_H_W(outLengths, outStrides)),
inLengths_(inLengths),
inStrides_(inStrides),
outLengths_(outLengths),
outStrides_(outStrides),
elementwise_op_(elementwise_op),
block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_))
{
}
const InDataType* in_dev_buffer_;
OutDataType* out_dev_buffer_;
InGridDesc in_grid_desc_;
OutGridDesc out_grid_desc_;
std::array<index_t, NumDim> inLengths_;
std::array<index_t, NumDim> inStrides_;
std::array<index_t, NumDim> outLengths_;
std::array<index_t, NumDim> outStrides_;
ElementwiseOperation elementwise_op_;
typename GridwisePermute::DefaultBlock2TileMap block_2_tile_map_;
};
struct Invoker : BaseInvokerCRTP<Invoker, Argument>
{
static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t grid_size = arg.block_2_tile_map_.CalculateGridSize(arg.in_grid_desc_);
const auto kernel = kernel_nd_permute<GridwisePermute,
InGridDesc,
OutGridDesc,
InDataType,
OutDataType,
ElementwiseOperation,
typename GridwisePermute::DefaultBlock2TileMap>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.in_grid_desc_,
arg.out_grid_desc_,
arg.in_dev_buffer_,
arg.out_dev_buffer_,
arg.elementwise_op_,
arg.block_2_tile_map_);
return elapsed_time;
}
};
static bool IsSupportedArgument(const Argument& arg)
{
constexpr auto GetPaddedLength = [](index_t length, index_t tile_length) {
return math::integer_divide_ceil(length, tile_length) * tile_length;
};
constexpr auto IsScalarPerVectorValid =
[](index_t length, index_t stride, index_t scalar_per_vector) {
if(stride == 1 && length % scalar_per_vector == 0)
{
return true;
}
else if(stride != 1 && scalar_per_vector == 1)
{
return true;
}
return false;
};
return IsScalarPerVectorValid(arg.inLengths_[SrcVectorDim],
arg.inStrides_[SrcVectorDim],
SrcScalarPerVector) &&
IsScalarPerVectorValid(
GetPaddedLength(arg.inLengths_[SrcVectorDim],
(SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
arg.inStrides_[SrcVectorDim],
SrcScalarPerVector) &&
IsScalarPerVectorValid(arg.outLengths_[DstVectorDim],
arg.outStrides_[DstVectorDim],
DstScalarPerVector) &&
IsScalarPerVectorValid(
GetPaddedLength(arg.outLengths_[DstVectorDim],
(DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
arg.inStrides_[DstVectorDim],
DstScalarPerVector) &&
GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
};
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment