Commit 673b30cf authored by ltqin's avatar ltqin
Browse files

add K0PerBlock dim

parent 7d42a6d4
......@@ -351,19 +351,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
}
__host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_N0_N1_N2_N3_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
MakeBGridDescriptor_K0_K0B_N0_N1_N2_N3_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
{
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(K0),
make_tuple(make_unmerge_transform(make_tuple(K0 / K0PerBlock, K0PerBlock)),
make_unmerge_transform(make_tuple(
N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3, 4>{}, Sequence<5>{}));
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4, 5>{}, Sequence<6>{}));
return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1;
}
......@@ -554,8 +554,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
constexpr auto b_thread_copy_desc_k0_n0_n1_n2_n3_k1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<KPerThread>{},
constexpr auto b_thread_copy_desc_k0_k0b_n0_n1_n2_n3_k1 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{}, // KPerThread
I1, // NBlockId
Number<MXdlPerWave>{}, // repeat
I1, // waves
......@@ -563,11 +564,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
Number<K1>{}));
ignore = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
b_thread_copy_desc_k0_n0_n1_n2_n3_k1.GetElementSpaceSize(),
b_thread_copy_desc_k0_k0b_n0_n1_n2_n3_k1.GetElementSpaceSize(),
true>{};
auto b_grid_desc_k0_n0_n1_n2_n3_k1 =
MakeBGridDescriptor_K0_N0_N1_N2_N3_K1(b_grid_desc_k0_n_k1);
auto b_grid_desc_k0_k0b_n0_n1_n2_n3_k1 =
MakeBGridDescriptor_K0_K0B_N0_N1_N2_N3_K1(b_grid_desc_k0_n_k1);
const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
......@@ -591,17 +592,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
ignore = ThreadwiseTensorSliceTransfer_v2<
FloatAB,
FloatAB,
decltype(b_grid_desc_k0_n0_n1_n2_n3_k1),
decltype(b_thread_copy_desc_k0_n0_n1_n2_n3_k1),
Sequence<Number<KPerThread>{}, I1, Number<MXdlPerWave>{}, I1, I1, Number<K1>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
decltype(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1),
decltype(b_thread_copy_desc_k0_k0b_n0_n1_n2_n3_k1),
Sequence<I1, Number<KPerThread>{}, I1, Number<MXdlPerWave>{}, I1, I1, Number<K1>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_k0_n0_n1_n2_n3_k1,
true>(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
make_multi_index(
wave_k_n_id[I0], n_block_data_idx_on_grid, 0, wave_id[I1], wave_k_n_id[I1], 0));
0, wave_k_n_id[I0], block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0));
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment