Commit 88b77181 authored by Chao Liu's avatar Chao Liu
Browse files

rename files, added header guard, added namespace

parent 05e04665
#pragma once #ifndef CK_BLOCKWISE_4D_TENSOR_OP_HPP
#define CK_BLOCKWISE_4D_TENSOR_OP_HPP
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
#include "threadwise_tensor_slice_op.hpp" #include "threadwise_tensor_slice_copy.hpp"
namespace ck {
template <index_t BlockSize, class Float, class DstDesc, class F> template <index_t BlockSize, class Float, class DstDesc, class F>
__device__ void __device__ void
...@@ -235,7 +239,7 @@ struct Blockwise4dTensorCopy1 ...@@ -235,7 +239,7 @@ struct Blockwise4dTensorCopy1
// but we need to make sure dst stride2 is big enough, // but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst // so that the out-of-bound write won't contaminate next line in dst
constexpr index_t L3 = CopyLengths{}.Get(I3); constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr index_t read_per_d3 = mod_conv::integer_divide_ceil(L3, DataPerRead); constexpr index_t read_per_d3 = math::integer_divide_ceil(L3, DataPerRead);
static_assert(read_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2), static_assert(read_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
"wrong! out-of-bound write will contaminate next line!\n"); "wrong! out-of-bound write will contaminate next line!\n");
...@@ -256,7 +260,7 @@ struct Blockwise4dTensorCopy1 ...@@ -256,7 +260,7 @@ struct Blockwise4dTensorCopy1
constexpr index_t L2 = CopyLengths{}.Get(I2); constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t L3 = CopyLengths{}.Get(I3); constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr index_t read_per_d3 = mod_conv::integer_divide_ceil(L3, DataPerRead); constexpr index_t read_per_d3 = math::integer_divide_ceil(L3, DataPerRead);
constexpr auto ref_desc = constexpr auto ref_desc =
make_ConstantTensorDescriptor_packed(Sequence<L0, L1, L2, read_per_d3>{}); make_ConstantTensorDescriptor_packed(Sequence<L0, L1, L2, read_per_d3>{});
...@@ -488,7 +492,7 @@ struct Blockwise4dTensorCopy3 ...@@ -488,7 +492,7 @@ struct Blockwise4dTensorCopy3
// we allow out-of-bound read from src in D3 dimension, // we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride is big enough, // but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst // so that the out-of-bound write won't contaminate next line in dst
constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead); constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2), static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
"wrong! out-of-bound write will contaminate next line!\n"); "wrong! out-of-bound write will contaminate next line!\n");
...@@ -500,7 +504,7 @@ struct Blockwise4dTensorCopy3 ...@@ -500,7 +504,7 @@ struct Blockwise4dTensorCopy3
"wrrong! BlockSize is not big enough for ThreadPerDims!"); "wrrong! BlockSize is not big enough for ThreadPerDims!");
constexpr index_t num_active_thread = constexpr index_t num_active_thread =
accumulate_on_sequence(ThreadPerDims{}, mod_conv::multiplies<index_t>{}, Number<1>{}); accumulate_on_sequence(ThreadPerDims{}, math::multiplies<index_t>{}, Number<1>{});
if(BlockSize > num_active_thread) if(BlockSize > num_active_thread)
{ {
...@@ -556,7 +560,7 @@ struct Blockwise4dTensorCopy3 ...@@ -556,7 +560,7 @@ struct Blockwise4dTensorCopy3
constexpr index_t nloop_d0 = L0 / thread_per_d0; constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1; constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = L2 / thread_per_d2; constexpr index_t nloop_d2 = L2 / thread_per_d2;
constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead); constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
#pragma unroll #pragma unroll
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0) for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
...@@ -613,7 +617,7 @@ struct Blockwise4dTensorCopy3 ...@@ -613,7 +617,7 @@ struct Blockwise4dTensorCopy3
constexpr index_t nloop_d0 = L0 / thread_per_d0; constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1; constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = L2 / thread_per_d2; constexpr index_t nloop_d2 = L2 / thread_per_d2;
constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead); constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2 * nloop_d3; return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2 * nloop_d3;
} }
...@@ -650,7 +654,7 @@ struct Blockwise4dTensorCopy3 ...@@ -650,7 +654,7 @@ struct Blockwise4dTensorCopy3
constexpr index_t nloop_d0 = L0 / thread_per_d0; constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1; constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = L2 / thread_per_d2; constexpr index_t nloop_d2 = L2 / thread_per_d2;
constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead); constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
constexpr auto clipboard_desc = make_ConstantTensorDescriptor_packed( constexpr auto clipboard_desc = make_ConstantTensorDescriptor_packed(
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{}); Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
...@@ -717,7 +721,7 @@ struct Blockwise4dTensorCopy3 ...@@ -717,7 +721,7 @@ struct Blockwise4dTensorCopy3
constexpr index_t nloop_d0 = L0 / thread_per_d0; constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1; constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = L2 / thread_per_d2; constexpr index_t nloop_d2 = L2 / thread_per_d2;
constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead); constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
constexpr auto clipboard_desc = make_ConstantTensorDescriptor_packed( constexpr auto clipboard_desc = make_ConstantTensorDescriptor_packed(
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{}); Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
...@@ -768,3 +772,7 @@ struct Blockwise4dTensorCopyReorder1 ...@@ -768,3 +772,7 @@ struct Blockwise4dTensorCopyReorder1
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy); SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
} }
}; };
} // namespace
#endif
#pragma once #ifndef CK_BLOCKWISE_BATCHED_GEMM_HPP
#define CK_BLOCKWISE_BATCHED_GEMM_HPP
#include "threadwise_gemm.hpp" #include "threadwise_gemm.hpp"
namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
class BlockMatrixA, class BlockMatrixA,
class BlockMatrixB, class BlockMatrixB,
...@@ -287,7 +291,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -287,7 +291,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
} }
} }
#if USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run_asm(const FloatA* __restrict__ p_a_block, __device__ void Run_asm(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block, const FloatB* __restrict__ p_b_block,
...@@ -518,3 +522,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -518,3 +522,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
} }
} }
}; };
} // namespace
#endif
#pragma once #ifndef CK_BLOCKWISE_GEMM_HPP
#define CK_BLOCKWISE_GEMM_HPP
#include "common.hpp" #include "common.hpp"
#include "threadwise_gemm.hpp" #include "threadwise_gemm.hpp"
namespace ck {
// if following number are power of 2, index calculation shall be greatly reduced: // if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster // MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
template <index_t BlockSize, template <index_t BlockSize,
...@@ -109,7 +113,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -109,7 +113,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
n_repeat * NPerLevel1Cluster + n_in_sub_c}; n_repeat * NPerLevel1Cluster + n_in_sub_c};
} }
#if USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
// TODO: this is not working correctly // TODO: this is not working correctly
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run_asm(const FloatA* __restrict__ p_a_block, __device__ void Run_asm(const FloatA* __restrict__ p_a_block,
...@@ -423,3 +427,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -423,3 +427,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
} }
} }
}; };
} // namespace ck
#endif
#pragma once #ifndef CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#include "threadwise_tensor_slice_op.hpp" #define CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace ck {
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor // slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst // memory layout (ordering of dimensions) can be different between src and dst
...@@ -142,10 +146,10 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -142,10 +146,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
// complete offset // complete offset
mThreadSrcOffset = accumulate_on_array( mThreadSrcOffset = accumulate_on_array(
mThreadSrcPartialOffsets, mod_conv::plus<index_t>{}, static_cast<index_t>(0)); mThreadSrcPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
mThreadDstOffset = accumulate_on_array( mThreadDstOffset = accumulate_on_array(
mThreadDstPartialOffsets, mod_conv::plus<index_t>{}, static_cast<index_t>(0)); mThreadDstPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
#if 0 #if 0
if(get_block_1d_id() == 0) if(get_block_1d_id() == 0)
...@@ -388,3 +392,7 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -388,3 +392,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
}); });
} }
}; };
} // namespace ck
#endif
#pragma once #ifndef CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
#include "threadwise_tensor_slice_op.hpp" #define CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
#include "threadwise_tensor_slice_copy.hpp"
namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
class Float, class Float,
...@@ -120,7 +124,7 @@ struct BlockwiseTensorSliceReorderCopy_v3 ...@@ -120,7 +124,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
// compiler: will it really compute index here, or be merged with // compiler: will it really compute index here, or be merged with
// GetOffsetFromMultiIndex and // GetOffsetFromMultiIndex and
// optimized away??? // optimized away???
src_data_multi_id[i] *= src_sub_lengths.Get(I); src_data_multi_id(i) *= src_sub_lengths.Get(I);
}); });
// compiler: will it really compute index here, or be merged with GetOffsetFromMultiIndex // compiler: will it really compute index here, or be merged with GetOffsetFromMultiIndex
...@@ -167,10 +171,8 @@ struct BlockwiseTensorSliceReorderCopy_v3 ...@@ -167,10 +171,8 @@ struct BlockwiseTensorSliceReorderCopy_v3
constexpr auto src_data_per_cluster_per_dims = constexpr auto src_data_per_cluster_per_dims =
thread_sub_tensor_lengths * SrcClusterLengths{}; thread_sub_tensor_lengths * SrcClusterLengths{};
constexpr auto repeat_lengths = constexpr auto repeat_lengths = transform_sequences(
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{}, math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
SrcLengths{},
src_data_per_cluster_per_dims);
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths; constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
...@@ -188,10 +190,8 @@ struct BlockwiseTensorSliceReorderCopy_v3 ...@@ -188,10 +190,8 @@ struct BlockwiseTensorSliceReorderCopy_v3
constexpr auto src_data_per_cluster_per_dims = constexpr auto src_data_per_cluster_per_dims =
thread_sub_tensor_lengths * SrcClusterLengths{}; thread_sub_tensor_lengths * SrcClusterLengths{};
constexpr auto repeat_lengths = constexpr auto repeat_lengths = transform_sequences(
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{}, math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
SrcLengths{},
src_data_per_cluster_per_dims);
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths; constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
...@@ -226,10 +226,8 @@ struct BlockwiseTensorSliceReorderCopy_v3 ...@@ -226,10 +226,8 @@ struct BlockwiseTensorSliceReorderCopy_v3
constexpr auto src_data_per_cluster_per_dims = constexpr auto src_data_per_cluster_per_dims =
thread_sub_tensor_lengths * SrcClusterLengths{}; thread_sub_tensor_lengths * SrcClusterLengths{};
constexpr auto repeat_lengths = constexpr auto repeat_lengths = transform_sequences(
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{}, math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
SrcLengths{},
src_data_per_cluster_per_dims);
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths; constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
...@@ -294,3 +292,6 @@ struct BlockwiseTensorSliceReorderCopy_v3 ...@@ -294,3 +292,6 @@ struct BlockwiseTensorSliceReorderCopy_v3
}).Else([&](auto fwd) { mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim); }); }).Else([&](auto fwd) { mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim); });
} }
}; };
} // namespace ck
#endif
#pragma once #ifndef CK_COMMON_HPP
#include "base.hpp" #define CK_COMMON_HPP
#include "utility.hpp"
#include "vector_type.hpp" #include "vector_type.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "Sequence.hpp" #include "Sequence.hpp"
...@@ -8,6 +10,8 @@ ...@@ -8,6 +10,8 @@
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
#if USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp" #include "amd_inline_asm.hpp"
#endif #endif
#endif
#pragma once #ifndef CK_CONFIG_HPP
#define CK_CONFIG_HPP
#cmakedefine01 DEVICE_BACKEND_HIP #cmakedefine01 DEVICE_BACKEND_HIP
#cmakedefine01 DEVICE_BACKEND_CUDA #cmakedefine01 DEVICE_BACKEND_CUDA
#if DEVICE_BACKEND_HIP #if DEVICE_BACKEND_HIP
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
#define USE_AMD_INLINE_ASM 1 #define CK_USE_AMD_INLINE_ASM 1
// For some reason, HIP compiler need this definition to generate optimal load and store
// instruction
typedef float float2_t __attribute__((ext_vector_type(2)));
typedef float float4_t __attribute__((ext_vector_type(4)));
#elif DEVICE_BACKEND_CUDA #elif DEVICE_BACKEND_CUDA
#include "cuda_runtime.h" #include "cuda_runtime.h"
#include "cuda_fp16.h" #include "cuda_fp16.h"
#include "nvToolsExt.h" #include "nvToolsExt.h"
#include "helper_cuda.h" #include "helper_cuda.h"
#define USE_AMD_INLINE_ASM 0 #define CK_USE_AMD_INLINE_ASM 0
#endif
namespace ck {
#if DEVICE_BACKEND_HIP
// For some reason, HIP compiler need this definition to generate optimal load and store
// instruction
typedef float float2_t __attribute__((ext_vector_type(2)));
typedef float float4_t __attribute__((ext_vector_type(4)));
#else
// For some reason, CUDA need this definition, otherwise // For some reason, CUDA need this definition, otherwise
// compiler won't generate optimal load and store instruction, and // compiler won't generate optimal load and store instruction, and
// kernel would produce wrong result, indicating the compiler fail to generate correct // kernel would produce wrong result, indicating the compiler fail to generate correct
...@@ -58,3 +65,7 @@ __device__ void fused_multiply_accumulate(int32_t& d, const int32_t& s0, const i ...@@ -58,3 +65,7 @@ __device__ void fused_multiply_accumulate(int32_t& d, const int32_t& s0, const i
#endif #endif
} }
#endif #endif
} // namespace ck
#endif
#pragma once #ifndef CK_CONV_COMMON_HPP
#define CK_CONV_COMMON_HPP
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
using namespace ck;
// this is ugly, only for 4d // this is ugly, only for 4d
template <class InDesc, class WeiDesc> template <class InDesc, class WeiDesc>
constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDesc) constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDesc)
...@@ -117,3 +121,5 @@ constexpr std::size_t calculate_convolution_memory_size(Float, InDesc, WeiDesc, ...@@ -117,3 +121,5 @@ constexpr std::size_t calculate_convolution_memory_size(Float, InDesc, WeiDesc,
return sizeof(Float) * return sizeof(Float) *
(InDesc::GetElementSpace() + WeiDesc::GetElementSpace() + OutDesc::GetElementSpace()); (InDesc::GetElementSpace() + WeiDesc::GetElementSpace() + OutDesc::GetElementSpace());
} }
#endif
#pragma once #ifndef CK_DEVICE_HPP
#define CK_DEVICE_HPP
#include <memory> #include <memory>
#include "config.h" #include "config.hpp"
using namespace ck;
struct DeviceMem struct DeviceMem
{ {
...@@ -56,3 +60,5 @@ float launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byt ...@@ -56,3 +60,5 @@ float launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byt
return timer.GetElapsedTime(); return timer.GetElapsedTime();
} }
#endif
#pragma once #ifndef CK_FUNCTIONAL_HPP
#define CK_FUNCTIONAL_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "Sequence.hpp" #include "Sequence.hpp"
namespace ck {
struct forwarder struct forwarder
{ {
template <typename T> template <typename T>
...@@ -70,3 +74,6 @@ struct static_if<false> ...@@ -70,3 +74,6 @@ struct static_if<false>
return Type{}; return Type{};
} }
}; };
} // namespace ck
#endif
#pragma once #ifndef CK_FUNCTIONAL2_HPP
#define CK_FUNCTIONAL2_HPP
#include "functional.hpp" #include "functional.hpp"
#include "Sequence.hpp" #include "Sequence.hpp"
namespace ck {
template <class> template <class>
struct static_for_impl; struct static_for_impl;
...@@ -59,3 +63,6 @@ accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/) ...@@ -59,3 +63,6 @@ accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
return result; return result;
} }
} // namespace ck
#endif
#pragma once #ifndef CK_FUNCTIONAL3_HPP
#define CK_FUNCTIONAL3_HPP
#include "functional.hpp" #include "functional.hpp"
#include "functional2.hpp" #include "functional2.hpp"
#include "Sequence.hpp" #include "Sequence.hpp"
#include "Array.hpp" #include "Array.hpp"
namespace ck {
// RemainLengths: Sequence<...> // RemainLengths: Sequence<...>
template <class RemainLengths> template <class RemainLengths>
struct static_ford_impl struct static_ford_impl
...@@ -107,3 +111,6 @@ struct ford ...@@ -107,3 +111,6 @@ struct ford
} }
} }
}; };
} // namespace ck
#endif
#pragma once #ifndef CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW
#define CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW
#include "common.hpp" #include "common.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_op.hpp" #include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_direct_convolution.hpp" #include "threadwise_direct_convolution.hpp"
namespace ck {
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
class Float, class Float,
...@@ -245,3 +249,6 @@ struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw ...@@ -245,3 +249,6 @@ struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw
Number<1>{}); Number<1>{});
} }
}; };
} // namespace ck
#endif
#pragma once #ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN
#include "common.hpp" #include "common.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
#include "threadwise_tensor_slice_op.hpp" #include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp" #include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp" #include "blockwise_batched_gemm.hpp"
namespace ck {
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
class Float, class Float,
...@@ -103,10 +107,10 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -103,10 +107,10 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
// tensor view of blockwise input and weight in LDS // tensor view of blockwise input and weight in LDS
// be careful of alignment // be careful of alignment
constexpr index_t max_align = mod_conv::lcm(InBlockCopyDataPerRead_N, constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N,
WeiBlockCopyDataPerRead_K, WeiBlockCopyDataPerRead_K,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB); GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{},
...@@ -119,11 +123,11 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -119,11 +123,11 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
constexpr auto wei_cyx_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_cyx_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock * Y * X, KPerBlock>{}, Sequence<CPerBlock * Y * X, KPerBlock>{},
Number<mod_conv::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{}); Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
constexpr auto wei_c_y_x_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_c_y_x_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, Y, X, KPerBlock>{}, Sequence<CPerBlock, Y, X, KPerBlock>{},
Number<mod_conv::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{}); Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
// tensor view of threadwise output in register // tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
...@@ -390,3 +394,6 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -390,3 +394,6 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
}); });
} }
}; };
} // namespace ck
#endif
#pragma once #ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN
#include "common.hpp" #include "common.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
#include "blockwise_3d_tensor_op.hpp" #include "blockwise_3d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_op.hpp" #include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp" #include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp" #include "blockwise_batched_gemm.hpp"
namespace ck {
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
class Float, class Float,
...@@ -104,10 +108,10 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -104,10 +108,10 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// LDS tensor view // LDS tensor view
// be careful of alignment // be careful of alignment
constexpr index_t max_align = mod_conv::lcm(InBlockCopyDataPerRead_N, constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N,
WeiBlockCopyDataPerRead_K, WeiBlockCopyDataPerRead_K,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB); GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HoPerBlock, WiPerBlock, NPerBlock>{}, Sequence<CPerBlock, HoPerBlock, WiPerBlock, NPerBlock>{},
...@@ -120,7 +124,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -120,7 +124,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
constexpr auto wei_c_x_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_c_x_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, X, KPerBlock>{}, Sequence<CPerBlock, X, KPerBlock>{},
Number<mod_conv::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{}); Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
// tensor view of threadwise output in register // tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
...@@ -426,3 +430,6 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -426,3 +430,6 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
}); });
} }
}; };
} // namespace ck
#endif
#pragma once
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_3d_tensor_op.hpp"
#include "blockwise_tensor_slice_op.hpp"
#include "threadwise_tensor_slice_op.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockReorderSrcSubLengths_NCHW,
class InBlockReorderSrcClusterLengths_NCHW,
class InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
index_t InBlockReorderDataPerRead_W,
index_t InBlockReorderDataPerWrite_N,
class WeiBlockCopyClusterLengths_CXK,
index_t WeiBlockCopyDataPerRead_K,
index_t OutThreadCopyDataPerWrite_N>
struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// be careful of this assertion
static_assert(
NPerBlock % NPerThread == 0 &&
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0)),
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
// divide block work: [K, Ho, Wo, N]
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin;
// global tensor view
constexpr auto wei_c_x_k_global_desc =
make_ConstantTensorDescriptor(Sequence<C, X, K>{}, Sequence<Y * X * K, K, 1>{});
// LDS tensor view
// be careful of alignment
constexpr index_t max_align = mod_conv::lcm(InBlockReorderDataPerWrite_N,
WeiBlockCopyDataPerRead_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HoPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{});
constexpr auto wei_c_x_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, X, KPerBlock>{}, Number<max_align>{});
// tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3<
BlockSize,
Float,
decltype(in_n_c_h_w_global_desc),
decltype(in_c_h_w_n_block_desc),
Sequence<NPerBlock, CPerBlock, HoPerBlock, WiPerBlock>,
InBlockReorderSrcSubLengths_NCHW,
InBlockReorderSrcClusterLengths_NCHW,
decltype(map_chwn2nchw),
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
InBlockReorderDataPerRead_W,
InBlockReorderDataPerWrite_N>{};
// blockwise wei copy
// format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy =
Blockwise3dTensorCopy3<BlockSize,
Float,
decltype(wei_c_x_k_global_desc),
decltype(wei_c_x_k_block_desc),
decltype(wei_c_x_k_block_desc.GetLengths()),
WeiBlockCopyClusterLengths_CXK,
WeiBlockCopyDataPerRead_K>{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_x_k_block_desc.GetStride(I0)>{});
constexpr auto b_c_wn_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{},
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
constexpr auto c_k_wn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{},
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
const auto blockwise_batch_gemm =
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
BlockSize,
decltype(a_c_k_block_mtx_desc),
decltype(b_c_wn_block_mtx_desc),
decltype(c_k_wn_thread_mtx_desc),
0,
in_c_h_w_n_block_desc.GetStride(I1),
out_k_h_w_n_thread_desc.GetStride(I1),
HoPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t in_block_space =
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space =
wei_c_x_k_block_desc.GetElementSpace(Number<max_align>{});
__shared__ Float p_in_block[in_block_space];
__shared__ Float p_wei_block[wei_block_space];
// register
Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()];
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
#if 0
const Float* p_in_global_block_offset =
p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{
for(index_t y = 0; y < Y; ++y)
{
blockwise_in_copy_reorder.Run(p_in_global_block_offset +
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, 0, 0),
p_wei_block);
__syncthreads();
for(index_t x = 0; x < X; ++x)
{
blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.GetOffsetFromMultiIndex(0, x, 0),
p_in_block +
in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, 0, x, 0),
p_out_thread);
}
__syncthreads();
}
}
#else
for(index_t y = 0; y < Y; ++y)
{
const Float* p_in_global_block_offset =
p_in_global +
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin);
const Float* p_wei_global_block_offset =
p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, 0, k_block_data_begin);
for(index_t
c_block_data_begin = 0;
c_block_data_begin < C;
c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_clipboard);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block);
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block);
__syncthreads();
for(index_t x = 0; x < X; ++x)
{
blockwise_batch_gemm.Run(
p_wei_block + wei_c_x_k_block_desc.GetOffsetFromMultiIndex(0, x, 0),
p_in_block + in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, 0, x, 0),
p_out_thread);
}
__syncthreads();
}
}
#endif
// output: register to global mem,
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin =
c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
// output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2), W1, W2, N / (N1 * N2), N1, N2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_10d_tensor_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
}
};
#pragma once #ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN
#include "common.hpp" #include "common.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_op.hpp" #include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp" #include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp" #include "blockwise_batched_gemm.hpp"
namespace ck {
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
class Float, class Float,
...@@ -74,10 +78,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -74,10 +78,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup "); "wrong! cannot evenly divide work for workgroup ");
constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock); constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock); constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock); constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock); constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
constexpr auto block_work_desc = make_ConstantTensorDescriptor( constexpr auto block_work_desc = make_ConstantTensorDescriptor(
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{}); Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
...@@ -99,10 +103,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -99,10 +103,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// LDS tensor view // LDS tensor view
// be careful of alignment // be careful of alignment
constexpr index_t max_align = mod_conv::lcm(InBlockCopyDataPerRead_N, constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N,
WeiBlockCopyDataPerRead_K, WeiBlockCopyDataPerRead_K,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB); GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{}, Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
...@@ -115,7 +119,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -115,7 +119,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, KPerBlock>{}, Sequence<CPerBlock, KPerBlock>{},
Number<mod_conv::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{}); Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
// tensor view of threadwise output in register // tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
...@@ -416,3 +420,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -416,3 +420,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
}); });
} }
}; };
} // namespace ck
#endif
#pragma once #ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#include "common.hpp" #include "common.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_op.hpp" #include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp" #include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp" #include "blockwise_batched_gemm.hpp"
namespace ck {
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
class Float, class Float,
...@@ -36,7 +40,7 @@ template <index_t GridSize, ...@@ -36,7 +40,7 @@ template <index_t GridSize,
index_t InBlockCopyDataPerRead_N, index_t InBlockCopyDataPerRead_N,
index_t WeiBlockCopyDataPerRead_K, index_t WeiBlockCopyDataPerRead_K,
index_t OutThreadCopyDataPerWrite_N> index_t OutThreadCopyDataPerWrite_N>
struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
...@@ -80,10 +84,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -80,10 +84,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup "); "wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock); constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock); constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock); constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock); constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed( constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed(
Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{}); Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{});
...@@ -104,10 +108,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -104,10 +108,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
// LDS tensor view // LDS tensor view
// be careful of alignment // be careful of alignment
constexpr index_t max_align = mod_conv::lcm(InBlockCopyDataPerRead_N, constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N,
WeiBlockCopyDataPerRead_K, WeiBlockCopyDataPerRead_K,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB); GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{}, Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
...@@ -120,7 +124,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -120,7 +124,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, KPerBlock>{}, Sequence<CPerBlock, KPerBlock>{},
Number<mod_conv::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{}); Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
// tensor view of threadwise output in register // tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed( constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
...@@ -466,3 +470,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -466,3 +470,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
}); });
} }
}; };
} // namespace ck
#endif
#pragma once
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_tensor_slice_op.hpp"
#include "threadwise_tensor_slice_op.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockReorderSrcSubLengths_NCHW,
class InBlockReorderSrcClusterLengths_NCHW,
class InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
index_t InBlockReorderDataPerRead_W,
index_t InBlockReorderDataPerWrite_N,
class WeiBlockCopyClusterLengths_CK, // not used
index_t WeiBlockCopyDataPerRead_K,
index_t OutThreadCopyDataPerWrite_N>
struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// be careful of this assertion
static_assert(
NPerBlock % NPerThread == 0 &&
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0)),
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
// assert for LDS double buffer
static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided");
// divide block work: [K, Ho, Wo, N]
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup ");
constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock);
constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed(
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin;
// global tensor view
constexpr auto wei_c_k_global_desc =
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
// LDS tensor view
// be careful of alignment
constexpr index_t max_align = mod_conv::lcm(InBlockReorderDataPerWrite_N,
WeiBlockCopyDataPerRead_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
Number<InBlockReorderDataPerWrite_N>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not meet");
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, KPerBlock>{},
Number<mod_conv::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
// tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3<
BlockSize,
Float,
decltype(in_n_c_h_w_global_desc),
decltype(in_c_h_w_n_block_desc),
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
InBlockReorderSrcSubLengths_NCHW,
InBlockReorderSrcClusterLengths_NCHW,
decltype(map_chwn2nchw),
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
InBlockReorderDataPerRead_W,
InBlockReorderDataPerWrite_N>({0, 0, 0, 0}, {0, 0, 0, 0});
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopyDataPerRead_K>{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
constexpr auto b_c_wn_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{},
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
constexpr auto c_k_wn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{},
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
const auto blockwise_batch_gemm =
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
BlockSize,
decltype(a_c_k_block_mtx_desc),
decltype(b_c_wn_block_mtx_desc),
decltype(c_k_wn_thread_mtx_desc),
0,
in_c_h_w_n_block_desc.GetStride(I1),
out_k_h_w_n_thread_desc.GetStride(I1),
HoPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// choose GEMM implementation here
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
#if 1
return blockwise_batch_gemm.Run(Xs...);
#elif 0
return blockwise_batch_gemm.Run_asm(Xs...);
#else
return blockwise_batch_gemm.Run_asm_v2(Xs...);
#endif
};
// LDS: be careful of alignment
constexpr index_t in_block_space =
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
// LDS double buffer
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register
// C++ lambda doesn't capture array, use pointer instead
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
Float* const p_out_thread = p_out_thread_data;
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
const Float* p_in_global_block_offset =
p_in_global +
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
const Float* p_wei_global_block_offset =
p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
// LDS double buffer: preload data into LDS
{
Float p_in_register_clipboard[blockwise_in_copy_reorder
.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double);
}
// LDS double buffer: main body
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
c_block_data_begin += 2 * CPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy_reorder
.GetRegisterClipboardSize()];
Float
p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
p_in_global_block_offset +=
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
p_wei_global_block_offset +=
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_clipboard[blockwise_in_copy_reorder
.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
// even iteration
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy_reorder.RunStoreRegisterClipboard(
p_in_register_clipboard, p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(
p_wei_register_clipboard, p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
}
}
// output: register to global mem,
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) { // fwd do nothing but
// perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_assert(
(fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!");
// output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
#if 0
constexpr auto out_10d_global_desc =
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
K1,
K2,
Ho,
Wo / (W1 * W2),
W1,
W2,
N / fwd(N1 * N2),
N1,
N2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
#else
constexpr auto out_10d_global_desc = fwd(out_k_h_w_n_global_desc)
.Fold(I3, Number<N1>{}, Number<N2>{})
.Fold(I2, Number<W1>{}, Number<W2>{})
.Fold(I0, Number<K1>{}, Number<K2>{});
constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
.Fold(I3, Number<1>{}, Number<N2>{})
.Fold(I2, Number<W1>{}, Number<1>{})
.Fold(I0, Number<1>{}, Number<K2>{});
#endif
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_tensor_slice_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
}).Else([&](auto fwd) {
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0,
"wrong!");
// output is a 10d tensor
constexpr index_t N1 = NPerBlock;
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
#if 0
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor_packed(
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor_packed(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
#else
constexpr auto out_10d_global_desc =
fwd(out_k_h_w_n_global_desc)
.Fold(I3, Number<N1>{})
.Fold(I2, Number<W1>{}, Number<W2>{}, Number<W3>{})
.Fold(I0, Number<K1>{}, Number<K2>{});
constexpr auto out_10d_thread_desc =
fwd(out_k_h_w_n_thread_desc)
.Fold(I3, Number<N1>{})
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
.Fold(I0, Number<1>{}, Number<K2>{});
#endif
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
for(index_t i = 0; i < 64; ++i)
{
printf("out %f, ", p_out_thread[i]);
}
}
#endif
threadwise_tensor_slice_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
});
}
};
#pragma once
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_tensor_slice_op.hpp"
#include "threadwise_tensor_slice_op.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockReorderSrcSubLengths_NCHW,
class InBlockReorderSrcClusterLengths_NCHW,
class InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
index_t InBlockReorderDataPerRead_W,
index_t InBlockReorderDataPerWrite_N,
class WeiBlockCopyClusterLengths_CK, // not used
index_t WeiBlockCopyDataPerRead_K,
index_t OutThreadCopyDataPerWrite_N>
struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// be careful of this assertion
static_assert(
NPerBlock % NPerThread == 0 &&
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0)),
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
// divide block work: [K, Ho, Wo, N]
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin;
// global tensor view
constexpr auto wei_c_k_global_desc =
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
// LDS tensor view
// be careful of alignment
constexpr index_t max_align = mod_conv::lcm(InBlockReorderDataPerWrite_N,
WeiBlockCopyDataPerRead_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
Number<InBlockReorderDataPerWrite_N>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not meet");
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, KPerBlock>{},
Number<mod_conv::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
// tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3<
BlockSize,
Float,
decltype(in_n_c_h_w_global_desc),
decltype(in_c_h_w_n_block_desc),
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
InBlockReorderSrcSubLengths_NCHW,
InBlockReorderSrcClusterLengths_NCHW,
decltype(map_chwn2nchw),
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
InBlockReorderDataPerRead_W,
InBlockReorderDataPerWrite_N>{};
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopyDataPerRead_K>{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
constexpr auto b_c_wn_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{},
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
constexpr auto c_k_wn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{},
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
const auto blockwise_batch_gemm =
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
BlockSize,
decltype(a_c_k_block_mtx_desc),
decltype(b_c_wn_block_mtx_desc),
decltype(c_k_wn_thread_mtx_desc),
0,
in_c_h_w_n_block_desc.GetStride(I1),
out_k_h_w_n_thread_desc.GetStride(I1),
HoPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// choose GEMM implementation here
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
#if 0
return blockwise_batch_gemm.Run(Xs...);
#elif 0
return blockwise_batch_gemm.Run_asm(Xs...);
#else
return blockwise_batch_gemm.Run_asm_v2(Xs...);
#endif
};
// LDS: be careful of alignment
constexpr index_t in_block_space =
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
__shared__ Float p_in_block[in_block_space];
__shared__ Float p_wei_block[wei_block_space];
// register
// C++ lambda doesn't capture array, use pointer instead
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
Float* const p_out_thread = p_out_thread_data;
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
#if 1
const Float* p_in_global_block_offset =
p_in_global +
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
const Float* p_wei_global_block_offset =
p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
#if 1
blockwise_in_copy_reorder.Run(
p_in_global_block_offset +
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x),
p_in_block);
blockwise_wei_copy.Run(
p_wei_global_block_offset +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_wei_block);
#else
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy_reorder.RunLoadRegisterClipboard(
p_in_global_block_offset +
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x),
p_in_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(
p_wei_global_block_offset +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_wei_clipboard);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block);
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block);
#endif
__syncthreads();
run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread);
__syncthreads();
}
}
}
#else
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
const Float* p_in_global_block_offset =
p_in_global +
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
const Float* p_wei_global_block_offset =
p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock,
p_in_global_block_offset +=
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
p_wei_global_block_offset +=
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{
#if 0
blockwise_in_copy_reorder.Run(p_in_global_block_offset,
p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset,
p_wei_block);
#else
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_clipboard);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block);
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block);
#endif
__syncthreads();
run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread);
__syncthreads();
}
}
}
#endif
// output: register to global mem,
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
// perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_assert(
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!");
// output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc =
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
K1,
K2,
Ho,
Wo / (W1 * W2),
W1,
W2,
N / f_dummy(N1 * N2),
N1,
N2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_tensor_slice_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
}).Else([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0,
"wrong!");
// output is a 10d tensor
constexpr index_t N1 = NPerBlock;
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
for(index_t i = 0; i < 64; ++i)
{
printf("out %f, ", p_out_thread[i]);
}
}
#endif
threadwise_tensor_slice_copy(out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
});
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment