Commit f1d8217d authored by Jing Zhang's avatar Jing Zhang
Browse files

debuggging

parent fe728dc5
...@@ -19,9 +19,11 @@ using AElementOp = PassThrough; ...@@ -19,9 +19,11 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle<ALayout, // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
...@@ -33,34 +35,35 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -33,34 +35,35 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BElementOp, BElementOp,
CElementOp, CElementOp,
GemmDefault, GemmDefault,
1, 1, // Prefetch stage
32, 128, // BlockSize
16, 64, // MPerBlock
32, 128, // NPerBlock
64, 64, // KPerBlock
8, 8, // K1
16, 16, // MPerWmma
16, 16, // NPerWmma
1, 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
2, 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<2, 16, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 8,
8, 8,
true, true,
S<2, 16, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 8,
8, 8,
true, true,
1, 1, // C shuffle (M Repeat) Per store
1, 1, // C shuffle (N Repeat) Per store
S<1, 16, 1, 2>, S<1, 32, 1, 4>,
8>; 8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
...@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN = ...@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
2, 2,
4, 4,
4, 4,
true, false,
S<4, 32, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
4, 4,
4, 4,
true, false,
1, 1,
1, 1,
S<1, 64, 1, 2>, S<1, 64, 1, 2>,
......
...@@ -137,8 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -137,8 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = true; static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = true; static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment