Commit 7d2fa996 authored by root's avatar root
Browse files

Restrict 4gemm to PassThrough + bug fix

parent bda26547
...@@ -42,48 +42,54 @@ namespace ck { ...@@ -42,48 +42,54 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ALayout, template <
typename BLayout, typename ALayout,
typename CLayout, typename BLayout,
typename ADataType, typename CLayout,
typename BDataType, typename ADataType,
typename CDataType, typename BDataType,
typename GemmAccDataType, typename CDataType,
typename CShuffleDataType, typename GemmAccDataType,
typename AElementwiseOperation, typename CShuffleDataType,
typename BElementwiseOperation, typename AElementwiseOperation,
typename CElementwiseOperation, typename BElementwiseOperation,
GemmSpecialization GemmSpec, typename CElementwiseOperation,
index_t NumGemmKPrefetchStage, GemmSpecialization GemmSpec,
index_t BlockSize, index_t NumGemmKPrefetchStage,
index_t MPerBlock, index_t BlockSize,
index_t NPerBlock, index_t MPerBlock,
index_t KPerBlock, index_t NPerBlock,
index_t AK1, index_t KPerBlock,
index_t BK1, index_t AK1,
index_t MPerXDL, index_t BK1,
index_t NPerXDL, index_t MPerXDL,
index_t MXdlPerWave, index_t NPerXDL,
index_t NXdlPerWave, index_t MXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, index_t NXdlPerWave,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferThreadClusterArrangeOrder,
index_t ABlockTransferSrcVectorDim, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferDstScalarPerVector_AK1, index_t ABlockTransferSrcScalarPerVector,
bool ABlockLdsExtraM, index_t ABlockTransferDstScalarPerVector_AK1,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1, bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferThreadClusterArrangeOrder,
index_t BBlockTransferSrcVectorDim, typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferSrcScalarPerVector,
bool BBlockLdsExtraN, index_t BBlockTransferDstScalarPerVector_BK1,
index_t CShuffleMXdlPerWavePerShuffle, bool BBlockLdsExtraN,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, index_t CShuffleNXdlPerWavePerShuffle,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceCGemm_4Gemm_Xdl_CShuffle struct DeviceCGemm_4Gemm_Xdl_CShuffle
: public DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
......
...@@ -60,8 +60,8 @@ template < ...@@ -60,8 +60,8 @@ template <
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
enable_if_t< enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>, is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false> bool> = false>
struct DeviceGemmDl struct DeviceGemmDl
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment