Commit 6eca23d7 authored by Yanxing-Shi's avatar Yanxing-Shi
Browse files

adapt ait api

parent bef0cb20
...@@ -62,10 +62,10 @@ template <index_t NumDimG, ...@@ -62,10 +62,10 @@ template <index_t NumDimG,
index_t NumDimK, index_t NumDimK,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
...@@ -73,6 +73,7 @@ template <index_t NumDimG, ...@@ -73,6 +73,7 @@ template <index_t NumDimG,
TensorSpecialization ASpec, TensorSpecialization ASpec,
TensorSpecialization BSpec, TensorSpecialization BSpec,
TensorSpecialization DESpec, TensorSpecialization DESpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -100,7 +101,6 @@ template <index_t NumDimG, ...@@ -100,7 +101,6 @@ template <index_t NumDimG,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(), ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceBatchedContractionMultipleD_Wmma_CShuffle struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
......
...@@ -27,14 +27,15 @@ template <typename ALayout, ...@@ -27,14 +27,15 @@ template <typename ALayout,
typename ELayout, typename ELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -62,7 +63,6 @@ template <typename ALayout, ...@@ -62,7 +63,6 @@ template <typename ALayout,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(), ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
......
...@@ -32,6 +32,7 @@ template <typename ALayout, ...@@ -32,6 +32,7 @@ template <typename ALayout,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -59,7 +60,6 @@ template <typename ALayout, ...@@ -59,7 +60,6 @@ template <typename ALayout,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(), ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
......
...@@ -64,6 +64,7 @@ template <typename ALayout, ...@@ -64,6 +64,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment