Unverified Commit 5c7cec11 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Code clean up (#20)



* tuning para,

* testing on v100

* add fp16

* remove deprecated tensor descriptor

* sync with miopen

* update build script
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 7d09790a
...@@ -8,53 +8,9 @@ ...@@ -8,53 +8,9 @@
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
#include "convolution_common.hpp"
namespace ck { namespace ck {
template <ConvolutionDirection>
struct make_wei_e_k_global_desc_v4r1;
template <>
struct make_wei_e_k_global_desc_v4r1<ConvolutionDirection::Forward>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
return reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(WeiDesc{}, I1, I3), Sequence<1, 0>{});
}
};
template <>
struct make_wei_e_k_global_desc_v4r1<ConvolutionDirection::BackwardWeight>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto wei_k_c_y_x_global_desc = WeiDesc{};
constexpr index_t K = wei_k_c_y_x_global_desc.GetLength(I0);
constexpr index_t C = wei_k_c_y_x_global_desc.GetLength(I1);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
return transform_tensor_descriptor(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3),
make_tuple(Merge<Sequence<C, Y * X>>{}, PassThrough<K>{}),
make_tuple(Sequence<1, 2>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
};
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
typename Float, typename Float,
...@@ -66,18 +22,17 @@ template <index_t GridSize, ...@@ -66,18 +22,17 @@ template <index_t GridSize,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename LeftPads,
typename RightPads, typename RightPads,
ConvolutionDirection ConvDirection,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
index_t GemmNRepeat, index_t GemmNRepeat,
index_t GemmMPerThreadSubC, index_t GemmMPerThread,
index_t GemmNPerThreadSubC, index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster, index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA, index_t GemmDataPerReadA,
index_t GemmDataPerReadB, index_t GemmDataPerReadB,
typename InBlockCopySubLengths_E_N1_B_N2, typename InBlockCopySubLengths_E_N1_B_N2,
...@@ -107,18 +62,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -107,18 +62,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
static_assert(ConvDirection == ConvolutionDirection::Forward ||
ConvDirection == ConvolutionDirection::BackwardWeight,
"wrong! this kernel only support convolution forward and backward-weight");
// this is a mess // this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters // TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr index_t N1 = GemmNRepeat; constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThread;
static_assert((N1 * N2 * BPerBlock) % static_assert(
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == (N1 * N2 * BPerBlock) % (GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
0,
"wrong!"); "wrong!");
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
...@@ -240,7 +190,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -240,7 +190,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// It is constructed differently, depending on whether forward or backward weight // It is constructed differently, depending on whether forward or backward weight
// convolution // convolution
constexpr auto wei_e_k_global_desc = constexpr auto wei_e_k_global_desc =
make_wei_e_k_global_desc_v4r1<ConvDirection>{}(wei_k_c_y_x_global_desc); transform_tensor_descriptor(unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3),
make_tuple(Merge<Sequence<C, Y * X>>{}, PassThrough<K>{}),
make_tuple(Sequence<1, 2>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// block tensor in LDS memory, dst of blockwise copy // block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -290,30 +243,29 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -290,30 +243,29 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
in_e_n1_b_n2_block_desc.GetStride(I0)); in_e_n1_b_n2_block_desc.GetStride(I0));
// sanity check // sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == static_assert(KPerBlock % (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0,
0,
"wrong!"); "wrong!");
constexpr index_t GemmMRepeat = constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); KPerBlock / (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{}); Number<GemmMRepeat * GemmMPerThread>{}, Number<GemmNRepeat * GemmNPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
decltype(a_e_k_block_mtx_desc), decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc), decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k1_n1n2_thread_mtx_desc), decltype(c_k0k1_n1n2_thread_mtx_desc),
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB>{}; GemmDataPerReadB>{};
...@@ -432,13 +384,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -432,13 +384,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// copy output: register to global memory // copy output: register to global memory
{ {
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t K1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t K0 = K / K1; constexpr index_t K0 = K / K1;
// define output tensor descriptor for threadwise copy // define output tensor descriptor for threadwise copy
// thread output tensor, src of threadwise copy // thread output tensor, src of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed( constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{}); Sequence<GemmMRepeat, GemmMPerThread, N1, 1, N2>{});
// global output tensor // global output tensor
constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor( constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor(
......
...@@ -18,11 +18,11 @@ template <index_t BlockSize, ...@@ -18,11 +18,11 @@ template <index_t BlockSize,
typename ThreadMatrixC, typename ThreadMatrixC,
index_t MPerThreadSubC, index_t MPerThreadSubC,
index_t NPerThreadSubC, index_t NPerThreadSubC,
index_t KPerThreadLoop,
index_t MLevel0ThreadCluster, index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster, index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster, index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster, index_t NLevel1ThreadCluster,
index_t KPerThreadLoop,
index_t ThreadGemmADataPerRead_M, index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N> index_t ThreadGemmBDataPerRead_N>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment