Unverified Commit 74284dd2 authored by rocking's avatar rocking Committed by GitHub
Browse files

Merge branch 'develop' into avgpool_bwd

parents 01aeb901 50643dd5
......@@ -16,7 +16,7 @@ namespace device {
namespace instance {
// conv2d backward data
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
......@@ -30,7 +30,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
......@@ -44,7 +44,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
......@@ -58,7 +58,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
......@@ -72,7 +72,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
......@@ -86,7 +86,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
......@@ -100,6 +100,91 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
PassThrough,
PassThrough>>>& instances);
// conv3d backward data
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename OutLayout,
typename WeiLayout,
......@@ -139,43 +224,96 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
if constexpr(NumDimSpatial == 2)
{
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, GNHWK>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
}
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
else if constexpr(NumDimSpatial == 3)
{
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
}
}
}
......
......@@ -164,6 +164,42 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
......@@ -273,8 +309,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
else if constexpr(NumDimSpatial == 3)
{
if(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
......@@ -296,6 +332,29 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
}
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
op_ptrs);
}
}
}
return op_ptrs;
......
......@@ -12,9 +12,42 @@ set(CK_DEVICE_INSTANCES)
FOREACH(subdir_path ${dir_list})
set(target_dir)
IF(IS_DIRECTORY "${subdir_path}")
get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir})
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
set(cmake_instance)
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
set(add_inst 0)
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8")
#message("fp8 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
#message("fp16 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32")
#message("fp32 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64")
#message("fp64 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16")
#message("bf16 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8")
#message("int8 instance found!")
set(add_inst 1)
endif()
if(NOT "${cmake_instance}" MATCHES "DTYPES")
#message("instance should be built for all types!")
set(add_inst 1)
endif()
if(add_inst EQUAL 1 OR NOT DEFINED DTYPES)
get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir})
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
endif()
ENDIF()
ENDFOREACH()
......
......@@ -31,22 +31,24 @@ using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
std::tuple<
// clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K|NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K|NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CDE0BlockTransfer| CDE0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcVectorDim| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
//generic
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
// no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 9, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
// Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
......
......@@ -31,23 +31,25 @@ using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances =
std::tuple<
// clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K| NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1| A0K1| B0K1|B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K| NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1| A0K1| B0K1|B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CDE0BlockTransfer| CDE0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcVectorDim| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
//generic
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
// no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
// Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 9, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
......
add_instance_library(device_batched_gemm_multi_d_instance
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp
)
set(BATCHED_GEMM_MULTID_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp)
endif()
add_instance_library(device_batched_gemm_multi_d_instance ${BATCHED_GEMM_MULTID_INSTANCES})
add_instance_library(device_conv2d_bwd_data_instance
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp
)
set(CONV2D_BWD_DATA_INSTANCES)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp)
endif()
add_instance_library(device_conv2d_bwd_data_instance ${CONV2D_BWD_DATA_INSTANCES})
......@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -151,3 +151,4 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
add_instance_library(device_gemm_instance
device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
set(GEMM_INSTANCES)
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)
endif()
device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp
)
add_instance_library(device_gemm_instance ${GEMM_INSTANCES})
set(ENABLE_PIPELINE_V2_OPT OFF)
......
......@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment