Commit 03f7892a authored by Chao Liu's avatar Chao Liu
Browse files

replacing array with vector for tensor data

parent e8421cca
...@@ -12,6 +12,9 @@ namespace ck { ...@@ -12,6 +12,9 @@ namespace ck {
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster, // MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster // MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename BlockMatrixA, typename BlockMatrixA,
typename BlockMatrixB, typename BlockMatrixB,
typename ThreadMatrixC, typename ThreadMatrixC,
...@@ -104,7 +107,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -104,7 +107,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
} }
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void __device__ void
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{ {
...@@ -150,7 +152,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -150,7 +152,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
NPerThreadSubC, NPerThreadSubC,
ThreadGemmBDataPerRead_N>{}; ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_mtx), constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<FloatA,
FloatB,
FloatC,
decltype(a_thread_mtx),
decltype(b_thread_mtx), decltype(b_thread_mtx),
decltype(c_thread_mtx)>{}; decltype(c_thread_mtx)>{};
// loop over k // loop over k
...@@ -180,7 +185,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -180,7 +185,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
}); });
} }
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void __device__ void
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{ {
...@@ -243,7 +247,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -243,7 +247,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
NPerThreadSubC, NPerThreadSubC,
ThreadGemmBDataPerRead_N>{}; ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx), constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<FloatA,
FloatB,
FloatC,
decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx), decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{}; decltype(c_thread_sub_mtx)>{};
...@@ -331,7 +338,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -331,7 +338,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC))); p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
} }
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const __device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{ {
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE #if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
...@@ -540,7 +546,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -540,7 +546,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
FloatA p_a_thread[a_thread_mtx_desc_.GetElementSpaceSize()]; FloatA p_a_thread[a_thread_mtx_desc_.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx_desc_.GetElementSpaceSize()]; FloatB p_b_thread[b_thread_mtx_desc_.GetElementSpaceSize()];
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx), constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<FloatA,
FloatB,
FloatC,
decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx), decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{}; decltype(c_thread_sub_mtx)>{};
......
...@@ -1429,6 +1429,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1429,6 +1429,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) { static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// position in slice window // position in slice window
#if 0 // debug #if 0 // debug
// TODO: unable to compile
constexpr auto data_to_origin_disp_idx = constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) * container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access; src_scalar_per_access;
......
...@@ -57,7 +57,10 @@ struct ThreadwiseMatrixSliceCopy_v2 ...@@ -57,7 +57,10 @@ struct ThreadwiseMatrixSliceCopy_v2
// C[M, N] += transpose(A[K, M]) * B[K, N] // C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data // Element of matrix can be vectorized data
template <typename ADesc, template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc, typename BDesc,
typename CDesc, typename CDesc,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
...@@ -65,7 +68,6 @@ template <typename ADesc, ...@@ -65,7 +68,6 @@ template <typename ADesc,
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_km_kn_mn_v1 struct ThreadwiseGemm_km_kn_mn_v1
{ {
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) __device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
...@@ -94,7 +96,6 @@ struct ThreadwiseGemm_km_kn_mn_v1 ...@@ -94,7 +96,6 @@ struct ThreadwiseGemm_km_kn_mn_v1
} }
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) __device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
...@@ -157,7 +158,6 @@ struct ThreadwiseGemm_km_kn_mn_v1 ...@@ -157,7 +158,6 @@ struct ThreadwiseGemm_km_kn_mn_v1
} }
#endif #endif
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) __device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
......
...@@ -5,10 +5,14 @@ ...@@ -5,10 +5,14 @@
namespace ck { namespace ck {
template <typename T, index_t N> template <
struct StaticBuffer : public vector_type_maker<T, N>::type typename ScalarType,
index_t N,
typename std::enable_if<is_same<typename scalar_type<ScalarType>::type, ScalarType>::value,
bool>::type = false>
struct StaticBuffer : public vector_type<ScalarType, N>
{ {
using base = typename vector_type_maker<T, N>::type; using base = vector_type<ScalarType, N>;
__host__ __device__ constexpr StaticBuffer() : base{} {} __host__ __device__ constexpr StaticBuffer() : base{} {}
}; };
...@@ -16,7 +20,60 @@ struct StaticBuffer : public vector_type_maker<T, N>::type ...@@ -16,7 +20,60 @@ struct StaticBuffer : public vector_type_maker<T, N>::type
template <typename T, index_t N> template <typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>) __host__ __device__ constexpr auto make_static_buffer(Number<N>)
{ {
return StaticBuffer<T, N>{}; using scalar_t = scalar_type<T>;
constexpr index_t scalar_per_vector = scalar_type<T>::vector_size;
return StaticBuffer<scalar_t, N * scalar_per_vector>{};
}
template <
typename ScalarType,
typename std::enable_if<is_same<typename scalar_type<ScalarType>::type, ScalarType>::value,
bool>::type = false>
struct DynamicBuffer
{
template <typename T>
struct PointerWrapper
{
T* p_;
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_[i]; }
};
ScalarType* p_scalar_;
__host__ __device__ constexpr DynamicBuffer(ScalarType* p_scalar) : p_scalar_{p_scalar} {}
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
ScalarType>::value,
bool>::type = false>
__host__ __device__ constexpr const auto& AsType() const
{
return PointerWrapper<X>{reinterpret_cast<X*>(p_scalar_)};
}
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
ScalarType>::value,
bool>::type = false>
__host__ __device__ constexpr auto& AsType()
{
return PointerWrapper<X>{reinterpret_cast<X*>(p_scalar_)};
}
};
template <typename T>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p)
{
using scalar_t = scalar_type<T>;
constexpr index_t scalar_per_vector = scalar_type<T>::vector_size;
return DynamicBuffer<scalar_t>{p};
} }
} // namespace ck } // namespace ck
......
...@@ -28,11 +28,11 @@ ...@@ -28,11 +28,11 @@
#endif #endif
// launch bounds // launch bounds
#define CK_USE_LAUNCH_BOUNDS 0 #define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS #ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256 #define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MIN_BLOCK_PER_CU 1 #define CK_MIN_BLOCK_PER_CU 2
#endif #endif
// buffer resourse // buffer resourse
......
...@@ -728,9 +728,8 @@ int main(int argc, char* argv[]) ...@@ -728,9 +728,8 @@ int main(int argc, char* argv[])
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
out_data_t> out_data_t>(
in_nchw_desc,
(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment