Commit 31ca2f41 authored by aska-0096's avatar aska-0096
Browse files

update fmha config, no scratch generated

parent b8e153a4
...@@ -42,8 +42,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -42,8 +42,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8, // K1 8, // K1
16, // MPerWmma 16, // MPerWmma
16, // NPerWmma 16, // NPerWmma
8, // M-Repeat // M-PerWmma / M-Repeat = M-Wave 1, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
1, // N-Repeat // N-PerWmma / N-Repeat = N-Wave 8, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 64, 1>, S<4, 64, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -58,9 +58,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -58,9 +58,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8, 8,
8, 8,
true, true,
4, // C shuffle (M Repeat) Per store 1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store 4, // C shuffle (N Repeat) Per store
S<1, 16, 1, 16>, S<1, 32, 1, 8>,
8>; 8>;
// clang-format on // clang-format on
......
...@@ -96,7 +96,7 @@ using DeviceGemmInstance = ...@@ -96,7 +96,7 @@ using DeviceGemmInstance =
256, 256,
// Gemm 0 // Gemm 0
128, // MPerBlock 128, // MPerBlock
128, // LPerBlock 64, // LPerBlock
32, // KPerBlock 32, // KPerBlock
8, // K1 8, // K1
// Gemm 1 // Gemm 1
...@@ -108,7 +108,7 @@ using DeviceGemmInstance = ...@@ -108,7 +108,7 @@ using DeviceGemmInstance =
16, // NPerWMMA 16, // NPerWMMA
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, // MRepeat 1, // MRepeat
8, // LRepeat 4, // LRepeat
4, // NRepeat 4, // NRepeat
S<4, 64, 1>, // ABlockTransfer MK -> K0 M K1 S<4, 64, 1>, // ABlockTransfer MK -> K0 M K1
S<1, 0, 2>, S<1, 0, 2>,
...@@ -129,12 +129,12 @@ using DeviceGemmInstance = ...@@ -129,12 +129,12 @@ using DeviceGemmInstance =
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
8, 8,
1, // be eight? 1,
false, false,
1, // CShuffleMWmmaPerWavePerShuffle 1, // CShuffleMWmmaPerWavePerShuffle
2, // CShuffleNWmmaPerWavePerShuffle 2, // CShuffleNWmmaPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment