"vscode:/vscode.git/clone" did not exist on "ce9cf353274d4fd7865ec6a1063afba96de63f64"
Unverified Commit a0058be6 authored by Haocong WANG's avatar Haocong WANG Committed by GitHub
Browse files

Disable SkipLDS & Align AIT api (#3)

parent cad3212d
...@@ -35,23 +35,24 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -35,23 +35,24 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BElementOp, BElementOp,
CElementOp, CElementOp,
GemmDefault, GemmDefault,
256, // BlockSize 1, // Prefetch stage
128, // BlockSize
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 64, // KPerBlock
8, // K1 8, // K1
16, // MPerWmma 16, // MPerWmma
16, // NPerWmma 16, // NPerWmma
1, // M-Repeat // M-PerWmma / M-Repeat = M-Wave 8, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
8, // N-Repeat // N-PerWmma / N-Repeat = N-Wave 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 64, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 8,
8, 8,
true, true,
S<4, 64, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
...@@ -59,8 +60,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -59,8 +60,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8, 8,
true, true,
1, // C shuffle (M Repeat) Per store 1, // C shuffle (M Repeat) Per store
4, // C shuffle (N Repeat) Per store 1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 8>, S<1, 16, 1, 8>,
8>; 8>;
// clang-format on // clang-format on
......
...@@ -72,41 +72,42 @@ using DeviceOpInstance = ...@@ -72,41 +72,42 @@ using DeviceOpInstance =
ELayout, ELayout,
ADataType, ADataType,
BDataType, BDataType,
ck::Tuple<DDataType>,
EDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp, CDEElementOp,
GemmSpec, GemmSpec,
256, 1,
128, 64,
128,
32, 32,
64,
64,
8, 8,
16, 16,
16, 16,
4,
2, 2,
S<4, 64, 1>, 2,
S<4, 16, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 4,
8, 4,
true, true,
S<4, 64, 1>, S<4, 16, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 4,
8, 4,
true, true,
1, 1,
1, 1,
S<1, 32, 1, 8>, S<1, 2, 1, 32>,
8>; 1>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -264,7 +265,7 @@ int main(int argc, char* argv[]) ...@@ -264,7 +265,7 @@ int main(int argc, char* argv[])
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << device_op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
......
...@@ -56,10 +56,10 @@ using DeviceOpInstanceKKNN = ...@@ -56,10 +56,10 @@ using DeviceOpInstanceKKNN =
NumDimK, NumDimK,
ADataType, ADataType,
BDataType, BDataType,
DsDataType,
EDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType,
EDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp, CDEElementOp,
...@@ -67,6 +67,7 @@ using DeviceOpInstanceKKNN = ...@@ -67,6 +67,7 @@ using DeviceOpInstanceKKNN =
ASpec, ASpec,
BSpec, BSpec,
DESpec, DESpec,
1,
256, 256,
128, 128,
128, 128,
......
...@@ -39,7 +39,7 @@ using S = ck::Sequence<Is...>; ...@@ -39,7 +39,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvSpec = static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......
...@@ -42,15 +42,16 @@ using DeviceConvFwdInstance = ...@@ -42,15 +42,16 @@ using DeviceConvFwdInstance =
OutputLayout<NDimSpatial>, OutputLayout<NDimSpatial>,
InKernelDataType, InKernelDataType,
WeiKernelDataType, WeiKernelDataType,
ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutKernelDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutKernelDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
ConvSpec, // ConvForwardSpecialization ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization GemmSpec, // GemmSpecialization
1, // Prefetch stage
256, // BlockSize 256, // BlockSize
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
...@@ -60,19 +61,19 @@ using DeviceConvFwdInstance = ...@@ -60,19 +61,19 @@ using DeviceConvFwdInstance =
16, // NPerWMMA 16, // NPerWMMA
4, // MRepeat 4, // MRepeat
2, // NRepeat 2, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<4, 8, 8>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1 1, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM true, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<4, 8, 8>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1 1, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN true, // BBlockLdsExtraN
4, 4,
2, 2,
......
...@@ -62,10 +62,10 @@ template <index_t NumDimG, ...@@ -62,10 +62,10 @@ template <index_t NumDimG,
index_t NumDimK, index_t NumDimK,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
...@@ -73,6 +73,7 @@ template <index_t NumDimG, ...@@ -73,6 +73,7 @@ template <index_t NumDimG,
TensorSpecialization ASpec, TensorSpecialization ASpec,
TensorSpecialization BSpec, TensorSpecialization BSpec,
TensorSpecialization DESpec, TensorSpecialization DESpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -100,7 +101,6 @@ template <index_t NumDimG, ...@@ -100,7 +101,6 @@ template <index_t NumDimG,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(), ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceBatchedContractionMultipleD_Wmma_CShuffle struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...@@ -132,8 +132,16 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -132,8 +132,16 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16; static constexpr auto WmmaK = 16;
static constexpr auto AEnableLds = NWaves == 1 ? false : true; static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true; static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
// Bug: blocksize 128, Tile 128x128x64, Repeat 8x2 Failure
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
......
...@@ -143,9 +143,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -143,9 +143,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto AEnableLds = LWaves == 1 ? false : true; static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true;
static constexpr auto B0EnableLds = MWaves == 1 ? false : true; static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto B1EnableLds = MWaves == 1 ? false : true; static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true;
static constexpr auto AEnableLds_manu = true;
static constexpr auto B0EnableLds_manu = true;
static constexpr auto B1EnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu;
static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu;
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm< using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>, Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>,
......
...@@ -28,14 +28,15 @@ template <typename ALayout, ...@@ -28,14 +28,15 @@ template <typename ALayout,
typename ELayout, typename ELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -63,7 +64,6 @@ template <typename ALayout, ...@@ -63,7 +64,6 @@ template <typename ALayout,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(), ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...@@ -94,12 +94,17 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -94,12 +94,17 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16; static constexpr auto WmmaK = 16;
static constexpr auto AEnableLds = NWaves == 1 ? false : true; static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true; static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
// Bug: blocksize 128, Tile 128x128x64, Repeat 8x2 Failure
// Bug: gemm.TileDesc(64, 32, 64, 64, 8, 0, 16, 16, 2, 2), failed
static constexpr auto BEnableLds_manu = true;
// Force enable LDS if uncommented following static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
// AEnableLds = true; static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
// BEnableLds = true;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
...@@ -744,7 +749,11 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -744,7 +749,11 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
<< MRepeat << ", " << MRepeat << ", "
<< NRepeat << NRepeat
<< ">" << ">"
<< " NumPrefetch: " << " AEnableLds: "
<< AEnableLds << ", "
<< "BEnableLds: "
<< BEnableLds << ", "
<< "NumPrefetch: "
<< NumPrefetch << ", " << NumPrefetch << ", "
<< "LoopScheduler: " << "LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", " << LoopSchedToString[LoopSched] << ", "
......
...@@ -34,6 +34,7 @@ template <typename ALayout, ...@@ -34,6 +34,7 @@ template <typename ALayout,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -61,7 +62,6 @@ template <typename ALayout, ...@@ -61,7 +62,6 @@ template <typename ALayout,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(), ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...@@ -92,7 +92,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -92,7 +92,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// If true, LDS is used unconditionally // If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false; // Bug: blocksize 128, Tile 128x128x64, Repeat 8x2 Failure
static constexpr auto BEnableLds_manu = true;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu; static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu; static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
......
...@@ -100,15 +100,16 @@ template <index_t NDimSpatial, ...@@ -100,15 +100,16 @@ template <index_t NDimSpatial,
typename ELayout, typename ELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -136,7 +137,6 @@ template <index_t NDimSpatial, ...@@ -136,7 +137,6 @@ template <index_t NDimSpatial,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment