Commit 03f7892a authored by Chao Liu's avatar Chao Liu
Browse files

replacing array with vector for tensor data

parent e8421cca
......@@ -12,6 +12,9 @@ namespace ck {
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
......@@ -104,7 +107,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
......@@ -150,7 +152,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
NPerThreadSubC,
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_mtx),
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<FloatA,
FloatB,
FloatC,
decltype(a_thread_mtx),
decltype(b_thread_mtx),
decltype(c_thread_mtx)>{};
// loop over k
......@@ -180,7 +185,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
});
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
......@@ -243,7 +247,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
NPerThreadSubC,
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx),
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<FloatA,
FloatB,
FloatC,
decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
......@@ -331,7 +338,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
......@@ -540,7 +546,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
FloatA p_a_thread[a_thread_mtx_desc_.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx_desc_.GetElementSpaceSize()];
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx),
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<FloatA,
FloatB,
FloatC,
decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
......
......@@ -1429,6 +1429,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// position in slice window
#if 0 // debug
// TODO: unable to compile
constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access;
......
......@@ -57,7 +57,10 @@ struct ThreadwiseMatrixSliceCopy_v2
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
template <typename ADesc,
template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc,
typename CDesc,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
......@@ -65,7 +68,6 @@ template <typename ADesc,
bool>::type = false>
struct ThreadwiseGemm_km_kn_mn_v1
{
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
......@@ -94,7 +96,6 @@ struct ThreadwiseGemm_km_kn_mn_v1
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
......@@ -157,7 +158,6 @@ struct ThreadwiseGemm_km_kn_mn_v1
}
#endif
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
......
......@@ -5,10 +5,14 @@
namespace ck {
template <typename T, index_t N>
struct StaticBuffer : public vector_type_maker<T, N>::type
template <
typename ScalarType,
index_t N,
typename std::enable_if<is_same<typename scalar_type<ScalarType>::type, ScalarType>::value,
bool>::type = false>
struct StaticBuffer : public vector_type<ScalarType, N>
{
using base = typename vector_type_maker<T, N>::type;
using base = vector_type<ScalarType, N>;
__host__ __device__ constexpr StaticBuffer() : base{} {}
};
......@@ -16,7 +20,60 @@ struct StaticBuffer : public vector_type_maker<T, N>::type
template <typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
return StaticBuffer<T, N>{};
using scalar_t = scalar_type<T>;
constexpr index_t scalar_per_vector = scalar_type<T>::vector_size;
return StaticBuffer<scalar_t, N * scalar_per_vector>{};
}
template <
typename ScalarType,
typename std::enable_if<is_same<typename scalar_type<ScalarType>::type, ScalarType>::value,
bool>::type = false>
struct DynamicBuffer
{
template <typename T>
struct PointerWrapper
{
T* p_;
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_[i]; }
};
ScalarType* p_scalar_;
__host__ __device__ constexpr DynamicBuffer(ScalarType* p_scalar) : p_scalar_{p_scalar} {}
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
ScalarType>::value,
bool>::type = false>
__host__ __device__ constexpr const auto& AsType() const
{
return PointerWrapper<X>{reinterpret_cast<X*>(p_scalar_)};
}
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
ScalarType>::value,
bool>::type = false>
__host__ __device__ constexpr auto& AsType()
{
return PointerWrapper<X>{reinterpret_cast<X*>(p_scalar_)};
}
};
template <typename T>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p)
{
using scalar_t = scalar_type<T>;
constexpr index_t scalar_per_vector = scalar_type<T>::vector_size;
return DynamicBuffer<scalar_t>{p};
}
} // namespace ck
......
......@@ -28,11 +28,11 @@
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS 0
#define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MIN_BLOCK_PER_CU 1
#define CK_MIN_BLOCK_PER_CU 2
#endif
// buffer resourse
......
......@@ -728,19 +728,18 @@ int main(int argc, char* argv[])
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
in_vector_size,
acc_data_t,
out_data_t>
(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
out_data_t>(
in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment