"docs/reference/wrapper.rst" did not exist on "efaf31061a00a9c17a888ddbf2e273aafe977d5e"
Commit d8ab41d5 authored by root's avatar root
Browse files

add instances

parent bdf6cddb
...@@ -72,25 +72,6 @@ struct MultiplyAddFastGelu ...@@ -72,25 +72,6 @@ struct MultiplyAddFastGelu
} }
}; };
struct PassThroughPack2
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ constexpr void operator()(ck::bhalf2_t& y, const ck::int8x2_t& x) const
{
y = ck::bit_cast<ck::bhalf2_t>(static_cast<int32_t>(ck::bit_cast<int16_t>(x)));
}
template <>
__host__ __device__ void operator()<ck::bhalf_t, int8_t>(ck::bhalf_t& y, const int8_t& x) const
{
y = ck::type_convert<ck::bhalf_t>(x);
}
constexpr const static bool is_pack2_invocable = true;
};
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough; using AElementOp = PassThrough;
......
...@@ -177,6 +177,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -177,6 +177,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
#if 0
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
const auto kernel = const auto kernel =
...@@ -187,6 +188,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -187,6 +188,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
Run(kernel); Run(kernel);
} }
else else
#endif
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
...@@ -199,6 +201,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -199,6 +201,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be One to Seven // Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{ {
#if 0
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
...@@ -312,6 +315,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -312,6 +315,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
} }
} }
else else
#endif
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{ {
......
...@@ -311,6 +311,26 @@ struct AddAddFastGelu ...@@ -311,6 +311,26 @@ struct AddAddFastGelu
} }
}; };
struct MultiplyAddFastGelu
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<ck::bhalf_t, float, ck::bhalf_t, ck::bhalf_t>(
ck::bhalf_t& e, const float& c, const ck::bhalf_t& d0, const ck::bhalf_t& d1) const
{
const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = ck::type_convert<ck::bhalf_t>(x1_f);
}
};
// E = Relu(alpha1 * C + alpha2 * D0 + D1) // E = Relu(alpha1 * C + alpha2 * D0 + D1)
struct ScaleAddScaleAddRelu struct ScaleAddScaleAddRelu
{ {
......
...@@ -37,6 +37,7 @@ using I32_F32_Tuple = ck::Tuple<I32, F32>; ...@@ -37,6 +37,7 @@ using I32_F32_Tuple = ck::Tuple<I32, F32>;
using I8_Tuple = ck::Tuple<I8>; using I8_Tuple = ck::Tuple<I8>;
using F32_F32_Tuple = ck::Tuple<F32, F32>; using F32_F32_Tuple = ck::Tuple<F32, F32>;
using BF16_BF16_Tuple = ck::Tuple<BF16, BF16>;
// GEMM layout // GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -97,6 +98,7 @@ using TanH = ck::tensor_operation::element_wise::TanH; ...@@ -97,6 +98,7 @@ using TanH = ck::tensor_operation::element_wise::TanH;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AddRelu = ck::tensor_operation::element_wise::AddRelu; using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddSilu = ck::tensor_operation::element_wise::AddSilu; using AddSilu = ck::tensor_operation::element_wise::AddSilu;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_multiply_add_fastgelu_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Row_Row_Tuple,
Row,
BF16,
I8,
BF16_BF16_Tuple,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>>>&);
// GEMM + Multiply + Add + FastGelu
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename D1Layout,
typename ELayout,
typename ADataType,
typename BDataType,
typename D0DataType,
typename D1DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyAddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyAddFastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<D0DataType, bhalf_t> && is_same_v<D1DataType, bhalf_t> &&
is_same_v<EDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -65,6 +65,8 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES ...@@ -65,6 +65,8 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_bf16_i8_bf16_multi_d/device_gemm_xdl_universal_multiply_add_fastgelu_bf16_i8_bf16_mk_kn_mn_mnkpadding_instance.cpp
) )
add_instance_library(device_gemm_universal_instance ${GEMM_UNIVERSAL_INSTANCES}) add_instance_library(device_gemm_universal_instance ${GEMM_UNIVERSAL_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using MultiplyAddFastGelu = element_wise::MultiplyAddFastGelu;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
using DsLayout = ck::Tuple<Row, Row>;
template <
typename DsDType,
typename CElementwiseOp,
GemmSpecialization GemmSpec
>
using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <
typename DsDType,
typename CElementwiseOp,
GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched
>
using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Row, DsLayout, Row, BF16, I8, DsDType, BF16, F32, F32, PassThrough, PassThrough, CElementwiseOp, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_multiply_add_fastgelu_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
ck::Tuple<Row, Row>,
Row,
BF16,
I8,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<BF16, BF16>,
MultiplyAddFastGelu,
GemmMNKPadding>{});
add_device_operation_instances(
instances,
device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<BF16, BF16>,
MultiplyAddFastGelu,
GemmMNKPadding,
Intrawave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
# ckProfiler # ckProfiler
set(PROFILER_SOURCES set(PROFILER_SOURCES
profiler.cpp profiler.cpp
#profile_gemm.cpp profile_gemm.cpp
#profile_reduce.cpp profile_reduce.cpp
#profile_groupnorm_bwd_data.cpp profile_groupnorm_bwd_data.cpp
#profile_groupnorm_fwd.cpp profile_groupnorm_fwd.cpp
#profile_layernorm_bwd_data.cpp profile_layernorm_bwd_data.cpp
#profile_layernorm_bwd_gamma_beta.cpp profile_layernorm_bwd_gamma_beta.cpp
#profile_groupnorm_bwd_gamma_beta.cpp profile_groupnorm_bwd_gamma_beta.cpp
#profile_layernorm_fwd.cpp profile_layernorm_fwd.cpp
#profile_max_pool3d_fwd.cpp profile_max_pool3d_fwd.cpp
#profile_avg_pool3d_bwd.cpp profile_avg_pool3d_bwd.cpp
#profile_max_pool3d_bwd.cpp profile_max_pool3d_bwd.cpp
#profile_softmax.cpp profile_softmax.cpp
#profile_batchnorm_fwd.cpp profile_batchnorm_fwd.cpp
#profile_batchnorm_bwd.cpp profile_batchnorm_bwd.cpp
#profile_batchnorm_infer.cpp profile_batchnorm_infer.cpp
#profile_conv_tensor_rearrange.cpp profile_conv_tensor_rearrange.cpp
#profile_transpose.cpp profile_transpose.cpp
#profile_permute_scale.cpp profile_permute_scale.cpp
profile_gemm_universal.cpp
) )
#if(GPU_TARGETS MATCHES "gfx9") if(GPU_TARGETS MATCHES "gfx9")
# if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
# list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
# list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
# endif() endif()
# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
# list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp) list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp)
# list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
# list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp)
# list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
# list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp)
# list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp)
# list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
# endif() endif()
# list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
# list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
# list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp)
# list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp)
# list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp)
# list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp)
# list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp)
# list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp)
#endif()
# endif()
#if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9")
# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9")
# list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
# endif() list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
# list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) endif()
# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp)
# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp)
#endif() list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp)
# endif()
#if(DL_KERNELS)
# list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) if(DL_KERNELS)
# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp)
#endif() list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp)
endif()
set(PROFILER_EXECUTABLE ckProfiler) set(PROFILER_EXECUTABLE ckProfiler)
...@@ -77,76 +77,76 @@ add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES}) ...@@ -77,76 +77,76 @@ add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors) target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance)
#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance)
#
#if(GPU_TARGETS MATCHES "gfx9") if(GPU_TARGETS MATCHES "gfx9")
# if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
# endif() endif()
# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
# endif() endif()
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
#endif() endif()
#
#if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11")
# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
# endif() endif()
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
#endif() endif()
#
#if(DL_KERNELS) if(DL_KERNELS)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
#endif() endif()
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment