"...resnet50_tensorflow.git" did not exist on "9d5eb7985a2a78a7fb2b2eedd0f0d6b2f6ed2f94"
Commit 95710403 authored by Jing Zhang's avatar Jing Zhang
Browse files

add kpack with incorrect results

parent 44078dba
...@@ -64,6 +64,10 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -64,6 +64,10 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const auto InRightPadH = in_right_pads[I0]; const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
...@@ -71,6 +75,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -71,6 +75,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK / GemmKPack, GemmKPack)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor // input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
...@@ -97,6 +108,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -97,6 +108,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK / GemmKPack, GemmKPack)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor // output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
...@@ -104,11 +122,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -104,11 +122,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0));
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1));
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); const auto GemmK0 = wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0);
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{}; constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{};
...@@ -129,22 +147,26 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -129,22 +147,26 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{})); make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor // hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor
constexpr auto wei_gemmk_gemmm_global_iterator_hacks = constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple(
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor // hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
constexpr auto in_gemmk_gemmn_global_iterator_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global // hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format // tensor hack for NKHW format
...@@ -158,15 +180,15 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -158,15 +180,15 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk_gemmm_global_desc, return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc,
in_gemmk_gemmn_global_desc, in_gemmk0_gemmn_gemmk1_global_desc,
out_m0_m1_m2_n_global_desc, out_m0_m1_m2_n_global_desc,
out_gemm_block_cluster_desc, out_gemm_block_cluster_desc,
wei_gemmk_gemmm_global_iterator_hacks, wei_gemmk0_gemmm_gemmk1_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks, in_gemmk0_gemmn_gemmk1_global_iterator_hacks,
out_m0_m1_m2_n_global_iterator_hacks, out_m0_m1_m2_n_global_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks); in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks);
} }
} // namespace ck } // namespace ck
......
...@@ -30,16 +30,16 @@ __global__ void ...@@ -30,16 +30,16 @@ __global__ void
kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global, kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k_m_global_desc, const AGlobalDesc a_k0_m_k1_global_desc,
const BGlobalDesc b_k_n_global_desc, const BGlobalDesc b_k0_n_k1_global_desc,
const CGlobalDesc c_m0_m1_m2_n_global_desc, const CGlobalDesc c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc c_block_cluster_desc) const CBlockClusterDesc c_block_cluster_desc)
{ {
GridwiseGemm::Run(p_a_global, GridwiseGemm::Run(p_a_global,
p_b_global, p_b_global,
p_c_global, p_c_global,
a_k_m_global_desc, a_k0_m_k1_global_desc,
b_k_n_global_desc, b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc, c_m0_m1_m2_n_global_desc,
c_block_cluster_desc, c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
...@@ -66,18 +66,18 @@ __global__ void ...@@ -66,18 +66,18 @@ __global__ void
kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global, kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k_m_global_desc, const void __CONSTANT__* p_a_k0_m_k1_global_desc,
const void __CONSTANT__* p_b_k_n_global_desc, const void __CONSTANT__* p_b_k0_n_k1_global_desc,
const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc, const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc) const void __CONSTANT__* p_c_block_cluster_desc)
{ {
// first cast void __CONSTANT__ void* to void* // first cast void __CONSTANT__ void* to void*
// second cast void* to Desc* // second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4) // the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m_global_desc = const auto a_k0_m_k1_global_desc =
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k_m_global_desc); *reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k0_m_k1_global_desc);
const auto b_k_n_global_desc = const auto b_k0_n_k1_global_desc =
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k_n_global_desc); *reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k0_n_k1_global_desc);
const auto c_m0_m1_m2_n_global_desc = const auto c_m0_m1_m2_n_global_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_m2_n_global_desc); *reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_m2_n_global_desc);
...@@ -87,8 +87,8 @@ __global__ void ...@@ -87,8 +87,8 @@ __global__ void
GridwiseGemm::Run(p_a_global, GridwiseGemm::Run(p_a_global,
p_b_global, p_b_global,
p_c_global, p_c_global,
a_k_m_global_desc, a_k0_m_k1_global_desc,
b_k_n_global_desc, b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc, c_m0_m1_m2_n_global_desc,
c_block_cluster_desc, c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
...@@ -113,21 +113,21 @@ template <index_t BlockSize, ...@@ -113,21 +113,21 @@ template <index_t BlockSize,
index_t KPack, index_t KPack,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M_KPack,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_K_M_KPack,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M, index_t ABlockTransferDstScalarPerVector_KPack,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N, typename BBlockTransferThreadSliceLengths_K_N_KPack,
typename BBlockTransferThreadClusterLengths_K_N, typename BBlockTransferThreadClusterLengths_K_N_KPack,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N, index_t BBlockTransferDstScalarPerVector_KPack,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
...@@ -141,25 +141,26 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -141,25 +141,26 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_KPack>{},
Number<BBlockTransferDstScalarPerVector_N>{}); Number<BBlockTransferDstScalarPerVector_KPack>{},
Number<KPack>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB); return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
} }
...@@ -168,8 +169,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -168,8 +169,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__device__ static void Run(const FloatAB* __restrict__ p_a_global, __device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc, const AGlobalDesc& a_k0_m_k1_global_desc,
const BGlobalDesc& b_k_n_global_desc, const BGlobalDesc& b_k0_n_k1_global_desc,
const CGlobalDesc& c_m0_m1_m2_n_global_desc, const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
...@@ -182,15 +183,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -182,15 +183,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_k_m_global_desc.GetElementSpaceSize()); p_a_global, a_k0_m_k1_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k_n_global_desc.GetElementSpaceSize()); p_b_global, b_k0_n_k1_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize()); p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize());
const auto K = a_k_m_global_desc.GetLength(I0); const auto K0 = a_k0_m_k1_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1); const auto M = a_k0_m_k1_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1); const auto N = b_k0_n_k1_global_desc.GetLength(I1);
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -204,74 +205,73 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -204,74 +205,73 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_KPack>{},
Number<BBlockTransferDstScalarPerVector_N>{}); Number<BBlockTransferDstScalarPerVector_KPack>{},
// Number<MPerThread>{}, Number<KPack>{});
// Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock>, Sequence<KPerBlock, MPerBlock, KPack>,
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M_KPack,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M_KPack,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_k_m_global_desc), decltype(a_k0_m_k1_global_desc),
decltype(a_k_m_block_desc), decltype(a_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim, 2, //ABlockTransferSrcVectorDim,
1, 2,
ABlockTransferSrcScalarPerVector, 1, //ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M, 1, //ABlockTransferDstScalarPerVector_KPack,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
a_k_m_global_desc, a_k0_m_k1_global_desc,
make_multi_index(0, m_block_data_idx_on_global), make_multi_index(0, m_block_data_idx_on_global, 0),
a_k_m_block_desc, a_k0_m_k1_block_desc,
make_multi_index(0, 0)); make_multi_index(0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>, Sequence<KPerBlock, NPerBlock, KPack>,
BBlockTransferThreadSliceLengths_K_N, BBlockTransferThreadSliceLengths_K_N_KPack,
BBlockTransferThreadClusterLengths_K_N, BBlockTransferThreadClusterLengths_K_N_KPack,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_k_n_global_desc), decltype(b_k0_n_k1_global_desc),
decltype(b_k_n_block_desc), decltype(b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim, 1, //BBlockTransferSrcVectorDim,
1, 2,
BBlockTransferSrcScalarPerVector, 1, //BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N, 1, //BBlockTransferDstScalarPerVector_KPack,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
b_k_n_global_desc, b_k0_n_k1_global_desc,
make_multi_index(0, n_block_data_idx_on_global), make_multi_index(0, n_block_data_idx_on_global, 0),
b_k_n_block_desc, b_k0_n_k1_block_desc,
make_multi_index(0, 0)); make_multi_index(0, 0, 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -285,25 +285,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -285,25 +285,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
NPerBlock % (NPerWave * NRepeat) == 0, NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!"); "wrong!");
static_assert(KPerBlock % KPack == 0, "KPerBlock is wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k_m_block_desc, a_k0_m_k1_block_desc,
make_tuple( make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(Number<KPerBlock / KPack>{}, Number<KPack>{})), make_unmerge_transform(
make_unmerge_transform( make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{}))), make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor( constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k_n_block_desc, b_k0_n_k1_block_desc,
make_tuple( make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(Number<KPerBlock / KPack>{}, Number<KPack>{})), make_unmerge_transform(
make_unmerge_transform( make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{}))), make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize, BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
...@@ -313,6 +311,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -313,6 +311,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
MPerWave, MPerWave,
NPerWave, NPerWave,
KPack>{}; KPack>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout(); constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize(); constexpr index_t BlkSize = CLayout.GetBlkSize();
...@@ -332,10 +331,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -332,10 +331,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block; FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
...@@ -349,37 +348,39 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -349,37 +348,39 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{} // Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); //.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{}; constexpr auto a_k0_m_k1_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{}; constexpr auto b_k0_n_k1_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack = constexpr auto a_k0_m_k1_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{}; AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_k0_n_k1_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_k_m_block_desc.GetElementSpaceSize()); p_a_block_double, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_k_n_block_desc.GetElementSpaceSize()); p_b_block_double, b_k0_n_k1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>( auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double + a_block_space_size, a_k_m_block_desc.GetElementSpaceSize()); p_a_block_double + a_block_space_size, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>( auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double + b_block_space_size, b_k_n_block_desc.GetElementSpaceSize()); p_b_block_double + b_block_space_size, b_k0_n_k1_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf);
} }
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
...@@ -391,77 +392,83 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -391,77 +392,83 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
do do
{ {
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(
a_block_slice_copy_step, a_k0_m_k1_global_desc,
a_k_m_global_move_slice_window_iterator_hack); a_block_slice_copy_step,
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, a_k0_m_k1_global_move_slice_window_iterator_hack);
b_block_slice_copy_step, b_blockwise_copy.MoveSrcSliceWindow(
b_k_n_global_move_slice_window_iterator_hack); b_k0_n_k1_global_desc,
b_block_slice_copy_step,
b_k0_n_k1_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks); a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(
a_block_slice_copy_step, a_k0_m_k1_global_desc,
a_k_m_global_move_slice_window_iterator_hack); a_block_slice_copy_step,
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, a_k0_m_k1_global_move_slice_window_iterator_hack);
b_block_slice_copy_step, b_blockwise_copy.MoveSrcSliceWindow(
b_k_n_global_move_slice_window_iterator_hack); b_k0_n_k1_global_desc,
b_block_slice_copy_step,
b_k0_n_k1_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks); a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock; k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock); } while(k_block_data_begin < K0 - 2 * KPerBlock);
} }
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_global_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k0_m_k1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_global_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack); b_k0_n_k1_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf);
__syncthreads(); __syncthreads();
...@@ -507,10 +514,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -507,10 +514,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
blockwise_gemm.CalculateCThreadOriginDataIndex( blockwise_gemm.CalculateCThreadOriginDataIndex(
mr_i, nr_i, xdlops_i, blk_i); mr_i, nr_i, xdlops_i, blk_i);
const index_t k_thread_data_on_global = const index_t m_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0]; m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
const index_t b_thread_data_on_global = const index_t n_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1]; n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks = constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks =
...@@ -528,10 +535,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -528,10 +535,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{c_m0_m1_m2_n_global_desc, true>{c_m0_m1_m2_n_global_desc,
make_multi_index(k_thread_data_on_global / (M2 * M1), make_multi_index(m_thread_data_on_global / (M2 * M1),
k_thread_data_on_global % (M2 * M1) / M2, m_thread_data_on_global % (M2 * M1) / M2,
k_thread_data_on_global % M2, m_thread_data_on_global % M2,
b_thread_data_on_global)} n_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc, .Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_blk_buf_, c_blk_buf_,
...@@ -549,8 +556,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -549,8 +556,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__device__ static void Run(const FloatAB* __restrict__ p_a_global, __device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc, const AGlobalDesc& a_k0_m_k1_global_desc,
const BGlobalDesc& b_k_n_global_desc, const BGlobalDesc& b_k0_n_k1_global_desc,
const CGlobalDesc& c_m0_m1_m2_n_global_desc, const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
...@@ -563,8 +570,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -563,8 +570,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
Run(p_a_global, Run(p_a_global,
p_b_global, p_b_global,
p_c_global, p_c_global,
a_k_m_global_desc, a_k0_m_k1_global_desc,
b_k_n_global_desc, b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc, c_m0_m1_m2_n_global_desc,
c_block_cluster_desc, c_block_cluster_desc,
p_shared_block, p_shared_block,
......
...@@ -73,64 +73,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -73,64 +73,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2( const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(OutDesc::GetLengths())); sequence_to_tuple_of_number(OutDesc::GetLengths()));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
#if 0
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#else
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4; constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, GemmKPack>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, GemmKPack>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#endif
const auto descs = const auto descs =
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<TInWei, transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<TInWei,
...@@ -167,21 +141,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -167,21 +141,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmKPack, GemmKPack,
MRepeat, MRepeat,
NRepeat, NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0>, Sequence<1, 0, 2>,
Sequence<1, 0>, Sequence<1, 0, 2>,
0, 2,
GemmABlockTransferSrcScalarPerVector_GemmK, GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM, GemmABlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 1>, Sequence<0, 2, 1>,
Sequence<0, 1>, Sequence<0, 2, 1>,
1, 1,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>, Sequence<2, 3, 0, 1>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment