Unverified Commit ec7c2e91 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Code refactor (#175)

* format

* improving pipeline

* fix typo

* format

* adding thread group

* adding thread group

* adding thread group

* adding gemm pipeline

* tweak

* refactor

* refactor

* add missing type convert

* refactor

* refactor

* refactor

* clean

* fix build

* refactor

* format

* clean up

* use remove_cvref_t

* clean

* clean up

* clean up

* clean up
parent a3c910ac
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -37,47 +36,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor; ...@@ -37,47 +36,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
ADataType, // ADataType <ALayout, // typename ALayout
BDataType, // BDataType BLayout, // typename BLayout
CDataType, // CDataType CLayout, // typename CLayout
AccDataType, // AccDataType ADataType, // typename ADataType
CDataType, // CShuffleDataType BDataType, // typename BDataType
ALayout, // ALayout CDataType, // typename CDataType
BLayout, // BLayout AccDataType, // typename GemmAccDataType
CLayout, // CLayout CDataType, // typename CShuffleDataType
PassThrough, // AElementwiseOperation PassThrough, // typename AElementwiseOperation
PassThrough, // BElementwiseOperation PassThrough, // typename BElementwiseOperation
PassThrough, // CElementwiseOperation PassThrough, // typename CElementwiseOperation
256, // BlockSize GemmDefault, // GemmSpecialization GemmSpec
256, // MPerBlock 1, // index_t NumGemmKPrefetchStage
128, // NPerBlock 256, // index_t BlockSize
32, // KPerBlock 256, // index_t MPerBlock
8, // AK1 128, // index_t NPerBlock
8, // BK1 32, // index_t KPerBlock
32, // MPerXDL 8, // index_t AK1
32, // NPerXDL 8, // index_t BK1
4, // MXdlPerWave 32, // index_t MPerXDL
2, // NXdlPerWave 32, // index_t NPerXDL
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 4, // index_t MXdlPerWave
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder 2, // index_t NXdlPerWave
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1
2, // ABlockTransferSrcVectorDim S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder
8, // ABlockTransferSrcScalarPerVector S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder
8, // ABlockTransferDstScalarPerVector_K1 2, // index_t ABlockTransferSrcVectorDim
true, // ABlockLdsAddExtraM 8, // index_t ABlockTransferSrcScalarPerVector
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 8, // index_t ABlockTransferDstScalarPerVector_AK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder 1, // index_t ABlockLdsExtraM
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1
2, // BBlockTransferSrcVectorDim S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder
8, // BBlockTransferSrcScalarPerVector S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder
8, // BBlockTransferDstScalarPerVector_K1 2, // index_t BBlockTransferSrcVectorDim
true, // BBlockLdsAddExtraN 8, // index_t BBlockTransferSrcScalarPerVector
1, // CShuffleMXdlPerWavePerShuffle 8, // index_t BBlockTransferDstScalarPerVector_BK1
1, // CShuffleNXdlPerWavePerShuffle 1, // index_t BBlockLdsExtraN
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl 1, // index_t CShuffleMXdlPerWavePerShuffle
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 1, // index_t CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include "check_err.hpp" #include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
...@@ -12,7 +11,6 @@ ...@@ -12,7 +11,6 @@
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
...@@ -46,11 +44,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -46,11 +44,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| //######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -20,64 +19,63 @@ ...@@ -20,64 +19,63 @@
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = int8_t; using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
using CDataType = int32_t; using CDataType = int8_t;
using AccDataType = int32_t; using AccDataType = int32_t;
using CShuffleDataType = int32_t; using CShuffleDataType = int8_t;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ADataType, // ADataType ALayout, // typename ALayout
BDataType, // BDataType BLayout, // typename BLayout
CDataType, // CDataType CLayout, // typename CLayout
AccDataType, // AccDataType ADataType, // typename ADataType
CShuffleDataType, // CShuffleDataType BDataType, // typename BDataType
ALayout, // ALayout CDataType, // typename CDataType
BLayout, // BLayout AccDataType, // typename GemmAccDataType
CLayout, // CLayout CShuffleDataType, // typename CShuffleDataType
PassThrough, // AElementwiseOperation PassThrough, // typename AElementwiseOperation
PassThrough, // BElementwiseOperation PassThrough, // typename BElementwiseOperation
PassThrough, // CElementwiseOperation PassThrough, // typename CElementwiseOperation
256, // BlockSize GemmDefault, // GemmSpecialization GemmSpec
256, // MPerBlock 1, // index_t NumGemmKPrefetchStage
128, // NPerBlock 256, // index_t BlockSize
64, // KPerBlock 256, // index_t MPerBlock
16, // AK1 128, // index_t NPerBlock
16, // BK1 64, // index_t KPerBlock
32, // MPerXDL 16, // index_t AK1
32, // NPerXDL 16, // index_t BK1
4, // MXdlPerWave 32, // index_t MPerXDL
2, // NXdlPerWave 32, // index_t NPerXDL
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 4, // index_t MXdlPerWave
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder 2, // index_t NXdlPerWave
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1
2, // ABlockTransferSrcVectorDim S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder
16, // ABlockTransferSrcScalarPerVector S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder
16, // ABlockTransferDstScalarPerVector_K1 2, // index_t ABlockTransferSrcVectorDim
true, // ABlockLdsAddExtraM 16, // index_t ABlockTransferSrcScalarPerVector
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 16, // index_t ABlockTransferDstScalarPerVector_AK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder 1, // index_t ABlockLdsExtraM
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1
2, // BBlockTransferSrcVectorDim S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder
16, // BBlockTransferSrcScalarPerVector S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder
16, // BBlockTransferDstScalarPerVector_K1 2, // index_t BBlockTransferSrcVectorDim
true, // BBlockLdsAddExtraN 8, // index_t BBlockTransferSrcScalarPerVector
1, // CShuffleMXdlPerWavePerShuffle 8, // index_t BBlockTransferDstScalarPerVector_BK1
1, // CShuffleNXdlPerWavePerShuffle 1, // index_t BBlockLdsExtraN
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl 1, // index_t CShuffleMXdlPerWavePerShuffle
4>; // CBlockTransferScalarPerVector_NWaveNPerXdl 1, // index_t CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -13,74 +13,91 @@ ...@@ -13,74 +13,91 @@
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
template <ck::index_t... Is> struct RequantReluRequant
using S = ck::Sequence<Is...>; {
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu)
{
}
using F32 = float; __host__ __device__ constexpr void operator()(float& y, const float& x) const
{
float gemm_requant = scaleGemm_ * x;
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = relu_requant > 127 ? 127 : relu_requant < -128 ? -128 : relu_requant;
}
using Row = ck::tensor_layout::gemm::RowMajor; float scaleGemm_;
using Col = ck::tensor_layout::gemm::ColumnMajor; float scaleRelu_;
};
using PassThrough = ck::tensor_operation::element_wise::PassThrough; template <ck::index_t... Is>
using RequantReluRequant = ck::tensor_operation::element_wise::RequantReluRequant; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = int8_t; using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
using CDataType = int8_t; using CDataType = int8_t;
using AccDataType = int32_t; using AccDataType = int32_t;
using CShuffleDataType = int32_t; using CShuffleDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ADataType, // ADataType ALayout, // typename ALayout,
BDataType, // BDataType BLayout, // typename BLayout,
CDataType, // CDataType CLayout, // typename CLayout,
AccDataType, // AccDataType ADataType, // typename ADataType,
CShuffleDataType, // CShuffleDataType BDataType, // typename BDataType,
ALayout, // ALayout CDataType, // typename CDataType,
BLayout, // BLayout AccDataType, // typename GemmAccDataType,
CLayout, // CLayout CShuffleDataType, // typename CShuffleDataType,
PassThrough, // AElementwiseOperation PassThrough, // typename AElementwiseOperation,
PassThrough, // BElementwiseOperation PassThrough, // typename BElementwiseOperation,
RequantReluRequant, // CElementwiseOperation RequantReluRequant, // typename CElementwiseOperation,
256, // BlockSize GemmDefault, // GemmSpecialization GemmSpec,
256, // MPerBlock 1, // index_t NumGemmKPrefetchStage,
128, // NPerBlock 256, // index_t BlockSize,
64, // KPerBlock 256, // index_t MPerBlock,
16, // AK1 128, // index_t NPerBlock,
16, // BK1 64, // index_t KPerBlock,
32, // MPerXDL 16, // index_t AK1,
32, // NPerXDL 16, // index_t BK1,
4, // MXdlPerWave 32, // index_t MPerXDL,
2, // NXdlPerWave 32, // index_t NPerXDL,
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 4, // index_t MXdlPerWave,
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder 2, // index_t NXdlPerWave,
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
2, // ABlockTransferSrcVectorDim S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder,
16, // ABlockTransferSrcScalarPerVector S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder,
16, // ABlockTransferDstScalarPerVector_K1 2, // index_t ABlockTransferSrcVectorDim,
true, // ABlockLdsAddExtraM 16, // index_t ABlockTransferSrcScalarPerVector,
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 16, // index_t ABlockTransferDstScalarPerVector_AK1,
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder 1, // bool ABlockLdsExtraM,
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
2, // BBlockTransferSrcVectorDim S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder,
16, // BBlockTransferSrcScalarPerVector S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder,
16, // BBlockTransferDstScalarPerVector_K1 2, // index_t BBlockTransferSrcVectorDim,
true, // BBlockLdsAddExtraN 8, // index_t BBlockTransferSrcScalarPerVector,
1, // CShuffleMXdlPerWavePerShuffle 8, // index_t BBlockTransferDstScalarPerVector_BK1,
1, // CShuffleNXdlPerWavePerShuffle 1, // bool BBlockLdsExtraN,
S<1, 1, 64, 1, 1, 4>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl 1, // index_t CShuffleMXdlPerWavePerShuffle,
16>; // CBlockTransferScalarPerVector_NWaveNPerXdl 1, // index_t CShuffleNXdlPerWavePerShuffle,
S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -26,17 +26,14 @@ ...@@ -26,17 +26,14 @@
#endif #endif
#endif #endif
// buffer resourse, wave size // buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1 #define CK_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_GPU_WAVE_SIZE -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) // for GPU code defined(__gfx90a__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_GPU_WAVE_SIZE 64
#elif defined(__gfx1030__) // for GPU code #elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_GPU_WAVE_SIZE 32
#endif #endif
// FMA instruction // FMA instruction
......
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP #pragma once
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp" #include "xdlops_gemm.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "thread_group.hpp"
namespace ck { namespace ck {
...@@ -25,7 +24,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -25,7 +24,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr index_t WaveSize = 64; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
...@@ -55,7 +56,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -55,7 +56,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__device__ static auto GetWaveIdx() __device__ static auto GetWaveIdx()
{ {
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
...@@ -122,8 +123,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -122,8 +123,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
BK0NK1BlockDesc::IsKnownAtCompileTime(), BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!"); "wrong!");
...@@ -339,4 +340,3 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -339,4 +340,3 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
}; };
} // namespace ck } // namespace ck
#endif
...@@ -45,8 +45,8 @@ struct BlockwiseTensorSliceTransfer_v5r1 ...@@ -45,8 +45,8 @@ struct BlockwiseTensorSliceTransfer_v5r1
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>()) src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() && static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
......
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP #pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
...@@ -13,7 +11,7 @@ namespace ck { ...@@ -13,7 +11,7 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <typename ThreadGroup,
typename SrcElementwiseOperation, typename SrcElementwiseOperation,
typename DstElementwiseOperation, typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
...@@ -35,7 +33,7 @@ template <index_t BlockSize, ...@@ -35,7 +33,7 @@ template <index_t BlockSize,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun,
index_t NumThreadScratch = 1> index_t NumThreadScratch = 1>
struct BlockwiseTensorSliceTransfer_v4r1 struct ThreadGroupTensorSliceTransfer_v4r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
...@@ -43,7 +41,7 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -43,7 +41,7 @@ struct BlockwiseTensorSliceTransfer_v4r1
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1( __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op, const SrcElementwiseOperation& src_element_op,
...@@ -58,8 +56,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -58,8 +56,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
dst_element_op) dst_element_op)
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() && static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
...@@ -69,14 +67,14 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -69,14 +67,14 @@ struct BlockwiseTensorSliceTransfer_v4r1
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small"); "wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
...@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{}) Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
} }
...@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
DstBuffer& dst_buf, DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{}) Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
} }
...@@ -124,8 +122,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -124,8 +122,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
} }
...@@ -133,8 +131,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -133,8 +131,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
} }
...@@ -169,4 +167,3 @@ struct BlockwiseTensorSliceTransfer_v4r1 ...@@ -169,4 +167,3 @@ struct BlockwiseTensorSliceTransfer_v4r1
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP #pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
...@@ -13,10 +11,10 @@ namespace ck { ...@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <typename ThreadGroup,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename SliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename SrcData, typename SrcData,
...@@ -28,19 +26,19 @@ template <index_t BlockSize, ...@@ -28,19 +26,19 @@ template <index_t BlockSize,
index_t ScalarPerVector, index_t ScalarPerVector,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r1 struct ThreadGroupTensorSliceTransfer_v6r1
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, __device__ constexpr ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin, const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op) const ElementwiseOperation& element_op)
: threadwise_transfer_(src_desc, : threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
dst_desc, dst_desc,
...@@ -48,25 +46,25 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -48,25 +46,25 @@ struct BlockwiseTensorSliceTransfer_v6r1
element_op) element_op)
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() && static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(), nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small"); "wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
...@@ -83,8 +81,8 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -83,8 +81,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf); threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf);
} }
...@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
} }
...@@ -101,8 +99,8 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -101,8 +99,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
} }
...@@ -130,4 +128,3 @@ struct BlockwiseTensorSliceTransfer_v6r1 ...@@ -130,4 +128,3 @@ struct BlockwiseTensorSliceTransfer_v6r1
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP #pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
...@@ -13,10 +11,10 @@ namespace ck { ...@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. It does not keep reference to tensor descriptor // 2. It does not keep reference to tensor descriptor
// 3. Run() does not construct new tensor coordinate // 3. Run() does not construct new tensor coordinate
template <index_t BlockSize, template <typename ThreadGroup,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename SliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename Src0Data, typename Src0Data,
...@@ -31,21 +29,21 @@ template <index_t BlockSize, ...@@ -31,21 +29,21 @@ template <index_t BlockSize,
bool ThreadTransferSrc0ResetCoordinateAfterRun, bool ThreadTransferSrc0ResetCoordinateAfterRun,
bool ThreadTransferSrc1ResetCoordinateAfterRun, bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r2 struct ThreadGroupTensorSliceTransfer_v6r2
{ {
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, __device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin, const Index& src0_block_slice_origin,
const Src1Desc& src1_desc, const Src1Desc& src1_desc,
const Index& src1_block_slice_origin, const Index& src1_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin, const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op) const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc, : threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
src1_desc, src1_desc,
...@@ -55,26 +53,26 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -55,26 +53,26 @@ struct BlockwiseTensorSliceTransfer_v6r2
element_op) element_op)
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() && static_assert(nDim == remove_cvref_t<Src0Desc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() && nDim == remove_cvref_t<Src1Desc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(), nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small"); "wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
...@@ -95,8 +93,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -95,8 +93,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf); threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf);
} }
...@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
} }
...@@ -113,8 +111,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -113,8 +111,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
} }
...@@ -122,8 +120,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -122,8 +120,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
} }
...@@ -154,4 +152,3 @@ struct BlockwiseTensorSliceTransfer_v6r2 ...@@ -154,4 +152,3 @@ struct BlockwiseTensorSliceTransfer_v6r2
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP #pragma once
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
...@@ -13,10 +11,10 @@ namespace ck { ...@@ -13,10 +11,10 @@ namespace ck {
// 1. Use StaticallyIndexedArray instead of C array for thread buffer // 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <typename ThreadGroup,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename SliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename Src0Data, typename Src0Data,
...@@ -34,23 +32,23 @@ template <index_t BlockSize, ...@@ -34,23 +32,23 @@ template <index_t BlockSize,
bool ThreadTransferSrc1ResetCoordinateAfterRun, bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferSrc2ResetCoordinateAfterRun, bool ThreadTransferSrc2ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r3 struct ThreadGroupTensorSliceTransfer_v6r3
{ {
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, __device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin, const Index& src0_block_slice_origin,
const Src1Desc& src1_desc, const Src1Desc& src1_desc,
const Index& src1_block_slice_origin, const Index& src1_block_slice_origin,
const Src2Desc& src2_desc, const Src2Desc& src2_desc,
const Index& src2_block_slice_origin, const Index& src2_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin, const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op) const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc, : threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
src1_desc, src1_desc,
...@@ -62,24 +60,24 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -62,24 +60,24 @@ struct BlockwiseTensorSliceTransfer_v6r3
element_op) element_op)
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() && static_assert(nDim == remove_cvref_t<Src0Desc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() && nDim == remove_cvref_t<Src1Desc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src2Desc>>::GetNumOfDimension() && nDim == remove_cvref_t<Src2Desc>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() && nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(), nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent"); "wrong! nDim not consistent");
static_assert( static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small"); "wrong! ThreadGroup::GetNumOfThread() too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id())); make_multi_index(get_thread_local_1d_id()));
...@@ -107,8 +105,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -107,8 +105,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.Run( threadwise_transfer_.Run(
src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf); src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf);
...@@ -117,8 +115,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -117,8 +115,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
} }
...@@ -126,8 +124,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -126,8 +124,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
} }
...@@ -135,8 +133,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -135,8 +133,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step) __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step); threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
} }
...@@ -144,8 +142,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -144,8 +142,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
} }
...@@ -179,4 +177,3 @@ struct BlockwiseTensorSliceTransfer_v6r3 ...@@ -179,4 +177,3 @@ struct BlockwiseTensorSliceTransfer_v6r3
}; };
} // namespace ck } // namespace ck
#endif
...@@ -720,11 +720,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -720,11 +720,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const index_t grid_size = const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
if(has_main_k0_block_loop)
{ {
const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
......
...@@ -455,13 +455,12 @@ struct DeviceBatchedGemmXdl ...@@ -455,13 +455,12 @@ struct DeviceBatchedGemmXdl
const index_t grid_size = const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_batched_gemm_xdlops_v2r3< const auto kernel = kernel_batched_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
......
...@@ -582,11 +582,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -582,11 +582,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t grid_size = const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
const auto K0 = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0); const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
if(has_main_k0_block_loop)
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
......
...@@ -698,13 +698,12 @@ struct ...@@ -698,13 +698,12 @@ struct
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v3r3< const auto kernel = kernel_gemm_xdlops_v3r3<
GridwiseGemm, GridwiseGemm,
......
#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP #pragma once
#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device.hpp" #include "device.hpp"
...@@ -660,13 +658,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -660,13 +658,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v3r2< const auto kernel = kernel_gemm_xdlops_v3r2<
GridwiseGemm, GridwiseGemm,
...@@ -919,4 +916,3 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -919,4 +916,3 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -640,13 +640,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -640,13 +640,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v3r1< const auto kernel = kernel_gemm_xdlops_v3r1<
GridwiseGemm, GridwiseGemm,
......
...@@ -478,13 +478,12 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -478,13 +478,12 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
......
...@@ -1296,11 +1296,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1296,11 +1296,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
const index_t grid_size = const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
const auto K0 = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0); const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
if(has_main_k0_block_loop)
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
......
...@@ -775,13 +775,12 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -775,13 +775,12 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment