"...resnet50_tensorflow.git" did not exist on "c5df7268120db7c7f9b3d40bdb059af663083efb"
Commit 8d15144c authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 7fde99f4
...@@ -71,13 +71,13 @@ template <index_t GridSize, ...@@ -71,13 +71,13 @@ template <index_t GridSize,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
index_t GemmNRepeat, index_t GemmNRepeat,
index_t GemmMPerThreadSubC, index_t GemmMPerThread,
index_t GemmNPerThreadSubC, index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster, index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA, index_t GemmDataPerReadA,
index_t GemmDataPerReadB, index_t GemmDataPerReadB,
typename InBlockCopySubLengths_E_N1_B_N2, typename InBlockCopySubLengths_E_N1_B_N2,
...@@ -114,11 +114,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -114,11 +114,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// this is a mess // this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters // TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr index_t N1 = GemmNRepeat; constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThread;
static_assert((N1 * N2 * BPerBlock) % static_assert(
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == (N1 * N2 * BPerBlock) % (GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
0,
"wrong!"); "wrong!");
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
...@@ -290,30 +289,29 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -290,30 +289,29 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
in_e_n1_b_n2_block_desc.GetStride(I0)); in_e_n1_b_n2_block_desc.GetStride(I0));
// sanity check // sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == static_assert(KPerBlock % (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0,
0,
"wrong!"); "wrong!");
constexpr index_t GemmMRepeat = constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); KPerBlock / (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{}); Number<GemmMRepeat * GemmMPerThread>{}, Number<GemmNRepeat * GemmNPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
decltype(a_e_k_block_mtx_desc), decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc), decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k1_n1n2_thread_mtx_desc), decltype(c_k0k1_n1n2_thread_mtx_desc),
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB>{}; GemmDataPerReadB>{};
...@@ -432,13 +430,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -432,13 +430,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// copy output: register to global memory // copy output: register to global memory
{ {
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t K1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t K0 = K / K1; constexpr index_t K0 = K / K1;
// define output tensor descriptor for threadwise copy // define output tensor descriptor for threadwise copy
// thread output tensor, src of threadwise copy // thread output tensor, src of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed( constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{}); Sequence<GemmMRepeat, GemmMPerThread, N1, 1, N2>{});
// global output tensor // global output tensor
constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor( constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor(
......
...@@ -159,7 +159,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -159,7 +159,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
1, 1,
GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN, GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>, Sequence<2, 3, 0, 1>,
3, 3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{}; GemmCThreadCopyDstDataPerWrite_GemmN1>{};
......
...@@ -18,11 +18,11 @@ template <index_t BlockSize, ...@@ -18,11 +18,11 @@ template <index_t BlockSize,
typename ThreadMatrixC, typename ThreadMatrixC,
index_t MPerThreadSubC, index_t MPerThreadSubC,
index_t NPerThreadSubC, index_t NPerThreadSubC,
index_t KPerThreadLoop,
index_t MLevel0ThreadCluster, index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster, index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster, index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster, index_t NLevel1ThreadCluster,
index_t KPerThreadLoop,
index_t ThreadGemmADataPerRead_M, index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N> index_t ThreadGemmBDataPerRead_N>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
......
...@@ -186,7 +186,6 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -186,7 +186,6 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
"wrong!"); "wrong!");
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
...@@ -201,11 +200,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -201,11 +200,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
decltype(c_m0m1_n0n1_thread_mtx_desc), decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThread, MPerThread,
NPerThread, NPerThread,
KPerThread,
MLevel0Cluster, MLevel0Cluster,
NLevel0Cluster, NLevel0Cluster,
MLevel1Cluster, MLevel1Cluster,
NLevel1Cluster, NLevel1Cluster,
KPerThread,
ThreadGemmAThreadCopySrcDataPerRead_M, ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N>{}; ThreadGemmBThreadCopySrcDataPerRead_N>{};
......
...@@ -207,7 +207,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16> ...@@ -207,7 +207,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
__device__ void __device__ void
run(Number<MPerWave>, Number<NPerWave>, const half* a, const half* b, float* reg_c) const run(Number<MPerWave>, Number<NPerWave>, const half_t* a, const half_t* b, float* reg_c) const
{ {
static_assert((MPerWave == 64 && NPerWave == 64) || (MPerWave == 32 && NPerWave == 64) || static_assert((MPerWave == 64 && NPerWave == 64) || (MPerWave == 32 && NPerWave == 64) ||
(MPerWave == 64 && NPerWave == 32), (MPerWave == 64 && NPerWave == 32),
...@@ -239,7 +239,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16> ...@@ -239,7 +239,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
__device__ void __device__ void
run(Number<MPerWave>, Number<NPerWave>, const half* a, const half* b, float* reg_c) const run(Number<MPerWave>, Number<NPerWave>, const half_t* a, const half_t* b, float* reg_c) const
{ {
static_assert((MPerWave == 32 && NPerWave == 32), "unsupported xdlops gemm"); static_assert((MPerWave == 32 && NPerWave == 32), "unsupported xdlops gemm");
...@@ -269,7 +269,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16> ...@@ -269,7 +269,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
__device__ void __device__ void
run(Number<MPerWave>, Number<NPerWave>, const half* a, const half* b, float* reg_c) const run(Number<MPerWave>, Number<NPerWave>, const half_t* a, const half_t* b, float* reg_c) const
{ {
static_assert((MPerWave == 16 && NPerWave == 16), "unsupported xdlops gemm"); static_assert((MPerWave == 16 && NPerWave == 16), "unsupported xdlops gemm");
...@@ -299,7 +299,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16> ...@@ -299,7 +299,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
__device__ void __device__ void
run(Number<MPerWave>, Number<NPerWave>, const half* a, const half* b, float* reg_c) const run(Number<MPerWave>, Number<NPerWave>, const half_t* a, const half_t* b, float* reg_c) const
{ {
static_assert((MPerWave == 16 && NPerWave == 64) || (MPerWave == 64 && NPerWave == 16), static_assert((MPerWave == 16 && NPerWave == 64) || (MPerWave == 64 && NPerWave == 16),
"unsupported xdlops gemm"); "unsupported xdlops gemm");
...@@ -330,7 +330,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16> ...@@ -330,7 +330,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
__device__ void __device__ void
run(Number<MPerWave>, Number<NPerWave>, const half* a, const half* b, float* reg_c) const run(Number<MPerWave>, Number<NPerWave>, const half_t* a, const half_t* b, float* reg_c) const
{ {
static_assert((MPerWave == 4 || MPerWave == 8) && NPerWave == 64, static_assert((MPerWave == 4 || MPerWave == 8) && NPerWave == 64,
"unsupported xdlops gemm"); "unsupported xdlops gemm");
...@@ -555,55 +555,55 @@ __device__ constexpr auto GetMFMAInfo<float, 4, 64>() ...@@ -555,55 +555,55 @@ __device__ constexpr auto GetMFMAInfo<float, 4, 64>()
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 64, 64>() __device__ constexpr auto GetMFMAInfo<half_t, 64, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{}; return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 64, 32>() __device__ constexpr auto GetMFMAInfo<half_t, 64, 32>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{}; return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 32, 64>() __device__ constexpr auto GetMFMAInfo<half_t, 32, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{}; return mfma_info<mfma_instr::mfma_f32_32x32x4f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 32, 32>() __device__ constexpr auto GetMFMAInfo<half_t, 32, 32>()
{ {
return mfma_info<mfma_instr::mfma_f32_32x32x8f16>{}; return mfma_info<mfma_instr::mfma_f32_32x32x8f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 16, 16>() __device__ constexpr auto GetMFMAInfo<half_t, 16, 16>()
{ {
return mfma_info<mfma_instr::mfma_f32_16x16x16f16>{}; return mfma_info<mfma_instr::mfma_f32_16x16x16f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 16, 64>() __device__ constexpr auto GetMFMAInfo<half_t, 16, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{}; return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 64, 16>() __device__ constexpr auto GetMFMAInfo<half_t, 64, 16>()
{ {
return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{}; return mfma_info<mfma_instr::mfma_f32_16x16x4f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 4, 64>() __device__ constexpr auto GetMFMAInfo<half_t, 4, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{}; return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{};
} }
template <> template <>
__device__ constexpr auto GetMFMAInfo<half, 8, 64>() __device__ constexpr auto GetMFMAInfo<half_t, 8, 64>()
{ {
return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{}; return mfma_info<mfma_instr::mfma_f32_4x4x4f16>{};
} }
......
...@@ -84,37 +84,37 @@ struct vector_type<float, 4> ...@@ -84,37 +84,37 @@ struct vector_type<float, 4>
}; };
template <> template <>
struct vector_type<half, 1> struct vector_type<half_t, 1>
{ {
using MemoryType = half; using MemoryType = half_t;
template <index_t I> template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>) __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number<I>)
{ {
static_assert(I < 1, "wrong"); static_assert(I < 1, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s; *(reinterpret_cast<half_t*>(&v) + I) = s;
} }
}; };
template <> template <>
struct vector_type<half, 2> struct vector_type<half_t, 2>
{ {
using MemoryType = half2_t; using MemoryType = half2_t;
union DataType union DataType
{ {
MemoryType vector; MemoryType vector;
half scalar[2]; half_t scalar[2];
}; };
template <index_t I> template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>) __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number<I>)
{ {
static_assert(I < 2, "wrong"); static_assert(I < 2, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s; *(reinterpret_cast<half_t*>(&v) + I) = s;
} }
__host__ __device__ static MemoryType Pack(half s0, half s1) __host__ __device__ static MemoryType Pack(half_t s0, half_t s1)
{ {
DataType data; DataType data;
data.scalar[0] = s0; data.scalar[0] = s0;
...@@ -124,24 +124,24 @@ struct vector_type<half, 2> ...@@ -124,24 +124,24 @@ struct vector_type<half, 2>
}; };
template <> template <>
struct vector_type<half, 4> struct vector_type<half_t, 4>
{ {
using MemoryType = half4_t; using MemoryType = half4_t;
union DataType union DataType
{ {
MemoryType vector; MemoryType vector;
half scalar[4]; half_t scalar[4];
}; };
template <index_t I> template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>) __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number<I>)
{ {
static_assert(I < 4, "wrong"); static_assert(I < 4, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s; *(reinterpret_cast<half_t*>(&v) + I) = s;
} }
__host__ __device__ static MemoryType Pack(half s0, half s1, half s2, half s3) __host__ __device__ static MemoryType Pack(half_t s0, half_t s1, half_t s2, half_t s3)
{ {
DataType data; DataType data;
data.scalar[0] = s0; data.scalar[0] = s0;
...@@ -255,8 +255,8 @@ struct inner_product_with_conversion ...@@ -255,8 +255,8 @@ struct inner_product_with_conversion
__device__ T operator()(half2_t a, half2_t b) const __device__ T operator()(half2_t a, half2_t b) const
{ {
const half* p_a_half = reinterpret_cast<const half*>(&a); const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half* p_b_half = reinterpret_cast<const half*>(&b); const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0; T acc = 0;
for(index_t v = 0; v < 2; ++v) for(index_t v = 0; v < 2; ++v)
...@@ -269,8 +269,8 @@ struct inner_product_with_conversion ...@@ -269,8 +269,8 @@ struct inner_product_with_conversion
__device__ T operator()(half4_t a, half4_t b) const __device__ T operator()(half4_t a, half4_t b) const
{ {
const half* p_a_half = reinterpret_cast<const half*>(&a); const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half* p_b_half = reinterpret_cast<const half*>(&b); const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0; T acc = 0;
for(index_t v = 0; v < 4; ++v) for(index_t v = 0; v < 4; ++v)
......
...@@ -14,17 +14,15 @@ using float2_t = float2; ...@@ -14,17 +14,15 @@ using float2_t = float2;
using float4_t = float4; using float4_t = float4;
// float // float
typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32))); typedef float float32_t __attribute__((ext_vector_type(32)));
// float16
// bfloat16 // bfloat16
typedef ushort ushort2_t __attribute__((ext_vector_type(2))); typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4))); typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8))); typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
// float16 // fp16
using half_t = half;
using half2_t = half2; using half2_t = half2;
using half4_t = float2; using half4_t = float2;
...@@ -93,37 +91,37 @@ struct vector_type<float, 4> ...@@ -93,37 +91,37 @@ struct vector_type<float, 4>
}; };
template <> template <>
struct vector_type<half, 1> struct vector_type<half_t, 1>
{ {
using MemoryType = half; using MemoryType = half_t;
template <index_t I> template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>) __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number<I>)
{ {
static_assert(I < 1, "wrong"); static_assert(I < 1, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s; *(reinterpret_cast<half_t*>(&v) + I) = s;
} }
}; };
template <> template <>
struct vector_type<half, 2> struct vector_type<half_t, 2>
{ {
using MemoryType = half2_t; using MemoryType = half2_t;
union DataType union DataType
{ {
MemoryType vector; MemoryType vector;
half scalar[2]; half_t scalar[2];
}; };
template <index_t I> template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>) __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number<I>)
{ {
static_assert(I < 2, "wrong"); static_assert(I < 2, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s; *(reinterpret_cast<half_t*>(&v) + I) = s;
} }
__host__ __device__ static MemoryType Pack(half s0, half s1) __host__ __device__ static MemoryType Pack(half_t s0, half_t s1)
{ {
DataType data; DataType data;
data.scalar[0] = s0; data.scalar[0] = s0;
...@@ -152,8 +150,8 @@ struct inner_product_with_conversion ...@@ -152,8 +150,8 @@ struct inner_product_with_conversion
__device__ T operator()(half2_t a, half2_t b) const __device__ T operator()(half2_t a, half2_t b) const
{ {
const half* p_a_half = reinterpret_cast<const half*>(&a); const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half* p_b_half = reinterpret_cast<const half*>(&b); const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0; T acc = 0;
for(index_t v = 0; v < 2; ++v) for(index_t v = 0; v < 2; ++v)
...@@ -166,8 +164,8 @@ struct inner_product_with_conversion ...@@ -166,8 +164,8 @@ struct inner_product_with_conversion
__device__ T operator()(half4_t a, half4_t b) const __device__ T operator()(half4_t a, half4_t b) const
{ {
const half* p_a_half = reinterpret_cast<const half*>(&a); const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half* p_b_half = reinterpret_cast<const half*>(&b); const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0; T acc = 0;
for(index_t v = 0; v < 4; ++v) for(index_t v = 0; v < 4; ++v)
......
...@@ -65,9 +65,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -65,9 +65,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -104,9 +104,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -104,9 +104,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -143,9 +143,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -143,9 +143,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -182,9 +182,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -182,9 +182,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -222,9 +222,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -222,9 +222,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -261,9 +261,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -261,9 +261,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -300,9 +300,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -300,9 +300,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -339,9 +339,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -339,9 +339,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -378,9 +378,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -378,9 +378,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -417,9 +417,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -417,9 +417,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -456,9 +456,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -456,9 +456,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -495,9 +495,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -495,9 +495,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -534,9 +534,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -534,9 +534,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -573,9 +573,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -573,9 +573,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -612,9 +612,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -612,9 +612,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -651,9 +651,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -651,9 +651,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -690,9 +690,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -690,9 +690,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -729,9 +729,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -729,9 +729,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2; constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 2; constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -761,7 +761,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -761,7 +761,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
#endif #endif
constexpr index_t N1 = GemmNRepeat; constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThread;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2); constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
...@@ -788,13 +788,13 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -788,13 +788,13 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
KPerBlock, KPerBlock,
EPerBlock, EPerBlock,
GemmNRepeat, GemmNRepeat,
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB, GemmDataPerReadB,
InBlockCopySubLengths_E_N1_B_N2, InBlockCopySubLengths_E_N1_B_N2,
......
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp"
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations>
void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
ck::index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0
// BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 16;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// BlockSize = 64, blockwise-GEMM 64x64, each thread hold 64 data
constexpr index_t BlockSize = 64;
constexpr index_t BPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 8, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
// BlockSize = 256, blockwise-GEMM 64x128, each thread hold 32 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 2;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 2;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1
constexpr index_t BlockSize = 64;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 32;
constexpr index_t EPerBlock = 4;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 1;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<1, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 16>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
#endif
constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv =
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
ConvolutionDirection::Forward,
BPerBlock,
KPerBlock,
EPerBlock,
GemmNRepeat,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
...@@ -62,9 +62,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -62,9 +62,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -95,9 +95,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -95,9 +95,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -128,9 +128,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -128,9 +128,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -162,9 +162,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -162,9 +162,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -195,9 +195,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -195,9 +195,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -230,9 +230,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -230,9 +230,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -265,9 +265,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -265,9 +265,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -298,9 +298,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -298,9 +298,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -333,9 +333,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -333,9 +333,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -366,9 +366,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -366,9 +366,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -401,9 +401,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -401,9 +401,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -434,9 +434,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -434,9 +434,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -467,9 +467,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -467,9 +467,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -502,9 +502,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -502,9 +502,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -535,9 +535,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -535,9 +535,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -570,9 +570,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -570,9 +570,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -603,9 +603,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -603,9 +603,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -636,9 +636,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -636,9 +636,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -669,9 +669,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -669,9 +669,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -702,9 +702,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -702,9 +702,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 2; constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -735,9 +735,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -735,9 +735,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -768,9 +768,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -768,9 +768,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 3; constexpr index_t GemmKPerBlock = 3;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -801,9 +801,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -801,9 +801,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 3; constexpr index_t GemmKPerBlock = 3;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -834,9 +834,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -834,9 +834,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 3; constexpr index_t GemmKPerBlock = 3;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 2;
...@@ -867,9 +867,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -867,9 +867,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -900,9 +900,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -900,9 +900,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 2; constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -933,9 +933,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -933,9 +933,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThreadSubC = 2; constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
...@@ -983,9 +983,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -983,9 +983,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThreadLoop, GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
......
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp"
using namespace ck;
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations>
void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
ck::index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_B = Sequence<4, 1>;
using InBlockCopyClusterLengths_E_B = Sequence<2, 128>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t OutThreadCopyDataPerAccess_B = 1;
#elif 1
// 1x1 filter, 8x8 image
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_B = Sequence<1, 4>;
using InBlockCopyClusterLengths_E_B = Sequence<8, 32>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t OutThreadCopyDataPerAccess_B = 4;
#elif 0
// 1x1 filter, 14x14 image
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_B = Sequence<2, 2>;
using InBlockCopyClusterLengths_E_B = Sequence<4, 64>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
constexpr index_t InBlockCopyDataPerAccess_B = 2;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t OutThreadCopyDataPerAccess_B = 2;
#endif
constexpr index_t B = N * Ho * Wo;
constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
#else
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated
#endif
<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
BPerBlock,
KPerBlock,
EPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopyDataPerAccess_B,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
OutThreadCopyDataPerAccess_B>{};
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000)));
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
...@@ -13,14 +13,9 @@ ...@@ -13,14 +13,9 @@
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv.hpp" #include "host_conv.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
//#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp"
//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_fp16.hpp"
//#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" //#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp" //#include "device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp"
...@@ -611,7 +606,7 @@ int main(int argc, char* argv[]) ...@@ -611,7 +606,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
...@@ -623,6 +618,18 @@ int main(int argc, char* argv[]) ...@@ -623,6 +618,18 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_fp16(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#elif 1 #elif 1
device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment