Commit 95710403 authored by Jing Zhang's avatar Jing Zhang
Browse files

add kpack with incorrect results

parent 44078dba
......@@ -64,6 +64,10 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
......@@ -71,6 +75,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK / GemmKPack, GemmKPack)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
......@@ -97,6 +108,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK / GemmKPack, GemmKPack)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
......@@ -104,11 +122,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0));
assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1));
const auto GemmK0 = wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0);
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{};
......@@ -129,22 +147,26 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
// hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor
constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
// hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
......@@ -158,15 +180,15 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk_gemmm_global_desc,
in_gemmk_gemmn_global_desc,
return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc,
in_gemmk0_gemmn_gemmk1_global_desc,
out_m0_m1_m2_n_global_desc,
out_gemm_block_cluster_desc,
wei_gemmk_gemmm_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks,
wei_gemmk0_gemmm_gemmk1_global_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_iterator_hacks,
out_m0_m1_m2_n_global_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks);
}
} // namespace ck
......
......@@ -30,16 +30,16 @@ __global__ void
kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k_m_global_desc,
const BGlobalDesc b_k_n_global_desc,
const AGlobalDesc a_k0_m_k1_global_desc,
const BGlobalDesc b_k0_n_k1_global_desc,
const CGlobalDesc c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc c_block_cluster_desc)
{
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
......@@ -66,18 +66,18 @@ __global__ void
kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k_m_global_desc,
const void __CONSTANT__* p_b_k_n_global_desc,
const void __CONSTANT__* p_a_k0_m_k1_global_desc,
const void __CONSTANT__* p_b_k0_n_k1_global_desc,
const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m_global_desc =
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k_m_global_desc);
const auto b_k_n_global_desc =
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k_n_global_desc);
const auto a_k0_m_k1_global_desc =
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k0_m_k1_global_desc);
const auto b_k0_n_k1_global_desc =
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k0_n_k1_global_desc);
const auto c_m0_m1_m2_n_global_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_m2_n_global_desc);
......@@ -87,8 +87,8 @@ __global__ void
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
......@@ -113,21 +113,21 @@ template <index_t BlockSize,
index_t KPack,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadSliceLengths_K_M_KPack,
typename ABlockTransferThreadClusterLengths_K_M_KPack,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
index_t ABlockTransferDstScalarPerVector_KPack,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadSliceLengths_K_N_KPack,
typename BBlockTransferThreadClusterLengths_K_N_KPack,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
index_t BBlockTransferDstScalarPerVector_KPack,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
......@@ -141,25 +141,26 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{});
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_KPack>{},
Number<BBlockTransferDstScalarPerVector_KPack>{},
Number<KPack>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
......@@ -168,8 +169,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const AGlobalDesc& a_k0_m_k1_global_desc,
const BGlobalDesc& b_k0_n_k1_global_desc,
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block,
......@@ -182,15 +183,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_k_m_global_desc.GetElementSpaceSize());
p_a_global, a_k0_m_k1_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k_n_global_desc.GetElementSpaceSize());
p_b_global, b_k0_n_k1_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize());
const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_global_desc.GetLength(I0);
const auto M = a_k0_m_k1_global_desc.GetLength(I1);
const auto N = b_k0_n_k1_global_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
......@@ -204,74 +205,73 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{});
// Number<MPerThread>{},
// Number<NPerThread>{});
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_KPack>{},
Number<BBlockTransferDstScalarPerVector_KPack>{},
Number<KPack>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
Sequence<KPerBlock, MPerBlock, KPack>,
ABlockTransferThreadSliceLengths_K_M_KPack,
ABlockTransferThreadClusterLengths_K_M_KPack,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
decltype(a_k0_m_k1_global_desc),
decltype(a_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
Sequence<0, 1, 2>,
2, //ABlockTransferSrcVectorDim,
2,
1, //ABlockTransferSrcScalarPerVector,
1, //ABlockTransferDstScalarPerVector_KPack,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k_m_global_desc,
make_multi_index(0, m_block_data_idx_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
a_k0_m_k1_global_desc,
make_multi_index(0, m_block_data_idx_on_global, 0),
a_k0_m_k1_block_desc,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
Sequence<KPerBlock, NPerBlock, KPack>,
BBlockTransferThreadSliceLengths_K_N_KPack,
BBlockTransferThreadClusterLengths_K_N_KPack,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
decltype(b_k0_n_k1_global_desc),
decltype(b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1>,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
Sequence<0, 1, 2>,
1, //BBlockTransferSrcVectorDim,
2,
1, //BBlockTransferSrcScalarPerVector,
1, //BBlockTransferDstScalarPerVector_KPack,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k_n_global_desc,
make_multi_index(0, n_block_data_idx_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
b_k0_n_k1_global_desc,
make_multi_index(0, n_block_data_idx_on_global, 0),
b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -285,25 +285,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!");
static_assert(KPerBlock % KPack == 0, "KPerBlock is wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k_m_block_desc,
make_tuple(
make_unmerge_transform(make_tuple(Number<KPerBlock / KPack>{}, Number<KPack>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2>{}));
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k_n_block_desc,
make_tuple(
make_unmerge_transform(make_tuple(Number<KPerBlock / KPack>{}, Number<KPack>{})),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2>{}));
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
......@@ -313,6 +311,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
MPerWave,
NPerWave,
KPack>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize();
......@@ -332,10 +331,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
......@@ -349,37 +348,39 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{};
constexpr auto a_k0_m_k1_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k0_n_k1_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack =
constexpr auto a_k0_m_k1_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
constexpr auto b_k0_n_k1_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_k_m_block_desc.GetElementSpaceSize());
p_a_block_double, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_k_n_block_desc.GetElementSpaceSize());
p_b_block_double, b_k0_n_k1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double + a_block_space_size, a_k_m_block_desc.GetElementSpaceSize());
p_a_block_double + a_block_space_size, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double + b_block_space_size, b_k_n_block_desc.GetElementSpaceSize());
p_b_block_double + b_block_space_size, b_k0_n_k1_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
......@@ -391,77 +392,83 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
a_blockwise_copy.MoveSrcSliceWindow(
a_k0_m_k1_global_desc,
a_block_slice_copy_step,
a_k0_m_k1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(
b_k0_n_k1_global_desc,
b_block_slice_copy_step,
b_k0_n_k1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
a_blockwise_copy.MoveSrcSliceWindow(
a_k0_m_k1_global_desc,
a_block_slice_copy_step,
a_k0_m_k1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(
b_k0_n_k1_global_desc,
b_block_slice_copy_step,
b_k0_n_k1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock);
} while(k_block_data_begin < K0 - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
a_k0_m_k1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
b_k0_n_k1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf);
__syncthreads();
......@@ -507,10 +514,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
blockwise_gemm.CalculateCThreadOriginDataIndex(
mr_i, nr_i, xdlops_i, blk_i);
const index_t k_thread_data_on_global =
const index_t m_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
const index_t b_thread_data_on_global =
const index_t n_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks =
......@@ -528,10 +535,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_global_desc,
make_multi_index(k_thread_data_on_global / (M2 * M1),
k_thread_data_on_global % (M2 * M1) / M2,
k_thread_data_on_global % M2,
b_thread_data_on_global)}
make_multi_index(m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_blk_buf_,
......@@ -549,8 +556,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const AGlobalDesc& a_k0_m_k1_global_desc,
const BGlobalDesc& b_k0_n_k1_global_desc,
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>,
......@@ -563,8 +570,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
Run(p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc,
p_shared_block,
......
......@@ -73,64 +73,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(OutDesc::GetLengths()));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif
#if 0
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#else
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, GemmKPack>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, GemmKPack>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<TInWei,
......@@ -167,21 +141,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmKPack,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 2, 1>,
Sequence<0, 2, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment