"vscode:/vscode.git/clone" did not exist on "d38c804320192c3844ff0bc7deed83e8b8cb7856"
Commit 7d42a6d4 authored by ltqin's avatar ltqin
Browse files

add thread copy desc and register buffer

parent c08dcaad
...@@ -51,8 +51,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds ...@@ -51,8 +51,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
...@@ -64,13 +64,13 @@ int main(int argc, char* argv[]) ...@@ -64,13 +64,13 @@ int main(int argc, char* argv[])
int nrepeat = 5; int nrepeat = 5;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 64;
ck::index_t N = 4096; ck::index_t N = 128;
ck::index_t K = 4096; ck::index_t K = 64;
ck::index_t StrideA = 4096; ck::index_t StrideA = 64;
ck::index_t StrideB = 4096; ck::index_t StrideB = 64;
ck::index_t StrideC = 4096; ck::index_t StrideC = 128;
if(argc == 4) if(argc == 4)
{ {
......
...@@ -202,6 +202,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -202,6 +202,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t KPerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
...@@ -349,7 +356,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -349,7 +356,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL);
const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor( const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor(
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
...@@ -360,6 +366,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -360,6 +366,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
make_tuple(Sequence<0>{}, Sequence<1, 2, 3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2, 3, 4>{}, Sequence<5>{}));
return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1; return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1;
} }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetWaveKNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KPerThread, NPerXDL))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_nk_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
...@@ -525,26 +554,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -525,26 +554,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
constexpr auto b_thread_copy_desc_k0_n0_n1_n2_n3_k1 =
/*static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t KPerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
constexpr auto b_k0_n0_n1_n2_n3_k1_thread_copy_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<KPerThread>{}, make_naive_tensor_descriptor_packed(make_tuple(Number<KPerThread>{},
I1, // NBlockId
Number<MXdlPerWave>{}, // repeat Number<MXdlPerWave>{}, // repeat
I1, // waves I1, // waves
I1, // NPerXdlops I1, // NPerXdlops
Number<K1>{})); Number<K1>{}));
StaticBuffer<AddressSpaceEnum::Vgpr, ignore = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB, FloatAB,
b_k0_n0_n1_n2_n3_k1_thread_copy_desc.GetElementSpaceSize(), b_thread_copy_desc_k0_n0_n1_n2_n3_k1.GetElementSpaceSize(),
true> true>{};
b_thread_buf;
*/ auto b_grid_desc_k0_n0_n1_n2_n3_k1 =
MakeBGridDescriptor_K0_N0_N1_N2_N3_K1(b_grid_desc_k0_n_k1);
MakeBGridDescriptor_K0_N0_N1_N2_N3_K1(b_grid_desc_k0_n_k1);
const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
#if 0
const index_t block_id = get_block_1d_id();
const index_t thread_id = get_thread_local_1d_id();
printf("block id: %d m blockid: %d n block id: %d ,thread id: %d, wave id :{%d %d %d} "
"kn id: {%d %d}\n",
block_id,
block_work_idx[I0],
block_work_idx[I1],
thread_id,
wave_id[I0],
wave_id[I1],
wave_id[I2],
wave_k_n_id[I0],
wave_k_n_id[I1]);
#endif
ignore = ThreadwiseTensorSliceTransfer_v2<
FloatAB,
FloatAB,
decltype(b_grid_desc_k0_n0_n1_n2_n3_k1),
decltype(b_thread_copy_desc_k0_n0_n1_n2_n3_k1),
Sequence<Number<KPerThread>{}, I1, Number<MXdlPerWave>{}, I1, I1, Number<K1>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_k0_n0_n1_n2_n3_k1,
make_multi_index(
wave_k_n_id[I0], n_block_data_idx_on_grid, 0, wave_id[I1], wave_k_n_id[I1], 0));
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment