Commit 673b30cf authored by ltqin's avatar ltqin
Browse files

add K0PerBlock dim

parent 7d42a6d4
...@@ -351,19 +351,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -351,19 +351,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_N0_N1_N2_N3_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1) MakeBGridDescriptor_K0_K0B_N0_N1_N2_N3_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
{ {
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor( const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor(
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(K0), make_tuple(make_unmerge_transform(make_tuple(K0 / K0PerBlock, K0PerBlock)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)), N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4, 5>{}, Sequence<6>{}));
return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1; return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1;
} }
...@@ -554,8 +554,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -554,8 +554,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
constexpr auto b_thread_copy_desc_k0_n0_n1_n2_n3_k1 = constexpr auto b_thread_copy_desc_k0_k0b_n0_n1_n2_n3_k1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<KPerThread>{}, make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{}, // KPerThread
I1, // NBlockId I1, // NBlockId
Number<MXdlPerWave>{}, // repeat Number<MXdlPerWave>{}, // repeat
I1, // waves I1, // waves
...@@ -563,11 +564,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -563,11 +564,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
Number<K1>{})); Number<K1>{}));
ignore = StaticBuffer<AddressSpaceEnum::Vgpr, ignore = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB, FloatAB,
b_thread_copy_desc_k0_n0_n1_n2_n3_k1.GetElementSpaceSize(), b_thread_copy_desc_k0_k0b_n0_n1_n2_n3_k1.GetElementSpaceSize(),
true>{}; true>{};
auto b_grid_desc_k0_n0_n1_n2_n3_k1 = auto b_grid_desc_k0_k0b_n0_n1_n2_n3_k1 =
MakeBGridDescriptor_K0_N0_N1_N2_N3_K1(b_grid_desc_k0_n_k1); MakeBGridDescriptor_K0_K0B_N0_N1_N2_N3_K1(b_grid_desc_k0_n_k1);
const auto wave_id = GetWaveIdx(); const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]); const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
...@@ -591,17 +592,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -591,17 +592,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
ignore = ThreadwiseTensorSliceTransfer_v2< ignore = ThreadwiseTensorSliceTransfer_v2<
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_grid_desc_k0_n0_n1_n2_n3_k1), decltype(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1),
decltype(b_thread_copy_desc_k0_n0_n1_n2_n3_k1), decltype(b_thread_copy_desc_k0_k0b_n0_n1_n2_n3_k1),
Sequence<Number<KPerThread>{}, I1, Number<MXdlPerWave>{}, I1, I1, Number<K1>{}>, Sequence<I1, Number<KPerThread>{}, I1, Number<MXdlPerWave>{}, I1, I1, Number<K1>{}>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5, 6>,
5, 6,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
b_grid_desc_k0_n0_n1_n2_n3_k1, make_multi_index(
make_multi_index( 0, wave_k_n_id[I0], block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0));
wave_k_n_id[I0], n_block_data_idx_on_grid, 0, wave_id[I1], wave_k_n_id[I1], 0));
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment