Unverified Commit 6b1490c9 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into aosewski/gemm_tile_loop

parents 271269a5 a3c80265
...@@ -9,9 +9,11 @@ ...@@ -9,9 +9,11 @@
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -19,19 +21,37 @@ namespace tensor_operation { ...@@ -19,19 +21,37 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using F64 = double;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// k/n/n/n are the fast changing dimension for A/B/D/E // k/n/n/n are the fast changing dimension for A/B/D/E
using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance = using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance = std::tuple<
device_contraction_f64_kn_instance<F64, // clang-format off
F64, //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
F64, //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
F64, //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
Empty_Tuple, //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
F64, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>,
F64, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
PassThrough, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>,
PassThrough, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>,
Scale>; DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
// clang-format on
>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -43,8 +63,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instanc ...@@ -43,8 +63,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instanc
F64, F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale>>>& instances)
F64>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance{}); instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance{});
......
...@@ -9,9 +9,11 @@ ...@@ -9,9 +9,11 @@
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -19,19 +21,37 @@ namespace tensor_operation { ...@@ -19,19 +21,37 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using F64 = double;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// m/k/n/n are the fast changing dimension for A/B/D/E // m/k/n/n are the fast changing dimension for A/B/D/E
using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance = using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance = std::tuple<
device_contraction_f64_mk_instance<F64, // clang-format off
F64, //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
F64, //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
F64, //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
Empty_Tuple, //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
F64, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
F64, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
PassThrough, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>,
PassThrough, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>,
Scale>; DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
// clang-format on
>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -43,8 +63,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instanc ...@@ -43,8 +63,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instanc
F64, F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale>>>& instances)
F64>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance{}); instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance{});
......
...@@ -9,9 +9,11 @@ ...@@ -9,9 +9,11 @@
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -19,19 +21,37 @@ namespace tensor_operation { ...@@ -19,19 +21,37 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using F64 = double;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// m/n/n/n are the fast changing dimension for A/B/D/E // m/n/n/n are the fast changing dimension for A/B/D/E
using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance = using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance = std::tuple<
device_contraction_f64_mn_instance<F64, // clang-format off
F64, //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
F64, //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
F64, //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
Empty_Tuple, //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
F64, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>,
F64, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
PassThrough, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>,
PassThrough, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>,
Scale>; DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
// clang-format on
>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -43,8 +63,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc ...@@ -43,8 +63,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
F64, F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale>>>& instances)
F64>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance{}); instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance{});
......
...@@ -50,23 +50,21 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s ...@@ -50,23 +50,21 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s
## Profile contraction kernels ## Profile contraction kernels
```bash ```bash
#arg1: tensor operation (contraction_bilinear=CONTRACTION+Bilinear) #arg1: tensor operation (contraction_bilinear=CONTRACTION+Bilinear)
#arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16) #arg2: data type (0: fp32; 1: f64)\n"
#arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16) #arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
#arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]) # 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1])
#arg5: verification (0: no; 1: yes) #arg4: verification (0: no; 1: yes)
#arg6: initialization (0: no init; 1: integer value; 2: decimal value) #arg5: initialization (0: no init; 1: integer value; 2: decimal value)
#arg7: print tensor value (0: no; 1: yes) #arg6: print tensor value (0: no; 1: yes)
#arg8: time kernel (0: no, 1: yes) #arg7: time kernel (0: no, 1: yes)
#arg9: alpha #arg8 and arg9: alpha and beta
#arg10: beta #arg10 to 15: M0, M1, N0, N1, K0, K1
#arg11 to 16: M0, M1, N0, N1, K0, K1 #arg16 to 31: Strides for A, B, D and E (skip for default)
#arg17 to 32: Strides for A, B, D and E (skip for default)
################ op datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1
################ op datatype compute_datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1 ./bin/ckProfiler contraction_bilinear 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128
./bin/ckProfiler contraction_bilinear 0 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128
``` ```
Result (MI100) Result (MI100)
......
...@@ -31,14 +31,10 @@ namespace profiler { ...@@ -31,14 +31,10 @@ namespace profiler {
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using F32 = float;
using F64 = double;
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CDELayout, typename CDELayout,
typename DataType, typename DataType,
typename ComputeDataType,
typename DTupleDataType, typename DTupleDataType,
typename CDElementOp> typename CDElementOp>
int profile_contraction_impl(ck::index_t do_verification, int profile_contraction_impl(ck::index_t do_verification,
...@@ -49,10 +45,10 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -49,10 +45,10 @@ int profile_contraction_impl(ck::index_t do_verification,
const std::vector<ck::index_t>& M, const std::vector<ck::index_t>& M,
const std::vector<ck::index_t>& N, const std::vector<ck::index_t>& N,
const std::vector<ck::index_t>& K, const std::vector<ck::index_t>& K,
const std::vector<ck::index_t>& StridesA, // [M0, M1, K0, K1] const std::vector<ck::index_t>& StridesA,
const std::vector<ck::index_t>& StridesB, // [N0, N1, K0, K1] const std::vector<ck::index_t>& StridesB,
const std::vector<ck::index_t>& StridesE, // [M0, M1, N0, N1] const std::vector<ck::index_t>& StridesE,
const std::vector<ck::index_t>& StridesD) // [M0, M1, N0, N1] const std::vector<ck::index_t>& StridesD)
{ {
bool pass = true; bool pass = true;
...@@ -67,13 +63,13 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -67,13 +63,13 @@ int profile_contraction_impl(ck::index_t do_verification,
}; };
Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA)); Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA));
Tensor<DataType> b_n_k(f_host_tensor_descriptor(N, K, StridesB)); Tensor<DataType> b_k_n(f_host_tensor_descriptor(K, N, StridesB));
Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE));
Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD)); Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_n_k: " << b_n_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
...@@ -82,12 +78,12 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -82,12 +78,12 @@ int profile_contraction_impl(ck::index_t do_verification,
case 0: break; case 0: break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
b_n_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); d_m_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0});
b_n_k.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); d_m_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
} }
...@@ -95,12 +91,12 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -95,12 +91,12 @@ int profile_contraction_impl(ck::index_t do_verification,
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(DataType) * b_n_k.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(DataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_n_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
e_device_buf.SetZero(); e_device_buf.SetZero();
d_device_buf.ToDevice(d_m_n.mData.data()); d_device_buf.ToDevice(d_m_n.mData.data());
...@@ -122,8 +118,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -122,8 +118,7 @@ int profile_contraction_impl(ck::index_t do_verification,
DataType, DataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDElementOp, CDElementOp>;
ComputeDataType>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -131,9 +126,6 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -131,9 +126,6 @@ int profile_contraction_impl(ck::index_t do_verification,
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
using AccDataType =
typename std::conditional<std::is_same<ComputeDataType, F64>::value, F64, F32>::type;
// Run reference op // Run reference op
if(do_verification) if(do_verification)
{ {
...@@ -144,8 +136,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -144,8 +136,7 @@ int profile_contraction_impl(ck::index_t do_verification,
DataType, DataType,
DataType, DataType,
DataType, DataType,
AccDataType, DataType,
ComputeDataType,
AElementOp, AElementOp,
BElementOp>; BElementOp>;
...@@ -155,7 +146,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -155,7 +146,7 @@ int profile_contraction_impl(ck::index_t do_verification,
Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
auto ref_argument = auto ref_argument =
ref_op.MakeArgument(a_m_k, b_n_k, c_m_n_host_result, a_element_op, b_element_op); ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
...@@ -281,29 +272,8 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -281,29 +272,8 @@ int profile_contraction_impl(ck::index_t do_verification,
{ {
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
// Both the kernel and the reference use `AccDataType`, so an absolute error of both float threshold =
// of them is bounded by `nelems_k * std::numeric_limits<AccDataType>::epsilon()`. static_cast<DataType>(nelems_k) * std::numeric_limits<DataType>::epsilon();
// Comparing one to another can result in an absolute error as high as twice that
// value.
double threshold = 2 * nelems_k * std::numeric_limits<AccDataType>::epsilon();
// Handle the possible casting error of either AccDataType -> DataType or
// DataType -> ComputeDataType.
// TODO: Add a generic solution for calculating thresholds in CK.
if constexpr(ck::is_same_v<DataType, ck::bhalf_t> ||
ck::is_same_v<ComputeDataType, ck::bhalf_t>)
{
const double epsilon = std::pow(2, -7);
// Maximum relative casting error when rounding to zero.
threshold += epsilon * 2;
}
else if constexpr(ck::is_same_v<DataType, ck::half_t> ||
ck::is_same_v<ComputeDataType, ck::half_t>)
{
const double epsilon = std::pow(2, -10);
// Maximum relative casting error when rounding to zero.
threshold += epsilon * 2;
}
pass = pass & ck::utils::check_err(e_m_n_device_result, pass = pass & ck::utils::check_err(e_m_n_device_result,
e_m_n_host_result, e_m_n_host_result,
"Error: incorrect results!", "Error: incorrect results!",
...@@ -313,7 +283,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -313,7 +283,7 @@ int profile_contraction_impl(ck::index_t do_verification,
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_n_k.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",") LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",") LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",")
......
...@@ -23,18 +23,8 @@ enum struct ContractionMatrixLayout ...@@ -23,18 +23,8 @@ enum struct ContractionMatrixLayout
enum struct ContractionDataType enum struct ContractionDataType
{ {
F32_F32_F32_F32, // 0 F32_F32_F32_F32, // 0
F64_F64_F64_F64, // 1 F64_F64_F64_F64, // 1
F16_F16_F16_F16, // 2
BF16_BF16_BF16_BF16, // 3
};
enum struct ContractionComputeDataType
{
F32 = 0,
F64,
F16,
BF16,
}; };
inline void collect_index_params(char* argv[], inline void collect_index_params(char* argv[],
......
...@@ -17,9 +17,8 @@ ...@@ -17,9 +17,8 @@
static void print_helper_msg() static void print_helper_msg()
{ {
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" << "arg2: data type (0: fp32; 1: f64)\n"
<< "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" << "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
<< "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
...@@ -27,42 +26,40 @@ static void print_helper_msg() ...@@ -27,42 +26,40 @@ static void print_helper_msg()
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
<< "arg5: verification (0: no; 1: yes)\n" << "arg4: verification (0: no; 1: yes)\n"
<< "arg6: initialization (0: no init; 1: integer value; 2: decimal " << "arg5: initialization (0: no init; 1: integer value; 2: decimal "
<< "value)\n" << "value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n" << "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg8: time kernel (0: no, 1: yes)\n" << "arg7: time kernel (0: no, 1: yes)\n"
<< "arg9: alpha\n" << "arg8 and arg9: alpha and beta\n"
<< "arg10: beta\n" << "arg10 to 15: M0, M1, N0, N1, K0, K1\n"
<< "arg11 to 16: M0, M1, N0, N1, K0, K1\n" << "arg16 to 31: Strides for A, B, D and E (skip for default)\n"
<< "arg17 to 32: Strides for A, B, D and E (skip for default)\n"
<< std::endl; << std::endl;
} }
int profile_contraction_bilinear(int argc, char* argv[]) int profile_contraction_bilinear(int argc, char* argv[])
{ {
const bool default_strides = argc == 17; const bool default_strides = argc == 16;
if(argc != 33 && argc != 17) if(argc != 32 && argc != 16)
{ {
print_helper_msg(); print_helper_msg();
exit(1); exit(1);
} }
const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2]));
const auto compute_data_type = static_cast<ContractionComputeDataType>(std::stoi(argv[3])); const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[3]));
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[4])); const bool do_verification = std::stoi(argv[4]);
const bool do_verification = std::stoi(argv[5]); const ck::index_t init_method = std::stoi(argv[5]);
const ck::index_t init_method = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
const bool do_log = std::stoi(argv[7]); const bool time_kernel = std::stoi(argv[7]);
const bool time_kernel = std::stoi(argv[8]); const float alpha = std::stof(argv[8]);
const float alpha = std::stof(argv[9]); const float beta = std::stof(argv[9]);
const float beta = std::stof(argv[10]);
std::vector<ck::index_t> M; std::vector<ck::index_t> M;
std::vector<ck::index_t> N; std::vector<ck::index_t> N;
std::vector<ck::index_t> K; std::vector<ck::index_t> K;
const ck::index_t dims_arg_num = 11; const ck::index_t dims_arg_num = 10;
collect_index_params(argv, M, dims_arg_num, 2); collect_index_params(argv, M, dims_arg_num, 2);
collect_index_params(argv, N, dims_arg_num + 2, 2); collect_index_params(argv, N, dims_arg_num + 2, 2);
collect_index_params(argv, K, dims_arg_num + 4, 2); collect_index_params(argv, K, dims_arg_num + 4, 2);
...@@ -79,130 +76,90 @@ int profile_contraction_bilinear(int argc, char* argv[]) ...@@ -79,130 +76,90 @@ int profile_contraction_bilinear(int argc, char* argv[])
collect_index_params(argv, StridesD, dims_arg_num + 18, 4); collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
} }
using F16 = ck::half_t; using F32 = float;
using BF16 = ck::bhalf_t; using F64 = double;
using F32 = float;
using F64 = double; auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) {
using ALayout = decltype(a_layout);
auto profile = using BLayout = decltype(b_layout);
[&](auto a_layout, auto b_layout, auto cde_layout, auto type, auto compute_type) { using CDELayout = decltype(cde_layout);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout); using DataType = decltype(type);
using CDELayout = decltype(cde_layout);
if(default_strides)
using DataType = decltype(type);
using ComputeDataType = decltype(compute_type);
if(default_strides)
{
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
}
bool pass = ck::profiler::profile_contraction_impl<ALayout,
BLayout,
CDELayout,
DataType,
ComputeDataType,
ck::Tuple<DataType>,
Bilinear>(do_verification,
init_method,
do_log,
time_kernel,
Bilinear{alpha, beta},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
};
auto run_profile_for_datatype = [&](auto type, auto compute_type) {
if(layout == ContractionMatrixLayout::MK_KN_MN_MN)
{
return profile(Row{}, Row{}, Row{}, type, compute_type);
}
else if(layout == ContractionMatrixLayout::MK_NK_MN_MN)
{
return profile(Row{}, Col{}, Row{}, type, compute_type);
}
else if(layout == ContractionMatrixLayout::KM_KN_MN_MN)
{
return profile(Col{}, Row{}, Row{}, type, compute_type);
}
else if(layout == ContractionMatrixLayout::KM_NK_MN_MN)
{ {
return profile(Col{}, Col{}, Row{}, type, compute_type); assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
} }
return false; bool pass = ck::profiler::profile_contraction_impl<ALayout,
BLayout,
CDELayout,
DataType,
ck::Tuple<DataType>,
Bilinear>(do_verification,
init_method,
do_log,
time_kernel,
Bilinear{alpha, beta},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
}; };
if(data_type == ContractionDataType::F32_F32_F32_F32) if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::MK_KN_MN_MN)
{ {
if(compute_data_type == ContractionComputeDataType::F32) return profile(Row{}, Row{}, Row{}, F32{});
{
return run_profile_for_datatype(F32{}, F32{});
}
else if(compute_data_type == ContractionComputeDataType::F16)
{
return run_profile_for_datatype(F32{}, F16{});
}
else if(compute_data_type == ContractionComputeDataType::BF16)
{
return run_profile_for_datatype(F32{}, BF16{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64) else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::MK_NK_MN_MN)
{ {
if(compute_data_type == ContractionComputeDataType::F64) return profile(Row{}, Col{}, Row{}, F32{});
{
return run_profile_for_datatype(F64{}, F64{});
}
else if(compute_data_type == ContractionComputeDataType::F32)
{
return run_profile_for_datatype(F64{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else if(data_type == ContractionDataType::F16_F16_F16_F16) else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{ {
if(compute_data_type == ContractionComputeDataType::F32) return profile(Col{}, Row{}, Row{}, F32{});
{
return run_profile_for_datatype(F16{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else if(data_type == ContractionDataType::BF16_BF16_BF16_BF16) else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::KM_NK_MN_MN)
{ {
if(compute_data_type == ContractionComputeDataType::F32) return profile(Col{}, Col{}, Row{}, F32{});
{ }
return run_profile_for_datatype(BF16{}, F32{}); else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
} layout == ContractionMatrixLayout::MK_KN_MN_MN)
else {
{ return profile(Row{}, Row{}, Row{}, F64{});
std::cout << "Incorrect combination of data type and compute data type." << std::endl; }
return 1; else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
} layout == ContractionMatrixLayout::MK_NK_MN_MN)
{
return profile(Row{}, Col{}, Row{}, F64{});
}
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{
return profile(Col{}, Row{}, Row{}, F64{});
}
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::KM_NK_MN_MN)
{
return profile(Col{}, Col{}, Row{}, F64{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
} }
return 1;
} }
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear); REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear);
...@@ -17,9 +17,8 @@ ...@@ -17,9 +17,8 @@
static void print_helper_msg() static void print_helper_msg()
{ {
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" << "arg2: data type (0: fp32; 1: f64)\n"
<< "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" << "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
<< "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
...@@ -27,40 +26,39 @@ static void print_helper_msg() ...@@ -27,40 +26,39 @@ static void print_helper_msg()
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
<< "arg5: verification (0: no; 1: yes)\n" << "arg4: verification (0: no; 1: yes)\n"
<< "arg6: initialization (0: no init; 1: integer value; 2: decimal " << "arg5: initialization (0: no init; 1: integer value; 2: decimal "
<< "value)\n" << "value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n" << "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg8: time kernel (0: no, 1: yes)\n" << "arg7: time kernel (0: no, 1: yes)\n"
<< "arg9: alpha\n" << "arg8: alpha\n"
<< "arg10 to 15: M0, M1, N0, N1, K0, K1\n" << "arg9 to 14: M0, M1, N0, N1, K0, K1\n"
<< "arg16 to 31: Strides for A, B, D and E (skip for default)\n" << "arg15 to 30: Strides for A, B, D and E (skip for default)\n"
<< std::endl; << std::endl;
} }
int profile_contraction_scale(int argc, char* argv[]) int profile_contraction_scale(int argc, char* argv[])
{ {
const bool default_strides = argc == 16; const bool default_strides = argc == 15;
if(argc != 32 && argc != 16) if(argc != 31 && argc != 15)
{ {
print_helper_msg(); print_helper_msg();
exit(1); exit(1);
} }
const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2]));
const auto compute_data_type = static_cast<ContractionComputeDataType>(std::stoi(argv[3])); const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[3]));
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[4])); const bool do_verification = std::stoi(argv[4]);
const bool do_verification = std::stoi(argv[5]); const ck::index_t init_method = std::stoi(argv[5]);
const ck::index_t init_method = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
const bool do_log = std::stoi(argv[7]); const bool time_kernel = std::stoi(argv[7]);
const bool time_kernel = std::stoi(argv[8]); const float alpha = std::stof(argv[8]);
const float alpha = std::stof(argv[9]);
std::vector<ck::index_t> M; std::vector<ck::index_t> M;
std::vector<ck::index_t> N; std::vector<ck::index_t> N;
std::vector<ck::index_t> K; std::vector<ck::index_t> K;
const ck::index_t dims_arg_num = 10; const ck::index_t dims_arg_num = 9;
collect_index_params(argv, M, dims_arg_num, 2); collect_index_params(argv, M, dims_arg_num, 2);
collect_index_params(argv, N, dims_arg_num + 2, 2); collect_index_params(argv, N, dims_arg_num + 2, 2);
collect_index_params(argv, K, dims_arg_num + 4, 2); collect_index_params(argv, K, dims_arg_num + 4, 2);
...@@ -77,131 +75,88 @@ int profile_contraction_scale(int argc, char* argv[]) ...@@ -77,131 +75,88 @@ int profile_contraction_scale(int argc, char* argv[])
collect_index_params(argv, StridesD, dims_arg_num + 18, 4); collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
} }
using F16 = ck::half_t; using F32 = float;
using BF16 = ck::bhalf_t; using F64 = double;
using F32 = float;
using F64 = double; auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) {
using ALayout = decltype(a_layout);
auto profile = using BLayout = decltype(b_layout);
[&](auto a_layout, auto b_layout, auto cde_layout, auto type, auto compute_type) { using CDELayout = decltype(cde_layout);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout); using DataType = decltype(type);
using CDELayout = decltype(cde_layout);
if(default_strides)
using DataType = decltype(type);
using ComputeDataType = decltype(compute_type);
if(default_strides)
{
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
}
bool pass = ck::profiler::profile_contraction_impl<ALayout,
BLayout,
CDELayout,
DataType,
ComputeDataType,
ck::Tuple<>,
Scale>(do_verification,
init_method,
do_log,
time_kernel,
Scale{alpha},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
};
auto run_profile_for_datatype = [&](auto type, auto compute_type) {
if(layout == ContractionMatrixLayout::MK_KN_MN_MN)
{
return profile(Row{}, Row{}, Row{}, type, compute_type);
}
else if(layout == ContractionMatrixLayout::MK_NK_MN_MN)
{
return profile(Row{}, Col{}, Row{}, type, compute_type);
}
else if(layout == ContractionMatrixLayout::KM_KN_MN_MN)
{
return profile(Col{}, Row{}, Row{}, type, compute_type);
}
else if(layout == ContractionMatrixLayout::KM_NK_MN_MN)
{ {
return profile(Col{}, Col{}, Row{}, type, compute_type); assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
} }
return false;
bool pass = ck::profiler::
profile_contraction_impl<ALayout, BLayout, CDELayout, DataType, ck::Tuple<>, Scale>(
do_verification,
init_method,
do_log,
time_kernel,
Scale{alpha},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
}; };
if(data_type == ContractionDataType::F32_F32_F32_F32) if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::MK_KN_MN_MN)
{ {
if(compute_data_type == ContractionComputeDataType::F32) return profile(Row{}, Row{}, Row{}, F32{});
{
return run_profile_for_datatype(F32{}, F32{});
}
else if(compute_data_type == ContractionComputeDataType::F16)
{
return run_profile_for_datatype(F32{}, F16{});
}
else if(compute_data_type == ContractionComputeDataType::BF16)
{
return run_profile_for_datatype(F32{}, BF16{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64) else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::MK_NK_MN_MN)
{ {
if(compute_data_type == ContractionComputeDataType::F64) return profile(Row{}, Col{}, Row{}, F32{});
{
return run_profile_for_datatype(F64{}, F64{});
}
else if(compute_data_type == ContractionComputeDataType::F32)
{
return run_profile_for_datatype(F64{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else if(data_type == ContractionDataType::F16_F16_F16_F16) else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{ {
if(compute_data_type == ContractionComputeDataType::F32) return profile(Col{}, Row{}, Row{}, F32{});
{
return run_profile_for_datatype(F16{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else if(data_type == ContractionDataType::BF16_BF16_BF16_BF16) else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::KM_NK_MN_MN)
{ {
if(compute_data_type == ContractionComputeDataType::F32) return profile(Col{}, Col{}, Row{}, F32{});
{ }
return run_profile_for_datatype(BF16{}, F32{}); else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
} layout == ContractionMatrixLayout::MK_KN_MN_MN)
else {
{ return profile(Row{}, Row{}, Row{}, F64{});
std::cout << "Incorrect combination of data type and compute data type." << std::endl; }
return 1; else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
} layout == ContractionMatrixLayout::MK_NK_MN_MN)
{
return profile(Row{}, Col{}, Row{}, F64{});
}
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{
return profile(Col{}, Row{}, Row{}, F64{});
}
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::KM_NK_MN_MN)
{
return profile(Col{}, Col{}, Row{}, F64{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
} }
return 1;
} }
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_scale); REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_scale);
...@@ -32,7 +32,7 @@ function(add_test_executable TEST_NAME) ...@@ -32,7 +32,7 @@ function(add_test_executable TEST_NAME)
set(test 0) set(test 0)
break() break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1)) NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal #if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1) set(test 1)
...@@ -61,7 +61,7 @@ function(add_test_executable TEST_NAME) ...@@ -61,7 +61,7 @@ function(add_test_executable TEST_NAME)
set(result 0) set(result 0)
endif() endif()
#message("add_test returns ${result}") #message("add_test returns ${result}")
return(PROPAGATE result) set(result ${result} PARENT_SCOPE)
endfunction(add_test_executable TEST_NAME) endfunction(add_test_executable TEST_NAME)
include(GoogleTest) include(GoogleTest)
...@@ -91,7 +91,7 @@ function(add_gtest_executable TEST_NAME) ...@@ -91,7 +91,7 @@ function(add_gtest_executable TEST_NAME)
set(test 0) set(test 0)
break() break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1)) NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal #if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1) set(test 1)
...@@ -123,7 +123,7 @@ function(add_gtest_executable TEST_NAME) ...@@ -123,7 +123,7 @@ function(add_gtest_executable TEST_NAME)
set(result 0) set(result 0)
endif() endif()
#message("add_gtest returns ${result}") #message("add_gtest returns ${result}")
return(PROPAGATE result) set(result ${result} PARENT_SCOPE)
endfunction(add_gtest_executable TEST_NAME) endfunction(add_gtest_executable TEST_NAME)
add_subdirectory(magic_number_division) add_subdirectory(magic_number_division)
......
...@@ -2,22 +2,8 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,22 +2,8 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) add_gtest_executable(test_batched_gemm test_batched_gemm.cpp)
if(result EQUAL 0) target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance)
target_link_libraries(test_batched_gemm_fp16 PRIVATE utility device_batched_gemm_instance)
endif()
add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_fp32 PRIVATE utility device_batched_gemm_instance)
endif()
add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_bf16 PRIVATE utility device_batched_gemm_instance)
endif()
add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_int8 PRIVATE utility device_batched_gemm_instance)
endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace {
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
int main()
{
int M = 256;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
using namespace ck::tensor_operation::device;
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM bf16: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace {
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
int main()
{
int M = 512;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
using namespace ck::tensor_operation::device;
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace {
using ADataType = float;
using BDataType = float;
using CDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
int main()
{
int M = 256;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
using namespace ck::tensor_operation::device;
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM fp32: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace {
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
int main()
{
int M = 256;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
using namespace ck::tensor_operation::device;
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM int8: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
struct GemmParams
{
ck::index_t M;
ck::index_t N;
ck::index_t K;
ck::index_t BatchCount;
};
class TestBatchedGemm : public ::testing::Test
{
protected:
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
std::vector<GemmParams> params;
template <typename DataType>
void Run()
{
using namespace ck::tensor_operation::device;
bool pass = true;
for(auto& param : params)
{
const auto M = param.M;
const auto N = param.N;
const auto K = param.K;
const auto BatchCount = param.BatchCount;
pass =
pass && ck::profiler::profile_batched_gemm_impl<DataType,
DataType,
DataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
DataType,
DataType,
DataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass =
pass && ck::profiler::profile_batched_gemm_impl<DataType,
DataType,
DataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
DataType,
DataType,
DataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass =
pass && ck::profiler::profile_batched_gemm_impl<DataType,
DataType,
DataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
DataType,
DataType,
DataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass =
pass && ck::profiler::profile_batched_gemm_impl<DataType,
DataType,
DataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
DataType,
DataType,
DataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
}
EXPECT_TRUE(pass);
}
};
#ifdef CK_ENABLE_INT8
TEST_F(TestBatchedGemm, i8)
{
this->params.push_back({64, 64, 64, 2});
this->params.push_back({64, 64, 64, 1});
this->params.push_back({60, 60, 60, 2});
this->params.push_back({68, 68, 68, 2});
this->params.push_back({40, 40, 40, 2});
this->params.push_back({256, 256, 128, 3});
this->template Run<int8_t>();
}
#endif
#ifdef CK_ENABLE_BF16
TEST_F(TestBatchedGemm, bf16)
{
this->params.push_back({64, 64, 64, 2});
this->params.push_back({64, 64, 64, 1});
this->params.push_back({60, 60, 60, 2});
this->params.push_back({68, 68, 68, 2});
this->params.push_back({40, 40, 40, 2});
this->params.push_back({256, 256, 128, 3});
this->template Run<ck::bhalf_t>();
}
#endif
#ifdef CK_ENABLE_FP16
TEST_F(TestBatchedGemm, fp16)
{
this->params.push_back({64, 64, 64, 2});
this->params.push_back({64, 64, 64, 1});
this->params.push_back({60, 60, 60, 2});
this->params.push_back({68, 68, 68, 2});
this->params.push_back({40, 40, 40, 2});
this->params.push_back({256, 256, 128, 3});
this->template Run<ck::half_t>();
}
#endif
#ifdef CK_ENABLE_FP32
TEST_F(TestBatchedGemm, fp32)
{
this->params.push_back({64, 64, 64, 2});
this->params.push_back({64, 64, 64, 1});
this->params.push_back({60, 60, 60, 2});
this->params.push_back({68, 68, 68, 2});
this->params.push_back({40, 40, 40, 2});
this->params.push_back({256, 256, 128, 3});
this->template Run<float>();
}
#endif
...@@ -10,12 +10,9 @@ ...@@ -10,12 +10,9 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "profiler/profile_contraction_impl.hpp" #include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp"
using F16 = ck::half_t; using F32 = float;
using BF16 = ck::bhalf_t; using F64 = double;
using F32 = float;
using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -23,49 +20,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -23,49 +20,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
struct Dimensions struct MemoryParams
{ {
std::vector<ck::index_t> M; std::vector<ck::index_t> M;
std::vector<ck::index_t> N; std::vector<ck::index_t> N;
std::vector<ck::index_t> K; std::vector<ck::index_t> K;
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC;
std::vector<ck::index_t> StridesD;
}; };
template <typename Tuple> template <typename Tuple>
class TestContraction : public ::testing::Test class TestContraction : public ::testing::Test
{ {
protected: protected:
using ALayout = std::tuple_element_t<0, Tuple>; using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>;
using CDLayout = std::tuple_element_t<2, Tuple>; using CDLayout = std::tuple_element_t<2, Tuple>;
using DataType = std::tuple_element_t<3, Tuple>; using DataType = std::tuple_element_t<3, Tuple>;
using DTupleDataType = std::tuple_element_t<4, Tuple>; using DTupleDataType = std::tuple_element_t<4, Tuple>;
using ComputeDataType = std::tuple_element_t<5, Tuple>; using CDElementOp = std::tuple_element_t<5, Tuple>;
using CDElementOp = std::tuple_element_t<6, Tuple>;
std::vector<MemoryParams> list_of_memory_params = {{{32, 32},
std::vector<Dimensions> dimension_list = {{{32, 32}, {32, 32}, {32, 32}}, {32, 32},
{{16, 16}, {32, 32}, {16, 16}}}; {32, 32},
{32768, 1024, 32, 1},
std::vector<ck::index_t> init_methods = {1, 2}; {32768, 1024, 32, 1},
{32768, 1024, 32, 1},
{32768, 1024, 32, 1}},
{{16, 16},
{32, 32},
{16, 16},
{4096, 256, 16, 1},
{16, 1, 8192, 256},
{16384, 1024, 32, 1},
{16384, 1024, 32, 1}}};
std::vector<ck::index_t> init_methods = {0, 1, 2};
std::unique_ptr<CDElementOp> p_cd_element_op; std::unique_ptr<CDElementOp> p_cd_element_op;
void Run() void Run()
{ {
for(auto& dimension_params : dimension_list) for(auto& memory_params : list_of_memory_params)
{ {
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC;
std::vector<ck::index_t> StridesD;
const auto& M = dimension_params.M;
const auto& N = dimension_params.N;
const auto& K = dimension_params.K;
assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(BLayout{}, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]});
assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]});
for(const ck::index_t init_method : init_methods) for(const ck::index_t init_method : init_methods)
{ {
bool pass = bool pass =
...@@ -73,20 +70,19 @@ class TestContraction : public ::testing::Test ...@@ -73,20 +70,19 @@ class TestContraction : public ::testing::Test
BLayout, BLayout,
CDLayout, CDLayout,
DataType, DataType,
ComputeDataType,
DTupleDataType, DTupleDataType,
CDElementOp>(true /*do_verification*/, CDElementOp>(true /*do_verification*/,
init_method, init_method,
false /*do_logs*/, false /*do_logs*/,
false /*time_kernel*/, false /*time_kernel*/,
*p_cd_element_op, *p_cd_element_op,
dimension_params.M, memory_params.M,
dimension_params.N, memory_params.N,
dimension_params.K, memory_params.K,
StridesA, memory_params.StridesA,
StridesB, memory_params.StridesB,
StridesC, memory_params.StridesC,
StridesD); memory_params.StridesD);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
} }
...@@ -103,18 +99,24 @@ class TestContractionBilinear : public TestContraction<Tuple> ...@@ -103,18 +99,24 @@ class TestContractionBilinear : public TestContraction<Tuple>
{ {
}; };
#define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op) \
std::tuple<Row, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Row, Col, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Col, Row, dt, tuple_dt, compute_dt, op>
using BilinearKernelTypes = using BilinearKernelTypes =
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F32, Bilinear), ::testing::Types<std::tuple<Row, Row, Row, F32, ck::Tuple<F32>, Bilinear>,
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F64, Bilinear)>; std::tuple<Row, Col, Row, F32, ck::Tuple<F32>, Bilinear>,
std::tuple<Col, Row, Row, F32, ck::Tuple<F32>, Bilinear>,
using ScaleKernelTypes = ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F32, Scale), std::tuple<Col, Col, Row, F32, ck::Tuple<F32>, Bilinear>,
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F64, Scale)>; std::tuple<Row, Row, Row, F64, ck::Tuple<F32>, Bilinear>,
std::tuple<Row, Col, Row, F64, ck::Tuple<F32>, Bilinear>,
std::tuple<Col, Row, Row, F64, ck::Tuple<F32>, Bilinear>,
std::tuple<Col, Col, Row, F64, ck::Tuple<F32>, Bilinear>>;
using ScaleKernelTypes = ::testing::Types<std::tuple<Row, Row, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Row, Col, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Col, Row, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Col, Col, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Row, Row, Row, F64, ck::Tuple<>, Scale>,
std::tuple<Row, Col, Row, F64, ck::Tuple<>, Scale>,
std::tuple<Col, Row, Row, F64, ck::Tuple<>, Scale>,
std::tuple<Col, Col, Row, F64, ck::Tuple<>, Scale>>;
TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes); TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes);
TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes); TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes);
...@@ -134,46 +136,3 @@ TYPED_TEST(TestContractionScale, scale) ...@@ -134,46 +136,3 @@ TYPED_TEST(TestContractionScale, scale)
this->p_cd_element_op = std::make_unique<Scale>(0.5f); this->p_cd_element_op = std::make_unique<Scale>(0.5f);
this->Run(); this->Run();
} }
template <typename Tuple>
class TestContractionScaleMixedPrecision : public TestContraction<Tuple>
{
};
template <typename Tuple>
class TestContractionBilinearMixedPrecision : public TestContraction<Tuple>
{
};
using BilinearKernelTypesMixedPrecision =
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F16, Bilinear),
ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, BF16, Bilinear),
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F32, Bilinear),
ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<F16>, F32, Bilinear),
ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<BF16>, F32, Bilinear)>;
using ScaleKernelTypesMixedPrecision =
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F16, Scale),
ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, BF16, Scale),
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F32, Scale),
ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<>, F32, Scale),
ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<>, F32, Scale)>;
TYPED_TEST_SUITE(TestContractionBilinearMixedPrecision, BilinearKernelTypesMixedPrecision);
TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecision);
TYPED_TEST(TestContractionBilinearMixedPrecision, bilinear)
{
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
this->Run();
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
this->Run();
}
TYPED_TEST(TestContractionScaleMixedPrecision, scale)
{
this->p_cd_element_op = std::make_unique<Scale>(1.f);
this->Run();
this->p_cd_element_op = std::make_unique<Scale>(0.5f);
this->Run();
}
...@@ -34,11 +34,11 @@ class ContractionInstanceWrapper ...@@ -34,11 +34,11 @@ class ContractionInstanceWrapper
static constexpr ck::index_t NumDim = 2; static constexpr ck::index_t NumDim = 2;
// clang-format off // clang-format off
using ContractionDeviceInstance = ck::tensor_operation::device:: using ContractionDeviceInstance = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>; DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>;
// clang-format on // clang-format on
bool isSupported(std::vector<ck::index_t>& ADims, bool isSupported(std::vector<ck::index_t>& ADims,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment