Commit 2a0592a9 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed gfx12

parent 7cb8a89f
...@@ -19,10 +19,9 @@ using AElementOp = PassThrough; ...@@ -19,10 +19,9 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle<ALayout,
ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
...@@ -34,33 +33,33 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -34,33 +33,33 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BElementOp, BElementOp,
CElementOp, CElementOp,
GemmDefault, GemmDefault,
2, // Prefetch stage 1,
256, // BlockSize 32,
128, // MPerBlock 16,
256, // NPerBlock 32,
64, // KPerBlock 64,
8, // K1 8,
16, // MPerWmma 16,
16, // NPerWmma 16,
4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave 1,
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave 2,
S<4, 64, 1>, S<2, 16, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 8,
8, 8,
true, true,
S<4, 64, 1>, S<2, 16, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 8,
8, 8,
true, true,
1, // C shuffle (M Repeat) Per store 1,
1, // C shuffle (N Repeat) Per store 1,
S<1, 32, 1, 8>, S<1, 16, 1, 2>,
8>; 8>;
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -73,6 +73,7 @@ struct BlockwiseGemmWMMA ...@@ -73,6 +73,7 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_KRow = 1; static constexpr index_t A_KRow = 1;
static constexpr index_t B_KRow = 1; static constexpr index_t B_KRow = 1;
#endif #endif
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
......
...@@ -136,6 +136,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12, ...@@ -136,6 +136,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
// static constexpr index_t src_b_data_size = 2; // static constexpr index_t src_b_data_size = 2;
// static constexpr index_t acc_data_size = 4; // static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1; static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma; static constexpr index_t num_thread_per_subgroups = n_per_wmma;
...@@ -565,14 +566,20 @@ struct WmmaGemm ...@@ -565,14 +566,20 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex() __host__ __device__ static auto CalculateAThreadOriginDataIndex()
{ {
// return GetLaneIdUnderSubGroup(); #ifdef __gfx12__
return GetLaneIdUnderSubGroup();
#else
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
#endif
} }
__host__ __device__ static auto CalculateBThreadOriginDataIndex() __host__ __device__ static auto CalculateBThreadOriginDataIndex()
{ {
// return GetLaneIdUnderSubGroup(); #ifdef __gfx12__
return GetLaneIdUnderSubGroup();
#else
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
#endif
} }
__device__ static CIndex GetBeginOfThreadBlk() __device__ static CIndex GetBeginOfThreadBlk()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment