Commit f8510368 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use date type directly as template argument

parent 62d7361f
...@@ -99,9 +99,6 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -99,9 +99,6 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
{ {
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor"); static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
using InDataTypePointer = const InDataType*;
using OutDataTypePointer = OutDataType*;
template <index_t N = NumDim> template <index_t N = NumDim>
static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array) static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array)
{ {
...@@ -142,8 +139,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -142,8 +139,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
using GridwisePermute = GridwisePermute<InGridDesc, using GridwisePermute = GridwisePermute<InGridDesc,
OutGridDesc, OutGridDesc,
InDataTypePointer, InDataType,
OutDataTypePointer, OutDataType,
ElementwiseOperation, ElementwiseOperation,
BlockSize, BlockSize,
NPerBlock, NPerBlock,
...@@ -162,8 +159,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -162,8 +159,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
const void* in_dev_buffer, const void* in_dev_buffer,
void* out_dev_buffer, void* out_dev_buffer,
ElementwiseOperation elementwise_op) ElementwiseOperation elementwise_op)
: in_dev_buffer_(static_cast<InDataTypePointer>(in_dev_buffer)), : in_dev_buffer_(static_cast<const InDataType*>(in_dev_buffer)),
out_dev_buffer_(static_cast<OutDataTypePointer>(out_dev_buffer)), out_dev_buffer_(static_cast<OutDataType*>(out_dev_buffer)),
in_grid_desc_(MakeDescriptor_N_H_W(inLengths, inStrides)), in_grid_desc_(MakeDescriptor_N_H_W(inLengths, inStrides)),
out_grid_desc_(MakeDescriptor_N_H_W(inLengths, inStrides)), out_grid_desc_(MakeDescriptor_N_H_W(inLengths, inStrides)),
inLengths_(inLengths), inLengths_(inLengths),
...@@ -175,8 +172,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -175,8 +172,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
{ {
} }
InDataTypePointer in_dev_buffer_; const InDataType* in_dev_buffer_;
OutDataTypePointer out_dev_buffer_; OutDataType* out_dev_buffer_;
InGridDesc in_grid_desc_; InGridDesc in_grid_desc_;
OutGridDesc out_grid_desc_; OutGridDesc out_grid_desc_;
...@@ -199,8 +196,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType, ...@@ -199,8 +196,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
const auto kernel = kernel_nd_permute<GridwisePermute, const auto kernel = kernel_nd_permute<GridwisePermute,
InGridDesc, InGridDesc,
OutGridDesc, OutGridDesc,
InDataTypePointer, InDataType,
OutDataTypePointer, OutDataType,
ElementwiseOperation, ElementwiseOperation,
typename GridwisePermute::DefaultBlock2TileMap>; typename GridwisePermute::DefaultBlock2TileMap>;
......
...@@ -68,14 +68,14 @@ struct Block2TileMap ...@@ -68,14 +68,14 @@ struct Block2TileMap
template <typename GridwisePermute, template <typename GridwisePermute,
typename InGridDesc, typename InGridDesc,
typename OutGridDesc, typename OutGridDesc,
typename InDataTypePointer, typename InDataType,
typename OutDataTypePointer, typename OutDataType,
typename ElementwiseOperation, typename ElementwiseOperation,
typename Block2TileMap> typename Block2TileMap>
__global__ void kernel_nd_permute(const InGridDesc in_grid_desc, __global__ void kernel_nd_permute(const InGridDesc in_grid_desc,
const OutGridDesc out_grid_desc, const OutGridDesc out_grid_desc,
const InDataTypePointer p_in_global, const InDataType* p_in_global,
const OutDataTypePointer p_out_global, OutDataType* p_out_global,
const ElementwiseOperation elementwise_op, const ElementwiseOperation elementwise_op,
const Block2TileMap block_2_tile_map) const Block2TileMap block_2_tile_map)
{ {
...@@ -92,8 +92,8 @@ __global__ void kernel_nd_permute(const InGridDesc in_grid_desc, ...@@ -92,8 +92,8 @@ __global__ void kernel_nd_permute(const InGridDesc in_grid_desc,
template <typename InGridDesc, template <typename InGridDesc,
typename OutGridDesc, typename OutGridDesc,
typename InDataTypePointer, typename InDataType,
typename OutDataTypePointer, typename OutDataType,
typename ElementwiseOperation, typename ElementwiseOperation,
index_t BlockSize, index_t BlockSize,
index_t NPerBlock, index_t NPerBlock,
...@@ -129,8 +129,6 @@ struct GridwisePermute ...@@ -129,8 +129,6 @@ struct GridwisePermute
{ {
constexpr auto in_block_desc = GetInBlockDesc(); constexpr auto in_block_desc = GetInBlockDesc();
using InDataType = remove_cv_t<remove_pointer_t<InDataTypePointer>>;
return in_block_desc.GetElementSpaceSize() * sizeof(InDataType); return in_block_desc.GetElementSpaceSize() * sizeof(InDataType);
} }
...@@ -142,15 +140,12 @@ struct GridwisePermute ...@@ -142,15 +140,12 @@ struct GridwisePermute
template <typename Block2TileMap> template <typename Block2TileMap>
__device__ static void Run(const InGridDesc in_grid_desc, __device__ static void Run(const InGridDesc in_grid_desc,
const OutGridDesc out_grid_desc, const OutGridDesc out_grid_desc,
const InDataTypePointer p_in_global, const InDataType* p_in_global,
const OutDataTypePointer p_out_global, OutDataType* p_out_global,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const ElementwiseOperation elementwise_op, const ElementwiseOperation elementwise_op,
const Block2TileMap& block_2_tile_map) const Block2TileMap& block_2_tile_map)
{ {
using InDataType = remove_cv_t<remove_pointer_t<InDataTypePointer>>;
using OutDataType = remove_cv_t<remove_pointer_t<OutDataTypePointer>>;
auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc.GetElementSpaceSize()); p_in_global, in_grid_desc.GetElementSpaceSize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment