"vscode:/vscode.git/clone" did not exist on "5be3c06485003425e0a6892fac4fc33157d47ab3"
Commit f8510368 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use date type directly as template argument

parent 62d7361f
......@@ -99,9 +99,6 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
{
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
using InDataTypePointer = const InDataType*;
using OutDataTypePointer = OutDataType*;
template <index_t N = NumDim>
static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array)
{
......@@ -142,8 +139,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
using GridwisePermute = GridwisePermute<InGridDesc,
OutGridDesc,
InDataTypePointer,
OutDataTypePointer,
InDataType,
OutDataType,
ElementwiseOperation,
BlockSize,
NPerBlock,
......@@ -162,8 +159,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op)
: in_dev_buffer_(static_cast<InDataTypePointer>(in_dev_buffer)),
out_dev_buffer_(static_cast<OutDataTypePointer>(out_dev_buffer)),
: in_dev_buffer_(static_cast<const InDataType*>(in_dev_buffer)),
out_dev_buffer_(static_cast<OutDataType*>(out_dev_buffer)),
in_grid_desc_(MakeDescriptor_N_H_W(inLengths, inStrides)),
out_grid_desc_(MakeDescriptor_N_H_W(inLengths, inStrides)),
inLengths_(inLengths),
......@@ -175,8 +172,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
{
}
InDataTypePointer in_dev_buffer_;
OutDataTypePointer out_dev_buffer_;
const InDataType* in_dev_buffer_;
OutDataType* out_dev_buffer_;
InGridDesc in_grid_desc_;
OutGridDesc out_grid_desc_;
......@@ -199,8 +196,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
const auto kernel = kernel_nd_permute<GridwisePermute,
InGridDesc,
OutGridDesc,
InDataTypePointer,
OutDataTypePointer,
InDataType,
OutDataType,
ElementwiseOperation,
typename GridwisePermute::DefaultBlock2TileMap>;
......
......@@ -68,14 +68,14 @@ struct Block2TileMap
template <typename GridwisePermute,
typename InGridDesc,
typename OutGridDesc,
typename InDataTypePointer,
typename OutDataTypePointer,
typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
typename Block2TileMap>
__global__ void kernel_nd_permute(const InGridDesc in_grid_desc,
const OutGridDesc out_grid_desc,
const InDataTypePointer p_in_global,
const OutDataTypePointer p_out_global,
const InDataType* p_in_global,
OutDataType* p_out_global,
const ElementwiseOperation elementwise_op,
const Block2TileMap block_2_tile_map)
{
......@@ -92,8 +92,8 @@ __global__ void kernel_nd_permute(const InGridDesc in_grid_desc,
template <typename InGridDesc,
typename OutGridDesc,
typename InDataTypePointer,
typename OutDataTypePointer,
typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
index_t BlockSize,
index_t NPerBlock,
......@@ -129,8 +129,6 @@ struct GridwisePermute
{
constexpr auto in_block_desc = GetInBlockDesc();
using InDataType = remove_cv_t<remove_pointer_t<InDataTypePointer>>;
return in_block_desc.GetElementSpaceSize() * sizeof(InDataType);
}
......@@ -142,15 +140,12 @@ struct GridwisePermute
template <typename Block2TileMap>
__device__ static void Run(const InGridDesc in_grid_desc,
const OutGridDesc out_grid_desc,
const InDataTypePointer p_in_global,
const OutDataTypePointer p_out_global,
const InDataType* p_in_global,
OutDataType* p_out_global,
void* __restrict__ p_shared,
const ElementwiseOperation elementwise_op,
const Block2TileMap& block_2_tile_map)
{
using InDataType = remove_cv_t<remove_pointer_t<InDataTypePointer>>;
using OutDataType = remove_cv_t<remove_pointer_t<OutDataTypePointer>>;
auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc.GetElementSpaceSize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment