Commit b58b98ff authored by Chao Liu's avatar Chao Liu
Browse files

add ckProfiler

parent 3d005816
...@@ -136,7 +136,11 @@ struct TensorAdaptor ...@@ -136,7 +136,11 @@ struct TensorAdaptor
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>; using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
public: public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorAdaptor() = default; __host__ __device__ constexpr TensorAdaptor() = default;
#else
__host__ __device__ constexpr TensorAdaptor() : transforms_{}, element_size_{} {}
#endif
__host__ __device__ constexpr TensorAdaptor(const Transforms& transforms) __host__ __device__ constexpr TensorAdaptor(const Transforms& transforms)
: transforms_{transforms}, element_size_{InitializeElementSize(transforms)} : transforms_{transforms}, element_size_{InitializeElementSize(transforms)}
......
...@@ -111,7 +111,14 @@ struct TensorDescriptor ...@@ -111,7 +111,14 @@ struct TensorDescriptor
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>; using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
public: public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorDescriptor() = default; __host__ __device__ constexpr TensorDescriptor() = default;
#else
__host__ __device__ constexpr TensorDescriptor()
: transforms_{}, element_size_{}, element_space_size_{}
{
}
#endif
__host__ __device__ constexpr TensorDescriptor(const Transforms& transforms, __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
ElementSpaceSize element_space_size) ElementSpaceSize element_space_size)
......
...@@ -18,7 +18,7 @@ struct TupleElementKey ...@@ -18,7 +18,7 @@ struct TupleElementKey
template <typename Key, typename Data> template <typename Key, typename Data>
struct TupleElementKeyData struct TupleElementKeyData
{ {
#if 0 #if 0 // workaround compiler complaint about implicitly-deleted default constructor
__host__ __device__ constexpr TupleElementKeyData() = default; __host__ __device__ constexpr TupleElementKeyData() = default;
#else #else
__host__ __device__ constexpr TupleElementKeyData() : mData{} {} __host__ __device__ constexpr TupleElementKeyData() : mData{} {}
......
#pragma once #pragma once
#include <stdlib.h> #include <vector>
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -12,10 +12,10 @@ namespace device_gemm_instance { ...@@ -12,10 +12,10 @@ namespace device_gemm_instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using F16_F16 = ck::Tuple<F16, F16> using F16_F16 = ck::Tuple<F16, F16>;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -28,7 +28,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -28,7 +28,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// e = elementwise((a * b), d) // e = elementwise((a * b), d)
// outout: e[m, n] // outout: e[m, n]
// input: a[k, m], b[k, n], d[m, n] // input: a[k, m], b[k, n], d[m, n]
using device_gemm_add_add_gelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple<
// clang-format off // clang-format off
//##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
......
...@@ -12,10 +12,10 @@ namespace device_gemm_instance { ...@@ -12,10 +12,10 @@ namespace device_gemm_instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using F16_F16 = ck::Tuple<F16, F16> using F16_F16 = ck::Tuple<F16, F16>;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -28,28 +28,28 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -28,28 +28,28 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// e = elementwise((a * b), d) // e = elementwise((a * b), d)
// outout: e[m, n] // outout: e[m, n]
// input: a[k, m], b[n, k], d[m, n] // input: a[k, m], b[n, k], d[m, n]
using device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple<
// clang-format off // clang-format off
//##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // clang-format on
>; >;
...@@ -57,8 +57,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc ...@@ -57,8 +57,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc
std::vector<DeviceGemmMultipleDPtr<2, PassThrough, PassThrough, AddAddFastGelu>>& instances) std::vector<DeviceGemmMultipleDPtr<2, PassThrough, PassThrough, AddAddFastGelu>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{});
device_gemm_gelu_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{});
} }
} // namespace device_gemm_instance } // namespace device_gemm_instance
......
#include <stdlib.h> #include <stdlib.h>
#include "config.hpp" #include "config.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "device_operation_instance.hpp" #include "device_operation_instance.hpp"
#include "device_gemm_multiple_d_xdl_cshuffle.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -12,10 +12,10 @@ namespace device_gemm_instance { ...@@ -12,10 +12,10 @@ namespace device_gemm_instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using F16_F16 = ck::Tuple<F16, F16> using F16_F16 = ck::Tuple<F16, F16>;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
......
...@@ -12,10 +12,10 @@ namespace device_gemm_instance { ...@@ -12,10 +12,10 @@ namespace device_gemm_instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using F16_F16 = ck::Tuple<F16, F16> using F16_F16 = ck::Tuple<F16, F16>;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
......
...@@ -24,7 +24,7 @@ include_directories(BEFORE ...@@ -24,7 +24,7 @@ include_directories(BEFORE
# ck_profiler # ck_profiler
set(PROFILER_SOURCE set(PROFILER_SOURCE
src/profiler.cpp src/profiler.cpp
src/profile_gemm.cpp # src/profile_gemm.cpp
# src/profile_gemm_bias_2d.cpp # src/profile_gemm_bias_2d.cpp
# src/profile_gemm_bias_relu.cpp # src/profile_gemm_bias_relu.cpp
# src/profile_gemm_bias_relu_add.cpp # src/profile_gemm_bias_relu_add.cpp
...@@ -47,7 +47,7 @@ add_executable(ckProfiler ${PROFILER_SOURCE}) ...@@ -47,7 +47,7 @@ add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE host_tensor)
target_link_libraries(ckProfiler PRIVATE conv_util) target_link_libraries(ckProfiler PRIVATE conv_util)
#target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "device_gemm.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "device_gemm_multiple_d.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -23,7 +23,7 @@ using DeviceGemmAddAddFastGeluPtr = ck::tensor_operation::device::DeviceGemmMult ...@@ -23,7 +23,7 @@ using DeviceGemmAddAddFastGeluPtr = ck::tensor_operation::device::DeviceGemmMult
2, 2,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::FastGelu>; ck::tensor_operation::element_wise::AddAddFastGelu>;
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmAddAddFastGeluPtr>&); std::vector<DeviceGemmAddAddFastGeluPtr>&);
...@@ -44,6 +44,7 @@ namespace profiler { ...@@ -44,6 +44,7 @@ namespace profiler {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename AccDataType,
typename D0DataType, typename D0DataType,
typename D1DataType, typename D1DataType,
typename EDataType, typename EDataType,
...@@ -54,7 +55,7 @@ template <typename ADataType, ...@@ -54,7 +55,7 @@ template <typename ADataType,
typename ELayout> typename ELayout>
int profile_gemm_add_add_fastgelu_impl(int do_verification, int profile_gemm_add_add_fastgelu_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool /*do_log*/,
bool time_kernel, bool time_kernel,
int M, int M,
int N, int N,
...@@ -131,28 +132,32 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -131,28 +132,32 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
is_same_v<ELayout, tensor_layout::gemm::RowMajor>) is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(device_op_ptrs); add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
device_op_ptrs);
} }
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> && else if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> &&
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> && is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<ELayout, tensor_layout::gemm::RowMajor>) is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(device_op_ptrs); add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
device_op_ptrs);
} }
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> && else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<BLayout, tensor_layout::gemm::RowMajor> && is_same_v<BLayout, tensor_layout::gemm::RowMajor> &&
is_same_v<ELayout, tensor_layout::gemm::RowMajor>) is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(device_op_ptrs); add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
device_op_ptrs);
} }
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> && else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> && is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<ELayout, tensor_layout::gemm::RowMajor>) is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(device_op_ptrs); add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
device_op_ptrs);
} }
} }
......
...@@ -22,16 +22,16 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) ...@@ -22,16 +22,16 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
enum struct MatrixDataType enum struct MatrixDataType
{ {
F32_F32_F32_F32_F32, // 0 F32_F32_F32_F32_F32, // 0
F16_F16_F16_F16_F16_F16_F16, // 1 F16_F16_F16_F16_F16, // 1
BF16_BF16_BF16_BF16_BF16, // 2 BF16_BF16_BF16_BF16_BF16, // 2
INT8_INT8_INT8_INT8_INT8, // 3 INT8_INT8_INT8_INT8_INT8, // 3
}; };
if(argc != 16) if(argc != 16)
{ {
// clang-format off // clang-format off
printf("arg1: tensor operation (gemm_gelu: GEMM+Add+Add+GeLU)\n"); printf("arg1: tensor operation (gemm_add_add_fastgelu: GEMM+Add+Add+GeLU)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
printf("arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n] + D0[m, n] + D1[m, n]);\n"); printf("arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n] + D0[m, n] + D1[m, n]);\n");
printf(" 1: E[m, n] = FastGeLU(A[m, k] * B[n, k] + D0[m, n] + D1[m, n]);\n"); printf(" 1: E[m, n] = FastGeLU(A[m, k] * B[n, k] + D0[m, n] + D1[m, n]);\n");
...@@ -40,7 +40,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) ...@@ -40,7 +40,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
printf("arg4: verification (0: no; 1: yes)\n"); printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n"); printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n");
// clang-format on // clang-format on
exit(1); exit(1);
...@@ -64,12 +64,14 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) ...@@ -64,12 +64,14 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
const int StrideE = std::stoi(argv[15]); const int StrideE = std::stoi(argv[15]);
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type, auto profile = [&](auto a_type,
auto b_type, auto b_type,
auto acc_type,
auto d0_type, auto d0_type,
auto d1_type, auto d1_type,
auto e_type, auto e_type,
...@@ -78,11 +80,12 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) ...@@ -78,11 +80,12 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
auto d0_layout, auto d0_layout,
auto d1_layout, auto d1_layout,
auto e_layout) { auto e_layout) {
using ADataType = decltype(a_type); using ADataType = decltype(a_type);
using BDataType = decltype(b_type); using BDataType = decltype(b_type);
using D0DataType = decltype(d0_type); using AccDataType = decltype(acc_type);
using D1DataType = decltype(d1_type); using D0DataType = decltype(d0_type);
using EDataType = decltype(e_type); using D1DataType = decltype(d1_type);
using EDataType = decltype(e_type);
using ALayout = decltype(a_layout); using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout); using BLayout = decltype(b_layout);
...@@ -96,16 +99,17 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) ...@@ -96,16 +99,17 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? N : M; const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? N : M;
const int DefaultStrideE = ck::is_same_v<ELayout, Row> ? N : M; const int DefaultStrideE = ck::is_same_v<ELayout, Row> ? N : M;
return ck::profiler::profile_gemm_add_add_gelu_impl<ADataType, return ck::profiler::profile_gemm_add_add_fastgelu_impl<ADataType,
BDataType, BDataType,
D0DataType, AccDataType,
D1DataType, D0DataType,
EDataType, D1DataType,
ALayout, EDataType,
BLayout, ALayout,
D0Layout, BLayout,
D1Layout, D0Layout,
ELayout>( D1Layout,
ELayout>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -122,22 +126,22 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) ...@@ -122,22 +126,22 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN) if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN)
{ {
return profile(F16{}, F16{}, F16{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}, Row{});
} }
else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
layout == MatrixLayout::MK_NK_MN_MN_MN) layout == MatrixLayout::MK_NK_MN_MN_MN)
{ {
return profile(F16{}, F16{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
} }
else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
layout == MatrixLayout::KM_KN_MN_MN_MN) layout == MatrixLayout::KM_KN_MN_MN_MN)
{ {
return profile(F16{}, F16{}, F16{}, F16{}, F16{}, Col{}, Row{}, Row{}, Row{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Row{}, Row{}, Row{}, Row{});
} }
else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
layout == MatrixLayout::KM_NK_MN_MN_MN) layout == MatrixLayout::KM_NK_MN_MN_MN)
{ {
return profile(F16{}, F16{}, F16{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{}, Row{});
} }
else else
{ {
......
...@@ -54,11 +54,11 @@ int main(int argc, char* argv[]) ...@@ -54,11 +54,11 @@ int main(int argc, char* argv[])
return 0; return 0;
} }
#if 0
if(strcmp(argv[1], "gemm") == 0) if(strcmp(argv[1], "gemm") == 0)
{ {
return profile_gemm(argc, argv); return profile_gemm(argc, argv);
} }
#if 0
else if(strcmp(argv[1], "gemm_bias_2d") == 0) else if(strcmp(argv[1], "gemm_bias_2d") == 0)
{ {
return profile_gemm_bias_2d(argc, argv); return profile_gemm_bias_2d(argc, argv);
...@@ -124,7 +124,7 @@ int main(int argc, char* argv[]) ...@@ -124,7 +124,7 @@ int main(int argc, char* argv[])
return profile_conv_bwd_weight(argc, argv); return profile_conv_bwd_weight(argc, argv);
} }
#endif #endif
else if(strcmp(argv[1], "gemm_add_add_fastgelu") == 0) if(strcmp(argv[1], "gemm_add_add_fastgelu") == 0)
{ {
return profile_gemm_add_add_fastgelu(argc, argv); return profile_gemm_add_add_fastgelu(argc, argv);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment