Commit df29a7e0 authored by Chao Liu's avatar Chao Liu
Browse files

enabling vector load on merged dim

parent 37b82b7e
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
namespace ck { namespace ck {
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor // slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst // memory layout (ordering of dimensions) can be different between src and dst.
// For now, only support SubLengths[...] == 1 on a merged dimension // on a merged dimension that constains multiple original dimensions,
// its sub-length need to evenly divide the length of the last original dimension
// so each thread is effectively reading a normal (not merged) tensor
template <index_t BlockSize, template <index_t BlockSize,
class Float, class Float,
class SrcDesc, class SrcDesc,
...@@ -75,7 +77,7 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -75,7 +77,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
// thread cluster // thread cluster
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed( constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
DataClusterLengths{}.ReorderGivenNew2Old(ThreadClusterArrangeOrder{})); DataClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
// BlockSize // BlockSize
static_assert(BlockSize == thread_cluster_desc.GetElementSize(), "wrong! BlockSize"); static_assert(BlockSize == thread_cluster_desc.GetElementSize(), "wrong! BlockSize");
...@@ -91,13 +93,23 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -91,13 +93,23 @@ struct BlockwiseGenericTensorSliceCopy_v1
"wrong! cannot evenly divide sliced tensor into cluster"); "wrong! cannot evenly divide sliced tensor into cluster");
}); });
// for now, only support SubLengths == 1 on a merged dimension that constains // on a merged dimension that constains multiple original dimensions,
// multiple original dimensions // its sub-length need to evenly divide the length of the last original dimension,
// so each thread is effectively reading a normal (not merged) tensor
static_for<0, nDim, 1>{}([&](auto IDim) { static_for<0, nDim, 1>{}([&](auto IDim) {
static_assert(SubLengths::Get(IDim) == 1 || constexpr auto sub_length = SubLengths::Get(IDim);
(!SrcDesc::ContainMultipleOriginalDimensions(IDim) &&
!DstDesc::ContainMultipleOriginalDimensions(IDim)), constexpr auto idim_original_src = SrcDesc::GetContainedOriginalDimensions(IDim).Back();
"wrong! only surpport Sub-Length == 1 on a merged dimension"); static_assert(SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_original_src) %
sub_length ==
0,
"wrong!");
constexpr auto idim_original_dst = DstDesc::GetContainedOriginalDimensions(IDim).Back();
static_assert(DstDesc::GetOriginalTensorDescriptor().GetLength(idim_original_dst) %
sub_length ==
0,
"wrong!");
}); });
// calculate mThreadSrcOffset, mThreadDstOffset // calculate mThreadSrcOffset, mThreadDstOffset
...@@ -118,28 +130,24 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -118,28 +130,24 @@ struct BlockwiseGenericTensorSliceCopy_v1
// partial offset on each dimension // partial offset on each dimension
static_for<0, nDim, 1>{}([&](auto IDim) { static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr index_t idim = IDim;
constexpr auto src_partial_original_dims = constexpr auto src_partial_original_dims =
SrcDesc::GetContainedOriginalDimensions(IDim); SrcDesc::GetContainedOriginalDimensions(IDim);
constexpr auto src_partial_original_desc = constexpr auto src_partial_original_desc =
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims); SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
mThreadSrcPartialOffsets(idim) = src_partial_original_desc.GetOffsetFromMultiIndex( mThreadSrcPartialOffsets(IDim) = src_partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims)); extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims));
}); });
static_for<0, nDim, 1>{}([&](auto IDim) { static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr index_t idim = IDim;
constexpr auto dst_partial_original_dims = constexpr auto dst_partial_original_dims =
DstDesc::GetContainedOriginalDimensions(IDim); DstDesc::GetContainedOriginalDimensions(IDim);
constexpr auto dst_partial_original_desc = constexpr auto dst_partial_original_desc =
DstDesc::GetOriginalTensorDescriptor().Extract(dst_partial_original_dims); DstDesc::GetOriginalTensorDescriptor().Extract(dst_partial_original_dims);
mThreadDstPartialOffsets(idim) = dst_partial_original_desc.GetOffsetFromMultiIndex( mThreadDstPartialOffsets(IDim) = dst_partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mThreadDstOriginalMultiId, dst_partial_original_dims)); extract_array(mThreadDstOriginalMultiId, dst_partial_original_dims));
}); });
...@@ -173,10 +181,8 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -173,10 +181,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr auto thread_tensor_desc = constexpr auto thread_tensor_desc =
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths); make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 #if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){}; static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
constexpr auto src_thread_data_multi_id_begin = constexpr auto src_thread_data_multi_id_begin =
repeat_multi_id * data_per_cluster_per_dims; repeat_multi_id * data_per_cluster_per_dims;
...@@ -189,14 +195,13 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -189,14 +195,13 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr index_t clipboard_offset = constexpr index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin); thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
#else #else
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){}); ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
const auto src_thread_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims; const auto src_thread_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths; const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
const index_t src_offset = const index_t src_offset =
SrcDesc{}.GetOffsetFromMultiIndex(src_thread_data_multi_id_begin); SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
const index_t clipboard_offset = const index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin); thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
...@@ -233,10 +238,8 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -233,10 +238,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
constexpr auto thread_tensor_desc = constexpr auto thread_tensor_desc =
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths); make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 #if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){}; static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
constexpr auto clipboard_data_multi_id_begin = constexpr auto clipboard_data_multi_id_begin =
repeat_multi_id * thread_sub_tensor_lengths; repeat_multi_id * thread_sub_tensor_lengths;
...@@ -246,10 +249,9 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -246,10 +249,9 @@ struct BlockwiseGenericTensorSliceCopy_v1
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin); thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
constexpr index_t dst_offset = constexpr index_t dst_offset =
DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin); DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin);
#else #else
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){}); ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id) {
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths; const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
const auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims; const auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
...@@ -257,7 +259,7 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -257,7 +259,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
const index_t clipboard_offset = const index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin); thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin); const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_multi_id_begin);
#endif #endif
// By position the origin of the per-thread window at the point, where multi-index // By position the origin of the per-thread window at the point, where multi-index
......
...@@ -59,7 +59,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, ...@@ -59,7 +59,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t B = (N * Ho * Wo) / (N1 * N2); constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
#if 1 #if 0
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16; constexpr index_t BPerBlock = 16;
...@@ -91,6 +91,40 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, ...@@ -91,6 +91,40 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 2, 2>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 8, 2>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 2;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#endif #endif
......
...@@ -454,7 +454,7 @@ int main(int argc, char* argv[]) ...@@ -454,7 +454,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 3x3 filter, 28x28 image // 3x3 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -464,12 +464,12 @@ int main(int argc, char* argv[]) ...@@ -464,12 +464,12 @@ int main(int argc, char* argv[])
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 512; constexpr index_t C = 512;
...@@ -479,6 +479,9 @@ int main(int argc, char* argv[]) ...@@ -479,6 +479,9 @@ int main(int argc, char* argv[])
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment