Commit fdcfae3a authored by Chao Liu's avatar Chao Liu
Browse files

reimplement threadwise copy

parent adc10088
...@@ -157,7 +157,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -157,7 +157,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize, BlockSize,
Float,
decltype(in_c_h_w_n_global_desc), decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc), decltype(in_c_h_w_n_block_desc),
NormalTensorCoordinate<decltype(in_c_h_w_n_global_desc)>, NormalTensorCoordinate<decltype(in_c_h_w_n_global_desc)>,
......
...@@ -176,7 +176,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -176,7 +176,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
#else #else
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize, BlockSize,
Float,
decltype(in_e_n1_b_n2_global_merged_desc), decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc), decltype(in_e_n1_b_n2_block_desc),
MergedTensorCoordinate<decltype(in_e_n1_b_n2_global_merged_desc)>, MergedTensorCoordinate<decltype(in_e_n1_b_n2_global_merged_desc)>,
...@@ -219,7 +218,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -219,7 +218,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
#else #else
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize, BlockSize,
Float,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>, NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
...@@ -373,7 +371,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -373,7 +371,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
Number<1>{}); Number<1>{});
#else #else
ThreadwiseGenericTensorSliceCopy_v2< ThreadwiseGenericTensorSliceCopy_v2<
Float,
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
NormalTensorCoordinate<decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc)>, NormalTensorCoordinate<decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc)>,
......
...@@ -131,7 +131,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -131,7 +131,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v2<BlockSize,
Float,
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc), decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>, MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
...@@ -158,7 +157,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -158,7 +157,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize, BlockSize,
Float,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>, NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
...@@ -288,7 +286,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -288,7 +286,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>; Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2< auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2<
Float,
decltype(out_k0_k1_b_thread_desc), decltype(out_k0_k1_b_thread_desc),
decltype(out_k0_k1_b_global_desc), decltype(out_k0_k1_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>, NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>,
......
...@@ -131,7 +131,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -131,7 +131,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v2<BlockSize,
Float,
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc), decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>, MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
...@@ -158,7 +157,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -158,7 +157,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize, BlockSize,
Float,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>, NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
...@@ -352,7 +350,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -352,7 +350,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>; Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2< auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2<
Float,
decltype(out_k0_k1_b_thread_desc), decltype(out_k0_k1_b_thread_desc),
decltype(out_k0_k1_b_global_desc), decltype(out_k0_k1_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>, NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>,
......
...@@ -65,11 +65,21 @@ struct ConstantMergedTensorDescriptor ...@@ -65,11 +65,21 @@ struct ConstantMergedTensorDescriptor
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}), static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
"wrong! stride of a merged dimension is undefined"); "wrong! stride of a merged dimension is undefined");
constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front(); constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
return OriginalTensorDesc::GetStride(Number<idim_original>{}); return OriginalTensorDesc::GetStride(Number<idim_original>{});
} }
// this is a hack to return the stride of the last original dimension of a merged dimension
// TODO: refactor this once the concept of "dimension" is used
template <index_t IDim>
__host__ __device__ static constexpr auto GetLastOriginalDimensionStride(Number<IDim>)
{
constexpr auto idim_last_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
return OriginalTensorDesc::GetStride(Number<idim_last_original>{});
}
__host__ __device__ static constexpr auto GetLengths() __host__ __device__ static constexpr auto GetLengths()
{ {
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{}; return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
......
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
namespace ck { namespace ck {
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor // Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst. // memory layout (ordering of dimensions) can be different between src and dst.
// on a merged dimension that constains multiple original dimensions, // This functions assume each thread is reading and writing a normal (not merged) tensor,
// its sub-length need to evenly divide the length of the last original dimension // to simplify index calculations. To satisfy this assumption, the user need to make sure
// so each thread is effectively reading a normal (not merged) tensor // that, on a merged dimension that constains multiple original dimensions, the length of
// the last original dimension need to be evenly dividable by its sub-lengths. Also, the
// repeat-length on the merged dimension need to be 1.
template <index_t BlockSize, template <index_t BlockSize,
class Float, class Float,
class SrcDesc, class SrcDesc,
...@@ -88,30 +90,55 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -88,30 +90,55 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr auto data_per_cluster_per_dims = SubLengths{} * ThreadClusterLengths{}; constexpr auto data_per_cluster_per_dims = SubLengths{} * ThreadClusterLengths{};
static_for<0, nDim, 1>{}([&](auto IDim) { static_for<0, nDim, 1>{}([&](auto IDim) {
static_assert(SliceLengths::Get(IDim) % SubLengths::Get(IDim) == 0,
"wrong! cannot evenly divide sliced tensor into sub-tensor");
static_assert(SliceLengths::Get(IDim) % data_per_cluster_per_dims.Get(IDim) == 0, static_assert(SliceLengths::Get(IDim) % data_per_cluster_per_dims.Get(IDim) == 0,
"wrong! cannot evenly divide sliced tensor into cluster"); "wrong! cannot evenly divide sliced tensor into cluster");
}); });
constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims;
// additional check for merged dimension
static_for<0, nDim, 1>{}([&](auto IDim_) {
// src
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
constexpr auto IDim = decltype(IDim_){};
// on a merged dimension that constains multiple original dimensions, // on a merged dimension that constains multiple original dimensions,
// its sub-length need to evenly divide the length of the last original dimension, // the length of the last original dimension need to evenly dividable by its
// sub-length,
// so each thread is effectively reading a normal (not merged) tensor // so each thread is effectively reading a normal (not merged) tensor
static_for<0, nDim, 1>{}([&](auto IDim) { constexpr auto idim_last_original_src =
constexpr auto sub_length = SubLengths::Get(IDim); SrcDesc::GetContainedOriginalDimensions(IDim).Back();
static_assert(
constexpr auto idim_original_src = SrcDesc::GetContainedOriginalDimensions(IDim).Back(); SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_src) %
static_assert(SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_original_src) % SubLengths::Get(IDim) ==
sub_length ==
0, 0,
"wrong!"); "wrong!");
constexpr auto idim_original_dst = DstDesc::GetContainedOriginalDimensions(IDim).Back(); // merged dimension should have repeat_lengths = 1
static_assert(DstDesc::GetOriginalTensorDescriptor().GetLength(idim_original_dst) % static_assert(repeat_lengths[IDim] == 1,
sub_length == "wrong! repeat_lengths shoud be 1 on merged dimension");
});
// dst
static_if<DstDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
constexpr auto IDim = decltype(IDim_){};
// on a merged dimension that constains multiple original dimensions,
// the length of the last original dimension need to evenly dividable by its
// sub-length,
// so each thread is effectively reading a normal (not merged) tensor
constexpr auto idim_last_original_dst =
DstDesc::GetContainedOriginalDimensions(IDim).Back();
static_assert(
DstDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_dst) %
SubLengths::Get(IDim) ==
0, 0,
"wrong!"); "wrong!");
// merged dimension should have repeat_lengths = 1
static_assert(repeat_lengths[IDim] == 1,
"wrong! repeat_lengths shoud be 1 on merged dimension");
});
}); });
// calculate mThreadSrcOffset, mThreadDstOffset // calculate mThreadSrcOffset, mThreadDstOffset
...@@ -376,7 +403,6 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -376,7 +403,6 @@ struct BlockwiseGenericTensorSliceCopy_v1
}; };
template <index_t BlockSize, template <index_t BlockSize,
class TData,
class SrcDesc, class SrcDesc,
class DstDesc, class DstDesc,
class SrcCoordinate, class SrcCoordinate,
...@@ -428,16 +454,19 @@ struct BlockwiseGenericTensorSliceCopy_v2 ...@@ -428,16 +454,19 @@ struct BlockwiseGenericTensorSliceCopy_v2
return RegisterBufferDesc::GetElementSpace(); return RegisterBufferDesc::GetElementSpace();
} }
template <class TData>
__device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const __device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const
{ {
mThreadwiseLoad.Run(p_src, p_buffer); mThreadwiseLoad.Run(p_src, p_buffer);
} }
template <class TData>
__device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const __device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const
{ {
mThreadwiseStore.Run(p_buffer, p_dst); mThreadwiseStore.Run(p_buffer, p_dst);
} }
template <class TData>
__device__ void Run(const TData* p_src, TData* p_dst) const __device__ void Run(const TData* p_src, TData* p_dst) const
{ {
TData p_buffer[GetRegisterBufferSize()]; TData p_buffer[GetRegisterBufferSize()];
...@@ -466,16 +495,14 @@ struct BlockwiseGenericTensorSliceCopy_v2 ...@@ -466,16 +495,14 @@ struct BlockwiseGenericTensorSliceCopy_v2
using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{})); using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
using ThreadwiseLoad = using ThreadwiseLoad =
ThreadwiseGenericTensorSliceCopy_v2<TData, ThreadwiseGenericTensorSliceCopy_v2<SrcDesc,
SrcDesc,
RegisterBufferDesc, RegisterBufferDesc,
SrcCoordinate, SrcCoordinate,
NormalTensorCoordinate<RegisterBufferDesc>, NormalTensorCoordinate<RegisterBufferDesc>,
SubLengths>; SubLengths>;
using ThreadwiseStore = using ThreadwiseStore =
ThreadwiseGenericTensorSliceCopy_v2<TData, ThreadwiseGenericTensorSliceCopy_v2<RegisterBufferDesc,
RegisterBufferDesc,
DstDesc, DstDesc,
NormalTensorCoordinate<RegisterBufferDesc>, NormalTensorCoordinate<RegisterBufferDesc>,
DstCoordinate, DstCoordinate,
......
...@@ -106,8 +106,107 @@ __device__ void threadwise_generic_tensor_slice_copy_v1( ...@@ -106,8 +106,107 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
#endif #endif
} }
template <class TData, #if 0
class SrcDesc, template <class SrcDesc,
class DstDesc,
class SliceLengths,
class SrcDimAccessOrder,
class DstDimAccessOrder,
index_t SrcVectorAccessDim,
index_t DstVectorAccessDim,
index_t SrcDataPerAccess,
index_t DstDataPerAccess>
struct ThreadwiseGenericTensorSliceCopy_v1
{
static constexpr index_t nDim = SliceLengths::GetNumOfDimension();
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_slice_origin,
Array<index_t, nDim> dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
{
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
nDim == SrcDimAccessOrder::GetSize() &&
nDim == DstDimAccessOrder::GetSize(),
"wrong! # of dimensions not the same");
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::{} &&
is_valid_sequence_map<DstDimAccessOrder>::{},
"wrong! map is not valid");
static_assert(SliceLengths{}[SrcVectorDim] % SrcDataPerAccess == 0 &&
SliceLengths{DstVectorDim} % DstDataPerAccess == 0,
"wrong! cannot evenly divide");
// check vectorized memory access
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDIm>{};
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDIm>{};
static_if<!SrcDesc::ContainMultipleOriginalDimensions(
src_vector_access_dim)>{}([&](auto fwd) {
static_assert(
(fwd(SrcDesc{}).GetStrides()[SrcVectorAccessDim] == 1 || SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
}).Else{}([&](auto fwd) {
static_assert((SrcDesc::GetLastOriginalDimensionStride(src_vector_access_dim) == 1 ||
SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
static_if<!DstDesc::ContainMultipleOriginalDimensions(
dst_vector_access_dim)>{}([&](auto fwd) {
static_assert(
(fwd(DstDesc{}).GetStrides()[DstVectorAccessDim] == 1 || DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
}).Else{}([&](auto fwd) {
static_assert((DstDesc::GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 ||
DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1()
: ThreadwiseGenericTensorSliceCopy_v1(make_zero_array<index_t, nDim>(),
make_zero_array<index_t, nDim>())
{
}
__device__ void SetSrcSliceOrigin(Array<index_t, nDim> src_slice_origin)
{
mSrcSliceOrigin = src_slice_origin;
}
__device__ void SetDstSliceOrigin(Array<index_t, nDim> dst_slice_origin)
{
mDstSliceOrigin = dst_slice_origin;
}
template <class TData>
__device__ void Run(const TData* p_src, TData* p_dst) const
{
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
TData p_buffer[buffer_desc.GetElementSpace()];
// copy data from src into buffer
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDIm>{};
constexpr auto src_access_lengths = SliceLengths::Modify(
src_vector_access_dim, SliceLengths::Get(src_vector_access_dim) / SrcDataPerAccess);
constexpr auto src_access_lengths_in_src_access_order =
src_access_lengths.ReorderGivenNew2Old(SrcDimAccessOrder{});
static_ford<decltype(src_access_lengths_in_src_access_order)>{}([&](auto src_access_id) {});
}
private:
Array<index_t, TData> mSrcSliceOrigin;
Array<index_t, TData> mDstSliceOrigin;
};
#endif
template <class SrcDesc,
class DstDesc, class DstDesc,
class SrcCoordinate, class SrcCoordinate,
class DstCoordinate, class DstCoordinate,
...@@ -116,18 +215,18 @@ struct ThreadwiseGenericTensorSliceCopy_v2 ...@@ -116,18 +215,18 @@ struct ThreadwiseGenericTensorSliceCopy_v2
{ {
static constexpr index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2()
: mSrcSliceOrigin(make_zero_array<index_t, nDim>()),
mDstSliceOrigin(make_zero_array<index_t, nDim>())
{
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2(SrcCoordinate src_slice_origin, __device__ constexpr ThreadwiseGenericTensorSliceCopy_v2(SrcCoordinate src_slice_origin,
DstCoordinate dst_slice_origin) DstCoordinate dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin) : mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
{ {
} }
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2()
: ThreadwiseGenericTensorSliceCopy_v2(make_zero_array<index_t, nDim>(),
make_zero_array<index_t, nDim>())
{
}
__device__ void SetSrcSliceOrigin(SrcCoordinate src_slice_origin) __device__ void SetSrcSliceOrigin(SrcCoordinate src_slice_origin)
{ {
mSrcSliceOrigin = src_slice_origin; mSrcSliceOrigin = src_slice_origin;
...@@ -148,6 +247,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2 ...@@ -148,6 +247,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2
} }
}; };
template <class TData>
__device__ void Run(const TData* p_src, TData* p_dst) const __device__ void Run(const TData* p_src, TData* p_dst) const
{ {
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{}); constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
...@@ -216,6 +316,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2 ...@@ -216,6 +316,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2
}); });
} }
// T can be Sequence or Array
template <class T, bool PositiveDirection> template <class T, bool PositiveDirection>
__device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>) __device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
{ {
...@@ -232,7 +333,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2 ...@@ -232,7 +333,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; }); }).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
} }
// private: private:
SrcCoordinate mSrcSliceOrigin; SrcCoordinate mSrcSliceOrigin;
DstCoordinate mDstSliceOrigin; DstCoordinate mDstSliceOrigin;
}; };
......
...@@ -6,9 +6,12 @@ ...@@ -6,9 +6,12 @@
namespace ck { namespace ck {
template <class Seq> template <class>
struct is_valid_sequence_map; struct is_valid_sequence_map;
template <class>
struct sequence_map_inverse;
template <index_t... Is> template <index_t... Is>
struct Sequence struct Sequence
{ {
...@@ -34,6 +37,8 @@ struct Sequence ...@@ -34,6 +37,8 @@ struct Sequence
return Number<GetImpl(Number<I>{})>{}; return Number<GetImpl(Number<I>{})>{};
} }
__host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); }
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I>) const __host__ __device__ constexpr auto operator[](Number<I>) const
{ {
...@@ -54,6 +59,18 @@ struct Sequence ...@@ -54,6 +59,18 @@ struct Sequence
return Sequence<Type::Get(Number<IRs>{})...>{}; return Sequence<Type::Get(Number<IRs>{})...>{};
} }
// MapOld2New is Sequence<...>
template <class MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{
static_assert(MapOld2New::GetSize() == GetSize(),
"wrong! reorder map should have the same size as Sequence to be rerodered");
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
}
__host__ __device__ static constexpr auto Reverse(); __host__ __device__ static constexpr auto Reverse();
__host__ __device__ static constexpr auto Front() __host__ __device__ static constexpr auto Front()
...@@ -253,6 +270,7 @@ struct sequence_reverse<Sequence<I0, I1>> ...@@ -253,6 +270,7 @@ struct sequence_reverse<Sequence<I0, I1>>
template <class Seq> template <class Seq>
struct is_valid_sequence_map struct is_valid_sequence_map
{ {
// not implemented yet, always return true
static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{}; static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{};
// TODO: add proper check for is_valid, something like: // TODO: add proper check for is_valid, something like:
...@@ -261,6 +279,33 @@ struct is_valid_sequence_map ...@@ -261,6 +279,33 @@ struct is_valid_sequence_map
// typename sequence_sort<Seq>::SortedSeqType>{}; // typename sequence_sort<Seq>::SortedSeqType>{};
}; };
template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
{
private:
static constexpr auto new_y2x = WorkingY2X::Modify(X2Y{}[XBegin], XBegin);
public:
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type;
};
template <class X2Y, class WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
using type = WorkingY2X;
};
template <class X2Y>
struct sequence_map_inverse
{
using type =
typename sequence_map_inverse_impl<X2Y,
typename uniform_sequence_gen<X2Y::GetSize(), 0>::type,
0,
X2Y::GetSize()>::type;
};
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>) __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
{ {
......
...@@ -132,7 +132,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -132,7 +132,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = constexpr auto gridwise_conv =
#if 1 #if 0
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
#else #else
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
......
...@@ -379,7 +379,7 @@ int main(int argc, char* argv[]) ...@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 0 #elif 1
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment