Unverified Commit f0748506 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski Committed by GitHub
Browse files

Add support for mixed precision in contraction scale and bilinear (#936)

* Extract common functionality to separate files

* Reference contraction: Remove incorrect consts from type_converts

* Reference contraction: Add missing type_convert for dst value

* Reference contraction: Fix incorrect order of B matrix dimensions

* Add support for mixed precision in contraction scale and bilinear

* Move using statements from instances to a common file

* Move using statements from examples to a common file

* Fix the order of B matrix dimensions across examples and profiler

* Fix the computation of error threshold

* Make ComputeDataType an optional argument

* Include possible DataType -> ComputeDataType casting error in the threshold

* Remove commented code
parent cb538740
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// m/n/n/n are the fast changing dimension for A/B/D/E
using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance =
device_contraction_f64_mn_instance<F64,
F64,
F32,
F64,
Empty_Tuple,
F64,
F32,
PassThrough,
PassThrough,
Scale>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
Empty_Tuple,
F64,
PassThrough,
PassThrough,
Scale,
F32>>>& instances)
{
add_device_operation_instances(
instances,
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -9,11 +9,9 @@ ...@@ -9,11 +9,9 @@
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -21,37 +19,19 @@ namespace tensor_operation { ...@@ -21,37 +19,19 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using F64 = double;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// k/k/n/n are the fast changing dimension for A/B/D/E // k/k/n/n are the fast changing dimension for A/B/D/E
using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance = std::tuple< using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance =
// clang-format off device_contraction_f64_kk_instance<F64,
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| F64,
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| F64,
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| F64,
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Empty_Tuple,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, F64,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, F64,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, PassThrough,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, PassThrough,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, Scale>;
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 32, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 32, 64, 16, 2, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>
// clang-format on
>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -63,7 +43,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instanc ...@@ -63,7 +43,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instanc
F64, F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances) Scale,
F64>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance{}); instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance{});
......
...@@ -9,11 +9,9 @@ ...@@ -9,11 +9,9 @@
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -21,37 +19,19 @@ namespace tensor_operation { ...@@ -21,37 +19,19 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using F64 = double;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// k/n/n/n are the fast changing dimension for A/B/D/E // k/n/n/n are the fast changing dimension for A/B/D/E
using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance = std::tuple< using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance =
// clang-format off device_contraction_f64_kn_instance<F64,
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| F64,
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| F64,
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| F64,
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Empty_Tuple,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, F64,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, F64,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, PassThrough,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, PassThrough,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, Scale>;
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
// clang-format on
>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -63,7 +43,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instanc ...@@ -63,7 +43,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instanc
F64, F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances) Scale,
F64>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance{}); instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance{});
......
...@@ -9,11 +9,9 @@ ...@@ -9,11 +9,9 @@
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -21,37 +19,19 @@ namespace tensor_operation { ...@@ -21,37 +19,19 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using F64 = double;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// m/k/n/n are the fast changing dimension for A/B/D/E // m/k/n/n are the fast changing dimension for A/B/D/E
using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance = std::tuple< using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance =
// clang-format off device_contraction_f64_mk_instance<F64,
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| F64,
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| F64,
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| F64,
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Empty_Tuple,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, F64,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, F64,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, PassThrough,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, PassThrough,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, Scale>;
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
// clang-format on
>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -63,7 +43,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instanc ...@@ -63,7 +43,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instanc
F64, F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances) Scale,
F64>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance{}); instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance{});
......
...@@ -9,11 +9,9 @@ ...@@ -9,11 +9,9 @@
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -21,37 +19,19 @@ namespace tensor_operation { ...@@ -21,37 +19,19 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using F64 = double;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// m/n/n/n are the fast changing dimension for A/B/D/E // m/n/n/n are the fast changing dimension for A/B/D/E
using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance = std::tuple< using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance =
// clang-format off device_contraction_f64_mn_instance<F64,
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| F64,
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| F64,
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| F64,
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Empty_Tuple,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, F64,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, F64,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, PassThrough,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, PassThrough,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, Scale>;
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
// clang-format on
>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -63,7 +43,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc ...@@ -63,7 +43,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
F64, F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances) Scale,
F64>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance{}); instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance{});
......
...@@ -50,21 +50,23 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s ...@@ -50,21 +50,23 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s
## Profile contraction kernels ## Profile contraction kernels
```bash ```bash
#arg1: tensor operation (contraction_bilinear=CONTRACTION+Bilinear) #arg1: tensor operation (contraction_bilinear=CONTRACTION+Bilinear)
#arg2: data type (0: fp32; 1: f64)\n" #arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)
#arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; #arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)
#arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
# 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]) # 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1])
#arg4: verification (0: no; 1: yes) #arg5: verification (0: no; 1: yes)
#arg5: initialization (0: no init; 1: integer value; 2: decimal value) #arg6: initialization (0: no init; 1: integer value; 2: decimal value)
#arg6: print tensor value (0: no; 1: yes) #arg7: print tensor value (0: no; 1: yes)
#arg7: time kernel (0: no, 1: yes) #arg8: time kernel (0: no, 1: yes)
#arg8 and arg9: alpha and beta #arg9: alpha
#arg10 to 15: M0, M1, N0, N1, K0, K1 #arg10: beta
#arg16 to 31: Strides for A, B, D and E (skip for default) #arg11 to 16: M0, M1, N0, N1, K0, K1
#arg17 to 32: Strides for A, B, D and E (skip for default)
################ op datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1
./bin/ckProfiler contraction_bilinear 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128 ################ op datatype compute_datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1
./bin/ckProfiler contraction_bilinear 0 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128
``` ```
Result (MI100) Result (MI100)
......
...@@ -31,10 +31,14 @@ namespace profiler { ...@@ -31,10 +31,14 @@ namespace profiler {
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using F32 = float;
using F64 = double;
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CDELayout, typename CDELayout,
typename DataType, typename DataType,
typename ComputeDataType,
typename DTupleDataType, typename DTupleDataType,
typename CDElementOp> typename CDElementOp>
int profile_contraction_impl(ck::index_t do_verification, int profile_contraction_impl(ck::index_t do_verification,
...@@ -45,10 +49,10 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -45,10 +49,10 @@ int profile_contraction_impl(ck::index_t do_verification,
const std::vector<ck::index_t>& M, const std::vector<ck::index_t>& M,
const std::vector<ck::index_t>& N, const std::vector<ck::index_t>& N,
const std::vector<ck::index_t>& K, const std::vector<ck::index_t>& K,
const std::vector<ck::index_t>& StridesA, const std::vector<ck::index_t>& StridesA, // [M0, M1, K0, K1]
const std::vector<ck::index_t>& StridesB, const std::vector<ck::index_t>& StridesB, // [N0, N1, K0, K1]
const std::vector<ck::index_t>& StridesE, const std::vector<ck::index_t>& StridesE, // [M0, M1, N0, N1]
const std::vector<ck::index_t>& StridesD) const std::vector<ck::index_t>& StridesD) // [M0, M1, N0, N1]
{ {
bool pass = true; bool pass = true;
...@@ -63,13 +67,13 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -63,13 +67,13 @@ int profile_contraction_impl(ck::index_t do_verification,
}; };
Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA)); Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA));
Tensor<DataType> b_k_n(f_host_tensor_descriptor(K, N, StridesB)); Tensor<DataType> b_n_k(f_host_tensor_descriptor(N, K, StridesB));
Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE));
Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD)); Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_n_k: " << b_n_k.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
...@@ -78,12 +82,12 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -78,12 +82,12 @@ int profile_contraction_impl(ck::index_t do_verification,
case 0: break; case 0: break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); b_n_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); d_m_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); b_n_k.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); d_m_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
} }
...@@ -91,12 +95,12 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -91,12 +95,12 @@ int profile_contraction_impl(ck::index_t do_verification,
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(DataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(DataType) * b_n_k.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_n_k.mData.data());
e_device_buf.SetZero(); e_device_buf.SetZero();
d_device_buf.ToDevice(d_m_n.mData.data()); d_device_buf.ToDevice(d_m_n.mData.data());
...@@ -118,7 +122,8 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -118,7 +122,8 @@ int profile_contraction_impl(ck::index_t do_verification,
DataType, DataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDElementOp>; CDElementOp,
ComputeDataType>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -126,6 +131,9 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -126,6 +131,9 @@ int profile_contraction_impl(ck::index_t do_verification,
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
using AccDataType =
typename std::conditional<std::is_same<ComputeDataType, F64>::value, F64, F32>::type;
// Run reference op // Run reference op
if(do_verification) if(do_verification)
{ {
...@@ -136,7 +144,8 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -136,7 +144,8 @@ int profile_contraction_impl(ck::index_t do_verification,
DataType, DataType,
DataType, DataType,
DataType, DataType,
DataType, AccDataType,
ComputeDataType,
AElementOp, AElementOp,
BElementOp>; BElementOp>;
...@@ -146,7 +155,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -146,7 +155,7 @@ int profile_contraction_impl(ck::index_t do_verification,
Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
auto ref_argument = auto ref_argument =
ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op); ref_op.MakeArgument(a_m_k, b_n_k, c_m_n_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
...@@ -272,8 +281,29 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -272,8 +281,29 @@ int profile_contraction_impl(ck::index_t do_verification,
{ {
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
float threshold = // Both the kernel and the reference use `AccDataType`, so an absolute error of both
static_cast<DataType>(nelems_k) * std::numeric_limits<DataType>::epsilon(); // of them is bounded by `nelems_k * std::numeric_limits<AccDataType>::epsilon()`.
// Comparing one to another can result in an absolute error as high as twice that
// value.
double threshold = 2 * nelems_k * std::numeric_limits<AccDataType>::epsilon();
// Handle the possible casting error of either AccDataType -> DataType or
// DataType -> ComputeDataType.
// TODO: Add a generic solution for calculating thresholds in CK.
if constexpr(ck::is_same_v<DataType, ck::bhalf_t> ||
ck::is_same_v<ComputeDataType, ck::bhalf_t>)
{
const double epsilon = std::pow(2, -7);
// Maximum relative casting error when rounding to zero.
threshold += epsilon * 2;
}
else if constexpr(ck::is_same_v<DataType, ck::half_t> ||
ck::is_same_v<ComputeDataType, ck::half_t>)
{
const double epsilon = std::pow(2, -10);
// Maximum relative casting error when rounding to zero.
threshold += epsilon * 2;
}
pass = pass & ck::utils::check_err(e_m_n_device_result, pass = pass & ck::utils::check_err(e_m_n_device_result,
e_m_n_host_result, e_m_n_host_result,
"Error: incorrect results!", "Error: incorrect results!",
...@@ -283,7 +313,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -283,7 +313,7 @@ int profile_contraction_impl(ck::index_t do_verification,
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "b: ", b_n_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",") LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",") LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",")
......
...@@ -23,8 +23,18 @@ enum struct ContractionMatrixLayout ...@@ -23,8 +23,18 @@ enum struct ContractionMatrixLayout
enum struct ContractionDataType enum struct ContractionDataType
{ {
F32_F32_F32_F32, // 0 F32_F32_F32_F32, // 0
F64_F64_F64_F64, // 1 F64_F64_F64_F64, // 1
F16_F16_F16_F16, // 2
BF16_BF16_BF16_BF16, // 3
};
enum struct ContractionComputeDataType
{
F32 = 0,
F64,
F16,
BF16,
}; };
inline void collect_index_params(char* argv[], inline void collect_index_params(char* argv[],
......
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
static void print_helper_msg() static void print_helper_msg()
{ {
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: f64)\n" << "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
<< "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " << "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
<< "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
...@@ -26,40 +27,42 @@ static void print_helper_msg() ...@@ -26,40 +27,42 @@ static void print_helper_msg()
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
<< "arg4: verification (0: no; 1: yes)\n" << "arg5: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal " << "arg6: initialization (0: no init; 1: integer value; 2: decimal "
<< "value)\n" << "value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n" << "arg7: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n" << "arg8: time kernel (0: no, 1: yes)\n"
<< "arg8 and arg9: alpha and beta\n" << "arg9: alpha\n"
<< "arg10 to 15: M0, M1, N0, N1, K0, K1\n" << "arg10: beta\n"
<< "arg16 to 31: Strides for A, B, D and E (skip for default)\n" << "arg11 to 16: M0, M1, N0, N1, K0, K1\n"
<< "arg17 to 32: Strides for A, B, D and E (skip for default)\n"
<< std::endl; << std::endl;
} }
int profile_contraction_bilinear(int argc, char* argv[]) int profile_contraction_bilinear(int argc, char* argv[])
{ {
const bool default_strides = argc == 16; const bool default_strides = argc == 17;
if(argc != 32 && argc != 16) if(argc != 33 && argc != 17)
{ {
print_helper_msg(); print_helper_msg();
exit(1); exit(1);
} }
const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2]));
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[3])); const auto compute_data_type = static_cast<ContractionComputeDataType>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[4]));
const ck::index_t init_method = std::stoi(argv[5]); const bool do_verification = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const ck::index_t init_method = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]); const bool do_log = std::stoi(argv[7]);
const float alpha = std::stof(argv[8]); const bool time_kernel = std::stoi(argv[8]);
const float beta = std::stof(argv[9]); const float alpha = std::stof(argv[9]);
const float beta = std::stof(argv[10]);
std::vector<ck::index_t> M; std::vector<ck::index_t> M;
std::vector<ck::index_t> N; std::vector<ck::index_t> N;
std::vector<ck::index_t> K; std::vector<ck::index_t> K;
const ck::index_t dims_arg_num = 10; const ck::index_t dims_arg_num = 11;
collect_index_params(argv, M, dims_arg_num, 2); collect_index_params(argv, M, dims_arg_num, 2);
collect_index_params(argv, N, dims_arg_num + 2, 2); collect_index_params(argv, N, dims_arg_num + 2, 2);
collect_index_params(argv, K, dims_arg_num + 4, 2); collect_index_params(argv, K, dims_arg_num + 4, 2);
...@@ -76,90 +79,130 @@ int profile_contraction_bilinear(int argc, char* argv[]) ...@@ -76,90 +79,130 @@ int profile_contraction_bilinear(int argc, char* argv[])
collect_index_params(argv, StridesD, dims_arg_num + 18, 4); collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
} }
using F32 = float; using F16 = ck::half_t;
using F64 = double; using BF16 = ck::bhalf_t;
using F32 = float;
auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) { using F64 = double;
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout); auto profile =
using CDELayout = decltype(cde_layout); [&](auto a_layout, auto b_layout, auto cde_layout, auto type, auto compute_type) {
using ALayout = decltype(a_layout);
using DataType = decltype(type); using BLayout = decltype(b_layout);
using CDELayout = decltype(cde_layout);
if(default_strides)
using DataType = decltype(type);
using ComputeDataType = decltype(compute_type);
if(default_strides)
{
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
}
bool pass = ck::profiler::profile_contraction_impl<ALayout,
BLayout,
CDELayout,
DataType,
ComputeDataType,
ck::Tuple<DataType>,
Bilinear>(do_verification,
init_method,
do_log,
time_kernel,
Bilinear{alpha, beta},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
};
auto run_profile_for_datatype = [&](auto type, auto compute_type) {
if(layout == ContractionMatrixLayout::MK_KN_MN_MN)
{ {
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); return profile(Row{}, Row{}, Row{}, type, compute_type);
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
} }
bool pass = ck::profiler::profile_contraction_impl<ALayout, else if(layout == ContractionMatrixLayout::MK_NK_MN_MN)
BLayout, {
CDELayout, return profile(Row{}, Col{}, Row{}, type, compute_type);
DataType, }
ck::Tuple<DataType>, else if(layout == ContractionMatrixLayout::KM_KN_MN_MN)
Bilinear>(do_verification, {
init_method, return profile(Col{}, Row{}, Row{}, type, compute_type);
do_log, }
time_kernel, else if(layout == ContractionMatrixLayout::KM_NK_MN_MN)
Bilinear{alpha, beta}, {
M, return profile(Col{}, Col{}, Row{}, type, compute_type);
N, }
K, return false;
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
}; };
if(data_type == ContractionDataType::F32_F32_F32_F32 && if(data_type == ContractionDataType::F32_F32_F32_F32)
layout == ContractionMatrixLayout::MK_KN_MN_MN)
{
return profile(Row{}, Row{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::MK_NK_MN_MN)
{ {
return profile(Row{}, Col{}, Row{}, F32{}); if(compute_data_type == ContractionComputeDataType::F32)
} {
else if(data_type == ContractionDataType::F32_F32_F32_F32 && return run_profile_for_datatype(F32{}, F32{});
layout == ContractionMatrixLayout::KM_KN_MN_MN) }
{ else if(compute_data_type == ContractionComputeDataType::F16)
return profile(Col{}, Row{}, Row{}, F32{}); {
} return run_profile_for_datatype(F32{}, F16{});
else if(data_type == ContractionDataType::F32_F32_F32_F32 && }
layout == ContractionMatrixLayout::KM_NK_MN_MN) else if(compute_data_type == ContractionComputeDataType::BF16)
{ {
return profile(Col{}, Col{}, Row{}, F32{}); return run_profile_for_datatype(F32{}, BF16{});
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 && else
layout == ContractionMatrixLayout::MK_KN_MN_MN) {
{ std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return profile(Row{}, Row{}, Row{}, F64{}); return 1;
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::MK_NK_MN_MN)
{
return profile(Row{}, Col{}, Row{}, F64{});
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 && else if(data_type == ContractionDataType::F64_F64_F64_F64)
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{ {
return profile(Col{}, Row{}, Row{}, F64{}); if(compute_data_type == ContractionComputeDataType::F64)
{
return run_profile_for_datatype(F64{}, F64{});
}
else if(compute_data_type == ContractionComputeDataType::F32)
{
return run_profile_for_datatype(F64{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 && else if(data_type == ContractionDataType::F16_F16_F16_F16)
layout == ContractionMatrixLayout::KM_NK_MN_MN)
{ {
return profile(Col{}, Col{}, Row{}, F64{}); if(compute_data_type == ContractionComputeDataType::F32)
{
return run_profile_for_datatype(F16{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else else if(data_type == ContractionDataType::BF16_BF16_BF16_BF16)
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; if(compute_data_type == ContractionComputeDataType::F32)
{
return 1; return run_profile_for_datatype(BF16{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
return 1;
} }
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear); REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear);
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
static void print_helper_msg() static void print_helper_msg()
{ {
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: f64)\n" << "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
<< "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " << "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
<< "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
...@@ -26,39 +27,40 @@ static void print_helper_msg() ...@@ -26,39 +27,40 @@ static void print_helper_msg()
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
<< "arg4: verification (0: no; 1: yes)\n" << "arg5: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal " << "arg6: initialization (0: no init; 1: integer value; 2: decimal "
<< "value)\n" << "value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n" << "arg7: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n" << "arg8: time kernel (0: no, 1: yes)\n"
<< "arg8: alpha\n" << "arg9: alpha\n"
<< "arg9 to 14: M0, M1, N0, N1, K0, K1\n" << "arg10 to 15: M0, M1, N0, N1, K0, K1\n"
<< "arg15 to 30: Strides for A, B, D and E (skip for default)\n" << "arg16 to 31: Strides for A, B, D and E (skip for default)\n"
<< std::endl; << std::endl;
} }
int profile_contraction_scale(int argc, char* argv[]) int profile_contraction_scale(int argc, char* argv[])
{ {
const bool default_strides = argc == 15; const bool default_strides = argc == 16;
if(argc != 31 && argc != 15) if(argc != 32 && argc != 16)
{ {
print_helper_msg(); print_helper_msg();
exit(1); exit(1);
} }
const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2]));
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[3])); const auto compute_data_type = static_cast<ContractionComputeDataType>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[4]));
const ck::index_t init_method = std::stoi(argv[5]); const bool do_verification = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const ck::index_t init_method = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]); const bool do_log = std::stoi(argv[7]);
const float alpha = std::stof(argv[8]); const bool time_kernel = std::stoi(argv[8]);
const float alpha = std::stof(argv[9]);
std::vector<ck::index_t> M; std::vector<ck::index_t> M;
std::vector<ck::index_t> N; std::vector<ck::index_t> N;
std::vector<ck::index_t> K; std::vector<ck::index_t> K;
const ck::index_t dims_arg_num = 9; const ck::index_t dims_arg_num = 10;
collect_index_params(argv, M, dims_arg_num, 2); collect_index_params(argv, M, dims_arg_num, 2);
collect_index_params(argv, N, dims_arg_num + 2, 2); collect_index_params(argv, N, dims_arg_num + 2, 2);
collect_index_params(argv, K, dims_arg_num + 4, 2); collect_index_params(argv, K, dims_arg_num + 4, 2);
...@@ -75,88 +77,131 @@ int profile_contraction_scale(int argc, char* argv[]) ...@@ -75,88 +77,131 @@ int profile_contraction_scale(int argc, char* argv[])
collect_index_params(argv, StridesD, dims_arg_num + 18, 4); collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
} }
using F32 = float; using F16 = ck::half_t;
using F64 = double; using BF16 = ck::bhalf_t;
using F32 = float;
auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) { using F64 = double;
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout); auto profile =
using CDELayout = decltype(cde_layout); [&](auto a_layout, auto b_layout, auto cde_layout, auto type, auto compute_type) {
using ALayout = decltype(a_layout);
using DataType = decltype(type); using BLayout = decltype(b_layout);
using CDELayout = decltype(cde_layout);
if(default_strides)
using DataType = decltype(type);
using ComputeDataType = decltype(compute_type);
if(default_strides)
{
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
}
bool pass = ck::profiler::profile_contraction_impl<ALayout,
BLayout,
CDELayout,
DataType,
ComputeDataType,
ck::Tuple<>,
Scale>(do_verification,
init_method,
do_log,
time_kernel,
Scale{alpha},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
};
auto run_profile_for_datatype = [&](auto type, auto compute_type) {
if(layout == ContractionMatrixLayout::MK_KN_MN_MN)
{ {
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); return profile(Row{}, Row{}, Row{}, type, compute_type);
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
} }
else if(layout == ContractionMatrixLayout::MK_NK_MN_MN)
bool pass = ck::profiler:: {
profile_contraction_impl<ALayout, BLayout, CDELayout, DataType, ck::Tuple<>, Scale>( return profile(Row{}, Col{}, Row{}, type, compute_type);
do_verification, }
init_method, else if(layout == ContractionMatrixLayout::KM_KN_MN_MN)
do_log, {
time_kernel, return profile(Col{}, Row{}, Row{}, type, compute_type);
Scale{alpha}, }
M, else if(layout == ContractionMatrixLayout::KM_NK_MN_MN)
N, {
K, return profile(Col{}, Col{}, Row{}, type, compute_type);
StridesA, }
StridesB, return false;
StridesE,
StridesD);
return pass;
}; };
if(data_type == ContractionDataType::F32_F32_F32_F32 && if(data_type == ContractionDataType::F32_F32_F32_F32)
layout == ContractionMatrixLayout::MK_KN_MN_MN)
{
return profile(Row{}, Row{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::MK_NK_MN_MN)
{
return profile(Row{}, Col{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{ {
return profile(Col{}, Row{}, Row{}, F32{}); if(compute_data_type == ContractionComputeDataType::F32)
} {
else if(data_type == ContractionDataType::F32_F32_F32_F32 && return run_profile_for_datatype(F32{}, F32{});
layout == ContractionMatrixLayout::KM_NK_MN_MN) }
{ else if(compute_data_type == ContractionComputeDataType::F16)
return profile(Col{}, Col{}, Row{}, F32{}); {
} return run_profile_for_datatype(F32{}, F16{});
else if(data_type == ContractionDataType::F64_F64_F64_F64 && }
layout == ContractionMatrixLayout::MK_KN_MN_MN) else if(compute_data_type == ContractionComputeDataType::BF16)
{ {
return profile(Row{}, Row{}, Row{}, F64{}); return run_profile_for_datatype(F32{}, BF16{});
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 && else
layout == ContractionMatrixLayout::MK_NK_MN_MN) {
{ std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return profile(Row{}, Col{}, Row{}, F64{}); return 1;
}
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 && else if(data_type == ContractionDataType::F64_F64_F64_F64)
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{ {
return profile(Col{}, Row{}, Row{}, F64{}); if(compute_data_type == ContractionComputeDataType::F64)
{
return run_profile_for_datatype(F64{}, F64{});
}
else if(compute_data_type == ContractionComputeDataType::F32)
{
return run_profile_for_datatype(F64{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else if(data_type == ContractionDataType::F64_F64_F64_F64 && else if(data_type == ContractionDataType::F16_F16_F16_F16)
layout == ContractionMatrixLayout::KM_NK_MN_MN)
{ {
return profile(Col{}, Col{}, Row{}, F64{}); if(compute_data_type == ContractionComputeDataType::F32)
{
return run_profile_for_datatype(F16{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
else else if(data_type == ContractionDataType::BF16_BF16_BF16_BF16)
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; if(compute_data_type == ContractionComputeDataType::F32)
{
return 1; return run_profile_for_datatype(BF16{}, F32{});
}
else
{
std::cout << "Incorrect combination of data type and compute data type." << std::endl;
return 1;
}
} }
return 1;
} }
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_scale); REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_scale);
...@@ -10,9 +10,12 @@ ...@@ -10,9 +10,12 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "profiler/profile_contraction_impl.hpp" #include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp"
using F32 = float; using F16 = ck::half_t;
using F64 = double; using BF16 = ck::bhalf_t;
using F32 = float;
using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -20,49 +23,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -20,49 +23,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
struct MemoryParams struct Dimensions
{ {
std::vector<ck::index_t> M; std::vector<ck::index_t> M;
std::vector<ck::index_t> N; std::vector<ck::index_t> N;
std::vector<ck::index_t> K; std::vector<ck::index_t> K;
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC;
std::vector<ck::index_t> StridesD;
}; };
template <typename Tuple> template <typename Tuple>
class TestContraction : public ::testing::Test class TestContraction : public ::testing::Test
{ {
protected: protected:
using ALayout = std::tuple_element_t<0, Tuple>; using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>;
using CDLayout = std::tuple_element_t<2, Tuple>; using CDLayout = std::tuple_element_t<2, Tuple>;
using DataType = std::tuple_element_t<3, Tuple>; using DataType = std::tuple_element_t<3, Tuple>;
using DTupleDataType = std::tuple_element_t<4, Tuple>; using DTupleDataType = std::tuple_element_t<4, Tuple>;
using CDElementOp = std::tuple_element_t<5, Tuple>; using ComputeDataType = std::tuple_element_t<5, Tuple>;
using CDElementOp = std::tuple_element_t<6, Tuple>;
std::vector<MemoryParams> list_of_memory_params = {{{32, 32},
{32, 32}, std::vector<Dimensions> dimension_list = {{{32, 32}, {32, 32}, {32, 32}},
{32, 32}, {{16, 16}, {32, 32}, {16, 16}}};
{32768, 1024, 32, 1},
{32768, 1024, 32, 1}, std::vector<ck::index_t> init_methods = {1, 2};
{32768, 1024, 32, 1},
{32768, 1024, 32, 1}},
{{16, 16},
{32, 32},
{16, 16},
{4096, 256, 16, 1},
{16, 1, 8192, 256},
{16384, 1024, 32, 1},
{16384, 1024, 32, 1}}};
std::vector<ck::index_t> init_methods = {0, 1, 2};
std::unique_ptr<CDElementOp> p_cd_element_op; std::unique_ptr<CDElementOp> p_cd_element_op;
void Run() void Run()
{ {
for(auto& memory_params : list_of_memory_params) for(auto& dimension_params : dimension_list)
{ {
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC;
std::vector<ck::index_t> StridesD;
const auto& M = dimension_params.M;
const auto& N = dimension_params.N;
const auto& K = dimension_params.K;
assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(BLayout{}, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]});
assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]});
for(const ck::index_t init_method : init_methods) for(const ck::index_t init_method : init_methods)
{ {
bool pass = bool pass =
...@@ -70,19 +73,20 @@ class TestContraction : public ::testing::Test ...@@ -70,19 +73,20 @@ class TestContraction : public ::testing::Test
BLayout, BLayout,
CDLayout, CDLayout,
DataType, DataType,
ComputeDataType,
DTupleDataType, DTupleDataType,
CDElementOp>(true /*do_verification*/, CDElementOp>(true /*do_verification*/,
init_method, init_method,
false /*do_logs*/, false /*do_logs*/,
false /*time_kernel*/, false /*time_kernel*/,
*p_cd_element_op, *p_cd_element_op,
memory_params.M, dimension_params.M,
memory_params.N, dimension_params.N,
memory_params.K, dimension_params.K,
memory_params.StridesA, StridesA,
memory_params.StridesB, StridesB,
memory_params.StridesC, StridesC,
memory_params.StridesD); StridesD);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
} }
...@@ -99,24 +103,18 @@ class TestContractionBilinear : public TestContraction<Tuple> ...@@ -99,24 +103,18 @@ class TestContractionBilinear : public TestContraction<Tuple>
{ {
}; };
#define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op) \
std::tuple<Row, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Row, Col, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Col, Row, dt, tuple_dt, compute_dt, op>
using BilinearKernelTypes = using BilinearKernelTypes =
::testing::Types<std::tuple<Row, Row, Row, F32, ck::Tuple<F32>, Bilinear>, ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F32, Bilinear),
std::tuple<Row, Col, Row, F32, ck::Tuple<F32>, Bilinear>, ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F64, Bilinear)>;
std::tuple<Col, Row, Row, F32, ck::Tuple<F32>, Bilinear>,
std::tuple<Col, Col, Row, F32, ck::Tuple<F32>, Bilinear>, using ScaleKernelTypes = ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F32, Scale),
std::tuple<Row, Row, Row, F64, ck::Tuple<F32>, Bilinear>, ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F64, Scale)>;
std::tuple<Row, Col, Row, F64, ck::Tuple<F32>, Bilinear>,
std::tuple<Col, Row, Row, F64, ck::Tuple<F32>, Bilinear>,
std::tuple<Col, Col, Row, F64, ck::Tuple<F32>, Bilinear>>;
using ScaleKernelTypes = ::testing::Types<std::tuple<Row, Row, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Row, Col, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Col, Row, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Col, Col, Row, F32, ck::Tuple<>, Scale>,
std::tuple<Row, Row, Row, F64, ck::Tuple<>, Scale>,
std::tuple<Row, Col, Row, F64, ck::Tuple<>, Scale>,
std::tuple<Col, Row, Row, F64, ck::Tuple<>, Scale>,
std::tuple<Col, Col, Row, F64, ck::Tuple<>, Scale>>;
TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes); TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes);
TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes); TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes);
...@@ -136,3 +134,46 @@ TYPED_TEST(TestContractionScale, scale) ...@@ -136,3 +134,46 @@ TYPED_TEST(TestContractionScale, scale)
this->p_cd_element_op = std::make_unique<Scale>(0.5f); this->p_cd_element_op = std::make_unique<Scale>(0.5f);
this->Run(); this->Run();
} }
template <typename Tuple>
class TestContractionScaleMixedPrecision : public TestContraction<Tuple>
{
};
template <typename Tuple>
class TestContractionBilinearMixedPrecision : public TestContraction<Tuple>
{
};
using BilinearKernelTypesMixedPrecision =
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F16, Bilinear),
ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, BF16, Bilinear),
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F32, Bilinear),
ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<F16>, F32, Bilinear),
ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<BF16>, F32, Bilinear)>;
using ScaleKernelTypesMixedPrecision =
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F16, Scale),
ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, BF16, Scale),
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F32, Scale),
ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<>, F32, Scale),
ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<>, F32, Scale)>;
TYPED_TEST_SUITE(TestContractionBilinearMixedPrecision, BilinearKernelTypesMixedPrecision);
TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecision);
TYPED_TEST(TestContractionBilinearMixedPrecision, bilinear)
{
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
this->Run();
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
this->Run();
}
TYPED_TEST(TestContractionScaleMixedPrecision, scale)
{
this->p_cd_element_op = std::make_unique<Scale>(1.f);
this->Run();
this->p_cd_element_op = std::make_unique<Scale>(0.5f);
this->Run();
}
...@@ -34,11 +34,11 @@ class ContractionInstanceWrapper ...@@ -34,11 +34,11 @@ class ContractionInstanceWrapper
static constexpr ck::index_t NumDim = 2; static constexpr ck::index_t NumDim = 2;
// clang-format off // clang-format off
using ContractionDeviceInstance = ck::tensor_operation::device:: using ContractionDeviceInstance = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>; DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>;
// clang-format on // clang-format on
bool isSupported(std::vector<ck::index_t>& ADims, bool isSupported(std::vector<ck::index_t>& ADims,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment