Commit 09ec28be authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into gelu

parents b9d3d277 85fc91c3
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_f64_f64_f64_km_nk_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_km_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>,
DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add<F32>;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>; ck::InMemoryDataOperationEnum::AtomicAdd>;
...@@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// c[m, n] = a[k, m] * b[k, n] // c[m, n] = a[k, m] * b[k, n]
using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple< using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple<
// clang-format off // clang-format off
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
...@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add<F32>;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>; ck::InMemoryDataOperationEnum::AtomicAdd>;
...@@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// c[m, n] = a[k, m] * b[n, k] // c[m, n] = a[k, m] * b[n, k]
using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = std::tuple< using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = std::tuple<
// clang-format off // clang-format off
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
...@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add<F32>;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>; ck::InMemoryDataOperationEnum::AtomicAdd>;
...@@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// c[m, n] = a[m, k] * b[n, k] // c[m, n] = a[m, k] * b[n, k]
using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = std::tuple< using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = std::tuple<
// clang-format off // clang-format off
//###########################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //###########################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
...@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add<F32>;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>; ck::InMemoryDataOperationEnum::AtomicAdd>;
...@@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// c[m, n] = a[m, k] * b[n, k] // c[m, n] = a[m, k] * b[n, k]
using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = std::tuple< using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = std::tuple<
// clang-format off // clang-format off
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
......
...@@ -98,6 +98,7 @@ namespace profiler { ...@@ -98,6 +98,7 @@ namespace profiler {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout>
...@@ -511,8 +512,14 @@ void profile_gemm_impl(int do_verification, ...@@ -511,8 +512,14 @@ void profile_gemm_impl(int do_verification,
bf16_to_f32_(b_k_n, b_f32_k_n); bf16_to_f32_(b_k_n, b_f32_k_n);
bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result); bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result);
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance =
ReferenceGemm<float, float, float, AElementOp, BElementOp, CElementOp>; ck::tensor_operation::host::ReferenceGemm<float,
float,
float,
float,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
...@@ -544,6 +551,7 @@ void profile_gemm_impl(int do_verification, ...@@ -544,6 +551,7 @@ void profile_gemm_impl(int do_verification,
ck::tensor_operation::host::ReferenceGemm<ADataType, ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CElementOp>; CElementOp>;
......
...@@ -19,10 +19,11 @@ namespace device_gemm_instance { ...@@ -19,10 +19,11 @@ namespace device_gemm_instance {
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using DPtrsGlobal = ck::Tuple<F32*, F32*>; using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Div, Div>;
using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<
DPtrsGlobal, DPtrsGlobal,
...@@ -127,25 +128,32 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -127,25 +128,32 @@ bool profile_gemm_reduce_impl(int do_verification,
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add<float>; using D0ReduceOp = ck::reduce::Add<float>;
using D1ReduceOp = ck::reduce::Add<float>; using D1ReduceOp = ck::reduce::Add<float>;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryIdentic<float, float, true>;
using UnaryIdenticElementOp = using UnaryIdenticElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<float, float, false>; ck::tensor_operation::element_wise::UnaryIdentic<float, float, false>;
using UnarySquareElementOp = using UnarySquareElementOp =
ck::tensor_operation::element_wise::UnarySquare<float, float, false>; ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>; using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
const auto a_element_op = AElementOp{}; const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{}; const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{}; const auto c_element_op = CElementOp{};
const auto dxs_in_element_op = DxsInElementOps{};
const auto dxs_out_element_op = DxsOutElementOps{};
const auto d0_reduce_op = D0ReduceOp{}; const auto d0_reduce_op = D0ReduceOp{};
const auto d1_reduce_op = D1ReduceOp{}; const auto d1_reduce_op = D1ReduceOp{};
auto dxs_in_element_op = DxsInElementOps{};
auto dxs_out_element_op = DxsOutElementOps{M, M};
if(do_verification) if(do_verification)
{ {
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>; BDataType,
CDataType,
DDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
...@@ -162,14 +170,18 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -162,14 +170,18 @@ bool profile_gemm_reduce_impl(int do_verification,
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
float d0_val = ck::type_convert<float>(c_m_n_host_result(m, n)); float c_val = ck::type_convert<float>(c_m_n_host_result(m, n));
float d1_val; float d0_val = 0;
float d1_val = 0;
UnarySquareElementOp{}(d1_val, d0_val); dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
d0_reduce_op(d0_acc, d0_val); d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val); d1_reduce_op(d1_acc, d1_val);
} }
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc);
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc);
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc); d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc); d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
} }
......
...@@ -43,6 +43,7 @@ namespace profiler { ...@@ -43,6 +43,7 @@ namespace profiler {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout>
...@@ -271,6 +272,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -271,6 +272,7 @@ void profile_grouped_gemm_impl(int do_verification,
ck::tensor_operation::host::ReferenceGemm<ADataType, ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CElementOp>; CElementOp>;
......
...@@ -68,6 +68,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -68,6 +68,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::half_t, ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -88,6 +89,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -88,6 +89,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::half_t, ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -108,6 +110,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -108,6 +110,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::half_t, ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -128,6 +131,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -128,6 +131,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::half_t, ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -146,6 +150,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -146,6 +150,7 @@ int profile_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_gemm_impl<float, ck::profiler::profile_gemm_impl<float,
float,
float, float,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
...@@ -166,6 +171,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -166,6 +171,7 @@ int profile_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_gemm_impl<float, ck::profiler::profile_gemm_impl<float,
float,
float, float,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
...@@ -186,6 +192,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -186,6 +192,7 @@ int profile_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_gemm_impl<float, ck::profiler::profile_gemm_impl<float,
float,
float, float,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
...@@ -206,6 +213,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -206,6 +213,7 @@ int profile_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_gemm_impl<float, ck::profiler::profile_gemm_impl<float,
float,
float, float,
float, float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
...@@ -228,6 +236,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -228,6 +236,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<int8_t, ck::profiler::profile_gemm_impl<int8_t,
int8_t, int8_t,
int8_t, int8_t,
int32_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -248,6 +257,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -248,6 +257,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<int8_t, ck::profiler::profile_gemm_impl<int8_t,
int8_t, int8_t,
int8_t, int8_t,
int32_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -268,6 +278,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -268,6 +278,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<int8_t, ck::profiler::profile_gemm_impl<int8_t,
int8_t, int8_t,
int8_t, int8_t,
int32_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -288,6 +299,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -288,6 +299,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<int8_t, ck::profiler::profile_gemm_impl<int8_t,
int8_t, int8_t,
int8_t, int8_t,
int32_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -308,6 +320,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -308,6 +320,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::bhalf_t, ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -328,6 +341,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -328,6 +341,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::bhalf_t, ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -348,6 +362,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -348,6 +362,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::bhalf_t, ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
...@@ -368,6 +383,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -368,6 +383,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::bhalf_t, ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
ck::bhalf_t, ck::bhalf_t,
float,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(
......
...@@ -79,6 +79,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -79,6 +79,7 @@ int profile_grouped_gemm(int argc, char* argv[])
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_impl<ck::half_t, ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
...@@ -97,6 +98,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -97,6 +98,7 @@ int profile_grouped_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_impl<ck::half_t, ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
...@@ -115,6 +117,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -115,6 +117,7 @@ int profile_grouped_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_impl<ck::half_t, ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
...@@ -133,6 +136,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -133,6 +136,7 @@ int profile_grouped_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_grouped_gemm_impl<ck::half_t, ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
......
...@@ -46,6 +46,7 @@ int main() ...@@ -46,6 +46,7 @@ int main()
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor; using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
...@@ -63,6 +64,7 @@ int main() ...@@ -63,6 +64,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -81,6 +83,7 @@ int main() ...@@ -81,6 +83,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
...@@ -99,6 +102,7 @@ int main() ...@@ -99,6 +102,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -117,6 +121,7 @@ int main() ...@@ -117,6 +121,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
......
...@@ -46,6 +46,7 @@ int main() ...@@ -46,6 +46,7 @@ int main()
using ADataType = float; using ADataType = float;
using BDataType = float; using BDataType = float;
using CDataType = float; using CDataType = float;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor; using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
...@@ -61,6 +62,7 @@ int main() ...@@ -61,6 +62,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -79,6 +81,7 @@ int main() ...@@ -79,6 +81,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
...@@ -97,6 +100,7 @@ int main() ...@@ -97,6 +100,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -115,6 +119,7 @@ int main() ...@@ -115,6 +119,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
......
...@@ -46,6 +46,7 @@ int main() ...@@ -46,6 +46,7 @@ int main()
using ADataType = int8_t; using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
using CDataType = int8_t; using CDataType = int8_t;
using AccDataType = int;
using RowMajor = ck::tensor_layout::gemm::RowMajor; using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
...@@ -61,6 +62,7 @@ int main() ...@@ -61,6 +62,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -79,6 +81,7 @@ int main() ...@@ -79,6 +81,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
...@@ -97,6 +100,7 @@ int main() ...@@ -97,6 +100,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -115,6 +119,7 @@ int main() ...@@ -115,6 +119,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
......
...@@ -111,6 +111,7 @@ template <typename DeviceGemmPtr_, ...@@ -111,6 +111,7 @@ template <typename DeviceGemmPtr_,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -186,6 +187,7 @@ struct TestGemm ...@@ -186,6 +187,7 @@ struct TestGemm
ck::tensor_operation::host::ReferenceGemm<ADataType, ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>; CElementwiseOperation>;
...@@ -215,6 +217,11 @@ struct TestGemm ...@@ -215,6 +217,11 @@ struct TestGemm
res = ck::utils::check_err(c_device.mData, c_host.mData); res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
} }
else if(std::is_same<CDataType, double>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
return res; return res;
} }
...@@ -311,6 +318,7 @@ struct TestGemmBF16 ...@@ -311,6 +318,7 @@ struct TestGemmBF16
// use fp32 host kernel to verify bf16 device kernel // use fp32 host kernel to verify bf16 device kernel
using ReferenceGemmInstance = using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<float, ck::tensor_operation::host::ReferenceGemm<float,
float,
float, float,
float, float,
AElementwiseOperation, AElementwiseOperation,
......
...@@ -55,6 +55,7 @@ int main() ...@@ -55,6 +55,7 @@ int main()
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor; using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
...@@ -74,6 +75,7 @@ int main() ...@@ -74,6 +75,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -96,6 +98,7 @@ int main() ...@@ -96,6 +98,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
...@@ -118,6 +121,7 @@ int main() ...@@ -118,6 +121,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -142,6 +146,7 @@ int main() ...@@ -142,6 +146,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
......
...@@ -56,6 +56,7 @@ int main() ...@@ -56,6 +56,7 @@ int main()
using ADataType = float; using ADataType = float;
using BDataType = float; using BDataType = float;
using CDataType = float; using CDataType = float;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor; using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
...@@ -75,6 +76,7 @@ int main() ...@@ -75,6 +76,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -97,6 +99,7 @@ int main() ...@@ -97,6 +99,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
...@@ -119,6 +122,7 @@ int main() ...@@ -119,6 +122,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -141,6 +145,7 @@ int main() ...@@ -141,6 +145,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
......
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
inline std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return std::string();
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return std::string();
}
const std::string name(props.gcnArchName);
return name;
}
int main()
{
if(get_device_name().find("gfx90a") == std::string::npos)
{
std::cout << "TestGemm ..... SUCCESS" << std::endl;
return 0;
}
using ADataType = double;
using BDataType = double;
using CDataType = double;
using AccDataType = double;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
...@@ -45,6 +45,7 @@ int main() ...@@ -45,6 +45,7 @@ int main()
using ADataType = int8_t; using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
using CDataType = int8_t; using CDataType = int8_t;
using AccDataType = int32_t;
using RowMajor = ck::tensor_layout::gemm::RowMajor; using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
...@@ -61,6 +62,7 @@ int main() ...@@ -61,6 +62,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -79,6 +81,7 @@ int main() ...@@ -79,6 +81,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
...@@ -97,6 +100,7 @@ int main() ...@@ -97,6 +100,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -115,6 +119,7 @@ int main() ...@@ -115,6 +119,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment