Commit 2a0592a9 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed gfx12

parent 7cb8a89f
......@@ -19,10 +19,9 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle<
ALayout,
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle<ALayout,
BLayout,
CLayout,
ADataType,
......@@ -34,33 +33,33 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BElementOp,
CElementOp,
GemmDefault,
2, // Prefetch stage
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 64, 1>,
1,
32,
16,
32,
64,
8,
16,
16,
1,
2,
S<2, 16, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<2, 16, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 8>,
1,
1,
S<1, 16, 1, 2>,
8>;
using ReferenceGemmInstance = ck::tensor_operation::host::
......
......@@ -73,6 +73,7 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_KRow = 1;
static constexpr index_t B_KRow = 1;
#endif
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
......
......@@ -136,6 +136,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
// static constexpr index_t src_b_data_size = 2;
// static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
......@@ -565,14 +566,20 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
// return GetLaneIdUnderSubGroup();
#ifdef __gfx12__
return GetLaneIdUnderSubGroup();
#else
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
#endif
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
// return GetLaneIdUnderSubGroup();
#ifdef __gfx12__
return GetLaneIdUnderSubGroup();
#else
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
#endif
}
__device__ static CIndex GetBeginOfThreadBlk()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment