Commit 9f1b4276 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

resolved conflicts

parents 711857c4 c7010716
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_INT8
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP64
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP8
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -290,6 +290,42 @@ using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple<
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, BF8, F32, F8, DsLayout, F8, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>
#endif
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -10,439 +10,18 @@
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_USE_XDL
#include "grouped_convolution_backward_data_xdl.inc"
#endif
#ifdef CK_USE_WMMA
#include "grouped_convolution_backward_data_wmma.inc"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv2d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv3d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough,
BF8,
F8>>>& instances);
#endif
template <ck::index_t NumDimSpatial,
typename OutLayout,
typename WeiLayout,
......@@ -488,9 +67,10 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_USE_XDL
if constexpr(NumDimSpatial == 2)
{
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, GNHWK>)
{
......@@ -500,43 +80,28 @@ struct DeviceOperationInstanceFactory<
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> &&
is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
......@@ -544,45 +109,29 @@ struct DeviceOperationInstanceFactory<
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> &&
is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
}
else if constexpr(NumDimSpatial == 3)
if constexpr(NumDimSpatial == 3)
{
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
{
......@@ -593,45 +142,29 @@ struct DeviceOperationInstanceFactory<
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> &&
is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
......@@ -640,44 +173,139 @@ struct DeviceOperationInstanceFactory<
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
else if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, bf8_t> &&
is_same_v<ComputeTypeB, f8_t>)
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, bf8_t> &&
is_same_v<ComputeTypeB, f8_t>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
is_same_v<ComputeTypeB, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
}
#endif
}
}
#endif
#ifdef CK_USE_WMMA
if constexpr(NumDimSpatial == 2)
{
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, GNHWK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> &&
is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
}
if constexpr(NumDimSpatial == 3)
{
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
op_ptrs);
......@@ -687,6 +315,7 @@ struct DeviceOperationInstanceFactory<
#endif
}
}
#endif
return op_ptrs;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv2d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv3d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough,
BF8,
F8>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -12,564 +12,19 @@
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// xdl
// conv1d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv2d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv3d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
BF8,
F8>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef DL_KERNELS
// dl
// conv1d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#include "grouped_convolution_backward_weight_dl.inc"
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv2d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv3d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#ifdef CK_USE_XDL
#include "grouped_convolution_backward_weight_xdl.inc"
#endif
#ifdef CK_USE_WMMA
#include "grouped_convolution_backward_weight_wmma.inc"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t NumDimSpatial,
typename InLayout,
......@@ -611,6 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef DL_KERNELS
if constexpr(NumDimSpatial == 1)
{
if constexpr(is_same_v<InLayout, GNWC> && is_same_v<WeiLayout, GKXC> &&
......@@ -621,10 +77,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>)
{
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
#endif
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
......@@ -632,10 +85,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
#endif
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
......@@ -644,19 +94,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
op_ptrs);
#endif
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
}
if constexpr(is_same_v<InLayout, NWGC> && is_same_v<WeiLayout, GKXC> &&
is_same_v<OutLayout, NWGK>)
{
#ifdef DL_KERNELS
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
......@@ -682,7 +127,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
#endif
}
}
......@@ -696,12 +140,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
......@@ -709,12 +149,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
......@@ -723,12 +159,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
}
......@@ -740,12 +172,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
......@@ -753,12 +181,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
......@@ -767,12 +191,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
}
......@@ -787,12 +207,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances(
op_ptrs);
#endif
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
......@@ -800,15 +216,39 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs);
}
#endif
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>)
{
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
}
#endif
}
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>)
{
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
}
#endif
......@@ -818,40 +258,125 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
}
}
#endif // DL_KERNELS
#ifdef CK_USE_XDL
if constexpr(NumDimSpatial == 1)
{
if constexpr(is_same_v<InLayout, GNWC> && is_same_v<WeiLayout, GKXC> &&
is_same_v<OutLayout, GNWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> &&
is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
}
}
if constexpr(NumDimSpatial == 2)
{
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, GNHWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>)
{
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
}
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
}
#endif
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
}
}
if constexpr(NumDimSpatial == 3)
{
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
op_ptrs);
}
#endif
......@@ -860,15 +385,39 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs);
}
#endif
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
}
#endif
}
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
}
#endif
......@@ -878,11 +427,36 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, bf8_t> &&
is_same_v<ComputeTypeB, f8_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
op_ptrs);
}
#endif
}
}
#endif
#ifdef CK_USE_WMMA
if constexpr(NumDimSpatial == 3)
{
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
......@@ -892,23 +466,42 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
}
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, bf8_t> &&
is_same_v<ComputeTypeB, f8_t>)
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> &&
is_same_v<ComputeTypeA, int8_t> &&
is_same_v<ComputeTypeB, int8_t>)
{
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
}
#endif
return op_ptrs;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv1d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv2d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv3d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv3d backward weight
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv1d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv2d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv3d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
BF8,
F8>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -12,907 +12,20 @@
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_BF16
// grouped conv1d forward, GNWC/GKXC/GNWK
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#ifdef DL_KERNELS
#include "grouped_convolution_forward_dl.inc"
#endif
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#ifdef CK_USE_XDL
#include "grouped_convolution_forward_xdl.inc"
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#ifdef CK_USE_WMMA
#include "grouped_convolution_forward_wmma.inc"
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP8
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough,
F8>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F8,
F8,
Empty_Tuple,
F8,
PassThrough,
PassThrough,
PassThrough,
F8>>>& instances);
#endif
#ifdef CK_ENABLE_BF8
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF8,
BF8,
Empty_Tuple,
F8,
PassThrough,
PassThrough,
PassThrough,
BF8>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS))
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS))
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS))
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS))
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t NumDimSpatial,
typename InLayout,
......@@ -921,7 +34,8 @@ template <ck::index_t NumDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename ComputeType>
typename AComputeType,
typename BComputeType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial,
InLayout,
......@@ -935,7 +49,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ComputeType>>
AComputeType,
BComputeType>>
{
using DeviceOp =
DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
......@@ -950,25 +65,73 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ComputeType>;
AComputeType,
BComputeType>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef DL_KERNELS
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
#endif
}
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
}
#endif
}
#endif // DL_KERNELS
#ifdef CK_USE_XDL
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, GNWC> &&
is_same_v<WeiLayout, GKXC> && is_same_v<OutLayout, GNWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
}
......@@ -976,14 +139,17 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, bhalf_t>)
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<ComputeType, int8_t>)
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
is_same_v<BComputeType, int8_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs);
}
......@@ -995,114 +161,60 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
}
#endif
#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS))
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(op_ptrs);
}
#endif
#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS))
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, bhalf_t>)
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<ComputeType, int8_t>)
{
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(op_ptrs);
}
#endif
}
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
}
#endif
#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS))
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
{
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
}
#endif
#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS))
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, bhalf_t>)
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<ComputeType, int8_t>)
{
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(op_ptrs);
}
#endif
}
......@@ -1111,42 +223,36 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, bhalf_t>)
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<ComputeType, int8_t>)
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
is_same_v<BComputeType, int8_t>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(op_ptrs);
}
#endif
}
......@@ -1156,7 +262,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<ComputeType, float>)
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
}
......@@ -1164,51 +271,160 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#ifdef CK_ENABLE_FP8
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, ck::f8_t>)
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, ck::f8_t> &&
is_same_v<BComputeType, ck::f8_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances(
op_ptrs);
}
if constexpr(is_same_v<InDataType, ck::f8_t> && is_same_v<WeiDataType, ck::f8_t> &&
is_same_v<OutDataType, ck::f8_t> && is_same_v<ComputeType, ck::f8_t>)
is_same_v<OutDataType, ck::f8_t> && is_same_v<AComputeType, ck::f8_t> &&
is_same_v<BComputeType, ck::f8_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF8
if constexpr(is_same_v<InDataType, ck::bf8_t> && is_same_v<WeiDataType, ck::bf8_t> &&
is_same_v<OutDataType, ck::f8_t> && is_same_v<ComputeType, ck::bf8_t>)
is_same_v<OutDataType, ck::f8_t> && is_same_v<AComputeType, ck::bf8_t> &&
is_same_v<BComputeType, ck::bf8_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(op_ptrs);
}
#endif
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
if constexpr(is_same_v<InDataType, ck::f8_t> && is_same_v<WeiDataType, ck::bf8_t> &&
is_same_v<OutDataType, ck::f8_t> && is_same_v<AComputeType, ck::f8_t> &&
is_same_v<BComputeType, ck::bf8_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, bhalf_t>)
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<ComputeType, int8_t>)
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
is_same_v<BComputeType, int8_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs);
}
#endif
}
#endif
#ifdef CK_USE_WMMA
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
is_same_v<BComputeType, int8_t>)
{
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(op_ptrs);
}
#endif
}
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
is_same_v<BComputeType, int8_t>)
{
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(op_ptrs);
}
#endif
}
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
is_same_v<BComputeType, int8_t>)
{
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(op_ptrs);
}
#endif
}
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t> && is_same_v<AComputeType, int8_t> &&
is_same_v<BComputeType, int8_t>)
{
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
......@@ -1217,6 +433,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
}
#endif
return op_ptrs;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_BF16
// grouped conv1d forward, GNWC/GKXC/GNWK
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP8
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough,
F8>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F8,
F8,
Empty_Tuple,
F8,
PassThrough,
PassThrough,
PassThrough,
F8>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF8
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF8,
BF8,
Empty_Tuple,
F8,
PassThrough,
PassThrough,
PassThrough,
BF8>>>& instances);
#endif
#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8))
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F8,
BF8,
Empty_Tuple,
F8,
PassThrough,
PassThrough,
PassThrough,
F8,
BF8>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -172,7 +172,6 @@ void add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_in
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
......@@ -281,8 +280,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<EDataType, half_t>)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
......@@ -301,8 +300,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
else if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, bhalf_t>)
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
......@@ -311,7 +310,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_nk_mn_instances(
op_ptrs);
......@@ -319,8 +318,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
#if defined(CK_ENABLE_BF16)
else if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<EDataType, bhalf_t>)
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<EDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
......@@ -329,7 +328,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instances(
op_ptrs);
......
......@@ -36,12 +36,27 @@ function(add_instance_library INSTANCE_NAME)
endif()
endforeach()
endif()
# Do not build DL instances if DL_KERNELS macro is not set
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
# Do not build XDL instances if gfx9 targets are not on the target list
foreach(source IN LISTS ARGN)
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
message("removing xdl instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
# Do not build WMMA instances if gfx11 targets are not on the target list
foreach(source IN LISTS ARGN)
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
message("removing wmma instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(ARGN)
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
......@@ -124,6 +139,26 @@ FOREACH(subdir_path ${dir_list})
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
set(add_inst 0)
endif()
if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx9"))
message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.")
set(add_inst 0)
endif()
if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11"))
message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.")
set(add_inst 0)
endif()
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT GPU_TARGETS MATCHES "gfx9"))
message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.")
set(add_inst 0)
endif()
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9"))
message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.")
set(add_inst 0)
endif()
if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS))
message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
set(add_inst 0)
endif()
if((add_inst EQUAL 1))
get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir})
......
# ONLY XDL_KERNELS
set(BATCHED_GEMM_INSTANCES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
......
# ONLY XDL_KERNELS
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
......
# ONLY XDL_KERNELS
add_instance_library(device_batched_gemm_bias_permute_instance
device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp
)
......
# ONLY XDL_KERNELS
add_instance_library(device_batched_gemm_gemm_instance
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
......
# ONLY XDL_KERNELS
add_instance_library(device_batched_gemm_reduce_instance
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment