Commit 2159921e authored by ltqin's avatar ltqin
Browse files

fix NXdlPerWave

parent 4e816177
...@@ -208,7 +208,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -208,7 +208,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t KPerThread = K0PerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
...@@ -383,7 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -383,7 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
__device__ static auto GetWaveKNIdx(const index_t thread_id) __device__ static auto GetWaveKNIdx(const index_t thread_id)
{ {
constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor( constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KPerThread, NPerXDL))), make_tuple(make_merge_transform(make_tuple(K0PerThread, NPerXDL))),
make_tuple(Sequence<0, 1>{}), make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
...@@ -559,9 +559,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -559,9 +559,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// B matrix blockwise copy // B matrix blockwise copy
constexpr auto b_thread_desc_k0_k0b_n0_n1_n2_n3_k1 = constexpr auto b_thread_desc_k0_k0b_n0_n1_n2_n3_k1 =
make_naive_tensor_descriptor_packed(make_tuple(I1, make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{}, // KPerThread Number<K0PerThread>{}, // K0PerThread
I1, // NBlockId I1, // NBlockId
Number<MXdlPerWave>{}, // repeat Number<NXdlPerWave>{}, // repeat
I1, // waves I1, // waves
I1, // NPerXdlops I1, // NPerXdlops
Number<K1>{})); Number<K1>{}));
...@@ -597,7 +597,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -597,7 +597,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
FloatAB, FloatAB,
decltype(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1), decltype(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1),
decltype(b_thread_desc_k0_k0b_n0_n1_n2_n3_k1), decltype(b_thread_desc_k0_k0b_n0_n1_n2_n3_k1),
Sequence<I1, Number<KPerThread>{}, I1, Number<MXdlPerWave>{}, I1, I1, Number<K1>{}>, Sequence<I1, Number<K0PerThread>{}, I1, Number<NXdlPerWave>{}, I1, I1, Number<K1>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6>, Sequence<0, 1, 2, 3, 4, 5, 6>,
6, 6,
1, 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment