Unverified Commit a0058be6 authored by Haocong WANG's avatar Haocong WANG Committed by GitHub
Browse files

Disable SkipLDS & Align AIT api (#3)

parent cad3212d
......@@ -34,24 +34,25 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
256, // BlockSize
GemmDefault,
1, // Prefetch stage
128, // BlockSize
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
1, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
8, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 64, 1>,
8, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
......@@ -59,8 +60,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8,
true,
1, // C shuffle (M Repeat) Per store
4, // C shuffle (N Repeat) Per store
S<1, 32, 1, 8>,
1, // C shuffle (N Repeat) Per store
S<1, 16, 1, 8>,
8>;
// clang-format on
......
......@@ -72,41 +72,42 @@ using DeviceOpInstance =
ELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
256,
128,
128,
1,
64,
32,
64,
64,
8,
16,
16,
4,
2,
S<4, 64, 1>,
2,
S<4, 16, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
4,
4,
true,
S<4, 64, 1>,
S<4, 16, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
4,
4,
true,
1,
1,
S<1, 32, 1, 8>,
8>;
S<1, 2, 1, 32>,
1>;
int main(int argc, char* argv[])
{
......@@ -264,7 +265,7 @@ int main(int argc, char* argv[])
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
<< device_op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
......
......@@ -56,10 +56,10 @@ using DeviceOpInstanceKKNN =
NumDimK,
ADataType,
BDataType,
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
......@@ -67,6 +67,7 @@ using DeviceOpInstanceKKNN =
ASpec,
BSpec,
DESpec,
1,
256,
128,
128,
......
......@@ -39,7 +39,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......
......@@ -42,15 +42,16 @@ using DeviceConvFwdInstance =
OutputLayout<NDimSpatial>,
InKernelDataType,
WeiKernelDataType,
ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutKernelDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutKernelDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, // Prefetch stage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
......@@ -60,19 +61,19 @@ using DeviceConvFwdInstance =
16, // NPerWMMA
4, // MRepeat
2, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<4, 8, 8>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockTransferSrcScalarPerVector
1, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<4, 8, 8>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockTransferSrcScalarPerVector
1, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
4,
2,
......
......@@ -62,10 +62,10 @@ template <index_t NumDimG,
index_t NumDimK,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
......@@ -73,6 +73,7 @@ template <index_t NumDimG,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
TensorSpecialization DESpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
......@@ -100,7 +101,6 @@ template <index_t NumDimG,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
......@@ -132,8 +132,16 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
// Bug: blocksize 128, Tile 128x128x64, Repeat 8x2 Failure
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
......
......@@ -143,9 +143,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto AEnableLds = LWaves == 1 ? false : true;
static constexpr auto B0EnableLds = MWaves == 1 ? false : true;
static constexpr auto B1EnableLds = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true;
static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_manu = true;
static constexpr auto B0EnableLds_manu = true;
static constexpr auto B1EnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu;
static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu;
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>,
......
......@@ -28,14 +28,15 @@ template <typename ALayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
......@@ -63,7 +64,6 @@ template <typename ALayout,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
......@@ -94,12 +94,17 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
// Bug: blocksize 128, Tile 128x128x64, Repeat 8x2 Failure
// Bug: gemm.TileDesc(64, 32, 64, 64, 8, 0, 16, 16, 2, 2), failed
static constexpr auto BEnableLds_manu = true;
// Force enable LDS if uncommented following
// AEnableLds = true;
// BEnableLds = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
......@@ -744,7 +749,11 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
<< MRepeat << ", "
<< NRepeat
<< ">"
<< " NumPrefetch: "
<< " AEnableLds: "
<< AEnableLds << ", "
<< "BEnableLds: "
<< BEnableLds << ", "
<< "NumPrefetch: "
<< NumPrefetch << ", "
<< "LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
......
......@@ -34,6 +34,7 @@ template <typename ALayout,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
......@@ -61,7 +62,6 @@ template <typename ALayout,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
......@@ -92,7 +92,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false;
// Bug: blocksize 128, Tile 128x128x64, Repeat 8x2 Failure
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
......
......@@ -99,16 +99,17 @@ template <index_t NDimSpatial,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
......@@ -136,7 +137,6 @@ template <index_t NDimSpatial,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment