Commit ec0ac21d authored by ltqin's avatar ltqin
Browse files

change lib function name

parent 48df79f2
...@@ -11,34 +11,34 @@ ...@@ -11,34 +11,34 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
extern template void void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute< std::vector<std::unique_ptr<
2, DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
1, 1,
1, 1,
1, 1,
F16, F16,
F16, F16,
F16, F16,
F16, F16,
ck::Tuple<F16>, ck::Tuple<F16>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
ScaleAdd, ScaleAdd,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& instances); MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
extern template void add_device_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
...@@ -59,27 +59,28 @@ extern template void add_device_instances( ...@@ -59,27 +59,28 @@ extern template void add_device_instances(
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances);
extern template void void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute< std::vector<std::unique_ptr<
2, DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
1, 1,
1, 1,
1, 1,
BF16, BF16,
BF16, BF16,
BF16, BF16,
BF16, BF16,
ck::Tuple<BF16>, ck::Tuple<BF16>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
ScaleAdd, ScaleAdd,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& instances); MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
extern template void add_device_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
...@@ -100,27 +101,28 @@ extern template void add_device_instances( ...@@ -100,27 +101,28 @@ extern template void add_device_instances(
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances);
extern template void void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute< std::vector<std::unique_ptr<
2, DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
1, 1,
1, 1,
1, 1,
F16, F16,
F16, F16,
F16, F16,
F16, F16,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& instances); MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
extern template void add_device_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
...@@ -141,27 +143,28 @@ extern template void add_device_instances( ...@@ -141,27 +143,28 @@ extern template void add_device_instances(
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances);
extern template void void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute< std::vector<std::unique_ptr<
2, DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
1, 1,
1, 1,
1, 1,
BF16, BF16,
BF16, BF16,
BF16, BF16,
BF16, BF16,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& instances); MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
extern template void add_device_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
...@@ -237,7 +240,8 @@ struct DeviceOperationInstanceFactory< ...@@ -237,7 +240,8 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
add_device_instances(op_ptrs); add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
op_ptrs);
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -11,7 +11,8 @@ ...@@ -11,7 +11,8 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_fp16_gmk_gnk_gno_gmo_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_bf16_gmk_gnk_gno_gmo_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
namespace half_data {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename DataType,
typename AccDataType,
typename D0DataTypes,
typename AD0ElementwiseOp,
MaskingSpecialization MaskingSpec>
using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances =
std::tuple<
// clang-format off
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec|
// #############################################| | | | | | Type| Type| Type| Type| ype| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
#if CK_WORKAROUND_SWDEV_388832
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
#endif
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DataType, DataType, DataType, DataType, D0DataTypes, ck::Tuple<>, AccDataType, DataType, PassThrough, PassThrough, AD0ElementwiseOp, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>
// clang-format on
>;
} // namespace half_data
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec,
typename enable_if<is_same<remove_cvref_t<ADataType>, ck::half_t>::value ||
is_same<remove_cvref_t<ADataType>, ck::bhalf_t>::value,
bool>::type = false>
auto create_device_instances()
{
return half_data::
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
F32,
Acc0BiasDataType,
C0DEElementwiseOperation,
MaskingSpec>{};
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -26,26 +26,31 @@ using S = ck::Sequence<Is...>; ...@@ -26,26 +26,31 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
template void add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute< void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
2, std::vector<std::unique_ptr<
1, DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
1, 1,
1, 1,
BF16, 1,
BF16, BF16,
BF16, BF16,
BF16, BF16,
ck::Tuple<BF16>, BF16,
ck::Tuple<>, ck::Tuple<BF16>,
PassThrough, ck::Tuple<>,
PassThrough, PassThrough,
ScaleAdd, PassThrough,
PassThrough, ScaleAdd,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& instances); PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_instances(instances);
}
template void add_device_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
...@@ -64,7 +69,10 @@ template void add_device_instances( ...@@ -64,7 +69,10 @@ template void add_device_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances)
{
add_device_instances(instances);
}
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -26,26 +26,31 @@ using S = ck::Sequence<Is...>; ...@@ -26,26 +26,31 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
template void add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute< void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
2, std::vector<std::unique_ptr<
1, DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
1, 1,
1, 1,
F16, 1,
F16, F16,
F16, F16,
F16, F16,
ck::Tuple<F16>, F16,
ck::Tuple<>, ck::Tuple<F16>,
PassThrough, ck::Tuple<>,
PassThrough, PassThrough,
ScaleAdd, PassThrough,
PassThrough, ScaleAdd,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& instances); PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_instances(instances);
}
template void add_device_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
...@@ -64,7 +69,10 @@ template void add_device_instances( ...@@ -64,7 +69,10 @@ template void add_device_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances)
{
add_device_instances(instances);
}
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -26,26 +26,31 @@ using S = ck::Sequence<Is...>; ...@@ -26,26 +26,31 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
template void add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute< void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
2, std::vector<std::unique_ptr<
1, DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
1, 1,
1, 1,
BF16, 1,
BF16, BF16,
BF16, BF16,
BF16, BF16,
ck::Tuple<>, BF16,
ck::Tuple<>, ck::Tuple<>,
PassThrough, ck::Tuple<>,
PassThrough, PassThrough,
Scale, PassThrough,
PassThrough, Scale,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& instances); PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_instances(instances);
}
template void add_device_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
...@@ -64,7 +69,10 @@ template void add_device_instances( ...@@ -64,7 +69,10 @@ template void add_device_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances)
{
add_device_instances(instances);
}
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -26,26 +26,30 @@ using S = ck::Sequence<Is...>; ...@@ -26,26 +26,30 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
template void add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute< void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
2, std::vector<std::unique_ptr<
1, DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
1, 1,
1, 1,
F16, 1,
F16, F16,
F16, F16,
F16, F16,
ck::Tuple<>, F16,
ck::Tuple<>, ck::Tuple<>,
PassThrough, ck::Tuple<>,
PassThrough, PassThrough,
Scale, PassThrough,
PassThrough, Scale,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& instances); PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
template void add_device_instances( instances)
{
add_device_instances(instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
...@@ -64,7 +68,10 @@ template void add_device_instances( ...@@ -64,7 +68,10 @@ template void add_device_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances)
{
add_device_instances(instances);
}
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment