Commit bdd0f64e authored by aska-0096's avatar aska-0096
Browse files

Fix a bug

parent a045e0be
...@@ -26,7 +26,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -26,7 +26,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
{ {
split_k = 1; split_k = 1;
} }
const auto in_g_n_c_wis_desc = const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
InputLayout<NDimSpatial>>(conv_param); InputLayout<NDimSpatial>>(conv_param);
......
...@@ -62,20 +62,6 @@ struct BlockwiseGemmWMMA ...@@ -62,20 +62,6 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4);
static constexpr auto A_temp0 = Number<ABlockDesc{}.GetLength(I0)>{};
static constexpr auto A_temp1 = Number<ABlockDesc{}.GetLength(I1)>{};
static constexpr auto A_temp2 = Number<ABlockDesc{}.GetLength(I2)>{};
static constexpr auto A_temp3 = Number<ABlockDesc{}.GetLength(I3)>{};
static constexpr auto A_temp4 = Number<ABlockDesc{}.GetLength(I4)>{};
// FIX it, workaround
using ABlockDesc_temp = decltype(
make_naive_tensor_descriptor(make_tuple(A_temp0, A_temp1, A_temp2, A_temp3, A_temp4),
make_tuple(A_temp1* A_temp2* A_temp3* A_temp4,
A_temp2* A_temp3* A_temp4,
A_temp3* A_temp4,
A_temp4,
I1)));
static constexpr auto wmma_gemm = static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{}; WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
...@@ -210,9 +196,6 @@ struct BlockwiseGemmWMMA ...@@ -210,9 +196,6 @@ struct BlockwiseGemmWMMA
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
// constexpr auto NSubGroup =
// c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; constexpr auto MThreadPerSubGroup
// = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
...@@ -302,7 +285,7 @@ struct BlockwiseGemmWMMA ...@@ -302,7 +285,7 @@ struct BlockwiseGemmWMMA
// Describe how data allocated in thread copy src buffer // Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc_temp a_block_desc_k0_m0_m1_m2_k1; static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
......
...@@ -249,20 +249,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -249,20 +249,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5);
// Err: merge transform cause non-constexpr issue
// return transform_tensor_descriptor(
// ABlockDesc_{},
// make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
// make_pass_through_transform(Number<MRepeat>{}),
// make_pass_through_transform(I1),
// make_pass_through_transform(I1),
// make_pass_through_transform(Number<A_K1>{})),
// make_tuple(Sequence<0, 3>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},
// Sequence<4>{}));
// Workaround, Freeze transform
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)), make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<MRepeat>{}), make_pass_through_transform(Number<MRepeat>{}),
make_pass_through_transform(I1), make_pass_through_transform(I1),
make_pass_through_transform(I1), make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0, 3>{}, make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}), Sequence<5>{}),
make_tuple( make_tuple(Sequence<>{},
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
...@@ -455,14 +480,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -455,14 +480,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
static constexpr auto a_block_space_size_aligned = static constexpr auto a_block_space_size_aligned =
AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
max_lds_align) * max_lds_align)
sizeof(FloatA)
: 0; : 0;
static constexpr auto b_block_space_size_aligned = static constexpr auto b_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple( BEnableLds ? math::integer_least_multiple(
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1().GetElementSpaceSize(), GetBBlockDescriptor_K0PerBlock_NPerBlock_K1().GetElementSpaceSize(),
max_lds_align) * max_lds_align)
sizeof(FloatB)
: 0; : 0;
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
...@@ -471,13 +494,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -471,13 +494,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_space_size = static constexpr auto c_shuffle_block_space_size =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
.GetElementSpaceSize() * .GetElementSpaceSize();
sizeof(FloatCShuffle);
static constexpr auto c_shuffle_block_space_offset = 0; static constexpr auto c_shuffle_block_space_offset = 0;
static constexpr auto lds_size = math::max( static constexpr auto lds_size =
c_shuffle_block_space_size, (a_block_space_size_aligned + b_block_space_size_aligned)); math::max(c_shuffle_block_space_size * sizeof(FloatCShuffle),
a_block_space_size_aligned * sizeof(FloatA) +
b_block_space_size_aligned * sizeof(FloatB));
}; };
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
...@@ -528,8 +552,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -528,8 +552,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
} }
}(); }();
// printf("---------------K = %d\n", K);
constexpr auto a_block_desc = MakeABlockDescriptor(); constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); constexpr auto b_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
...@@ -540,7 +562,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -540,7 +562,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto K0PerBlock = KPerBlock/ K1; constexpr auto K0PerBlock = KPerBlock/ K1;
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatA*>(p_shared), static_cast<FloatA*>(p_shared),
a_block_desc.GetElementSpaceSize()); SharedMemTrait::a_block_space_size_aligned);
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
...@@ -615,8 +637,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -615,8 +637,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
{ {
constexpr auto K0PerBlock = KPerBlock/ K1; constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB*>(p_shared) + SharedMemTrait::a_block_space_size_aligned, static_cast<FloatB*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc.GetElementSpaceSize()); SharedMemTrait::b_block_space_size_aligned);
auto b_blockwise_copy = auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
...@@ -703,7 +725,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -703,7 +725,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/ /*******************************************************************************/
// Shift Per SUB_K // Shift Per SUB_K
constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
// printf("a_block_slice_copy_step FirstKdim = %d\n", a_block_slice_copy_step[I0]);
constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
// gridwise GEMM pipeline // gridwise GEMM pipeline
...@@ -726,13 +747,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -726,13 +747,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/ /*******************************************************************************/
// write out to C, implement shuffle // write out to C, implement shuffle
{ {
#if 0
static_for<0, c_thread_buf.Size(), 1>{}([&](auto i) {
printf("tid: %03d, c_thread_buf[%02d] val: %08x\n", get_thread_local_1d_id(), i.value,
*(reinterpret_cast<const uint32_t*>(&(c_thread_buf[i]))));
// c_thread_buf(i) = 32;
});
#endif
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
...@@ -751,7 +765,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -751,7 +765,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared), SharedMemTrait::c_shuffle_block_space_size); static_cast<FloatCShuffle*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset,
SharedMemTrait::c_shuffle_block_space_size);
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment