Commit 32806d5f authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents e70a4d19 d0f355a3
...@@ -345,7 +345,11 @@ void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( ...@@ -345,7 +345,11 @@ void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances); DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances( void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances); DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
...@@ -575,7 +579,8 @@ struct DeviceOperationInstanceFactory< ...@@ -575,7 +579,8 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
{ {
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances(op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
......
...@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
BF16, BF16,
BF16, BF16,
...@@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
F16, F16,
F16, F16,
...@@ -59,7 +59,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -59,7 +59,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
F32, F32,
F32, F32,
...@@ -75,7 +75,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -75,7 +75,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
int8_t, int8_t,
int8_t, int8_t,
...@@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> && if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>) is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK> &&
DLayouts::Size() == 2 && is_same_v<tuple_element_t<0, DLayouts>, NDHWGK> &&
is_same_v<tuple_element_t<1, DLayouts>, G_K>)
{ {
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP32
// FP32
void add_device_groupnorm_bwd_data_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdData<F32, F32, F32, F32, F32, 5, 3>>>&);
#endif
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
5,
3>>
{
using DeviceOp = DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
5,
3>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<DYDataType, F32> && is_same_v<XDataType, F32> &&
is_same_v<GammaDataType, F32> && is_same_v<MeanInvStdDataType, F32> &&
is_same_v<DXDataType, F32>)
{
add_device_groupnorm_bwd_data_f32_instances(op_ptrs);
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
// FP16
void add_device_layernorm2d_bwd_data_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdData<F16, F16, F16, F16, F16, 2, 1>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void add_device_layernorm2d_bwd_data_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdData<F32, F32, F32, F32, F32, 2, 1>>>&);
#endif
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<DYDataType, F16> && is_same_v<XDataType, F16> &&
is_same_v<GammaDataType, F16> && is_same_v<MeanInvStdDataType, F16> &&
is_same_v<DXDataType, F16>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
add_device_layernorm2d_bwd_data_f16_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<DYDataType, F32> && is_same_v<XDataType, F32> &&
is_same_v<GammaDataType, F32> && is_same_v<MeanInvStdDataType, F32> &&
is_same_v<DXDataType, F32>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
add_device_layernorm2d_bwd_data_f32_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -20,15 +20,15 @@ namespace instance { ...@@ -20,15 +20,15 @@ namespace instance {
// FP16 // FP16
void add_device_normalization_fwd_rank_2_1_f16_instances( void add_device_normalization_fwd_rank_2_1_f16_instances(
std::vector< std::vector<
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 2, 1>>>&); std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, PassThrough, 2, 1>>>&);
void add_device_normalization_fwd_rank_4_3_f16_instances( void add_device_normalization_fwd_rank_4_3_f16_instances(
std::vector< std::vector<
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 4, 3>>>&); std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, PassThrough, 4, 3>>>&);
void add_device_normalization_fwd_rank_5_3_f16_instances( void add_device_normalization_fwd_rank_5_3_f16_instances(
std::vector< std::vector<
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 5, 3>>>&); std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, PassThrough, 5, 3>>>&);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
// FP32 // FP32
...@@ -76,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -76,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> && if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16> && is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<SaveMeanInvStdDataType, F32>) is_same_v<SaveMeanInvStdDataType, F16>)
{ {
if constexpr(Rank == 2 && NumReduceDim == 1) if constexpr(Rank == 2 && NumReduceDim == 1)
{ {
......
...@@ -19,7 +19,7 @@ namespace instance { ...@@ -19,7 +19,7 @@ namespace instance {
// FP16 // FP16
void add_device_normalization_fwd_rank_5_3_swish_f16_instances( void add_device_normalization_fwd_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Swish, 5, 3>>>&); std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F16, Swish, 5, 3>>>&);
// FP32 // FP32
void add_device_normalization_fwd_rank_5_3_swish_f32_instances( void add_device_normalization_fwd_rank_5_3_swish_f32_instances(
...@@ -61,7 +61,7 @@ struct DeviceOperationInstanceFactory< ...@@ -61,7 +61,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> && if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16> && is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<SaveMeanInvStdDataType, F32>) is_same_v<SaveMeanInvStdDataType, F16>)
{ {
if constexpr(Rank == 5 && NumReduceDim == 3) if constexpr(Rank == 5 && NumReduceDim == 3)
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_permute_scale_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>,
ck::Tuple<F16>,
PassThrough,
element_wise::UnarySquare,
Scale,
4>>>&);
void add_device_permute_scale_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>,
ck::Tuple<F32>,
PassThrough,
element_wise::UnarySquare,
Scale,
4>>>&);
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale,
index_t NumDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
UnaryOperation,
Scale,
NumDim>>
{
using DeviceOp = DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
UnaryOperation,
Scale,
NumDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F32>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F32>>)
{
add_device_permute_scale_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F16>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F16>>)
{
add_device_permute_scale_f16_instances(op_ptrs);
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -101,7 +101,8 @@ list(APPEND GEMM_INSTANCES ...@@ -101,7 +101,8 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp) device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp) device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp)
......
...@@ -16,6 +16,7 @@ namespace tensor_operation { ...@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
......
...@@ -16,6 +16,7 @@ namespace tensor_operation { ...@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp"
#ifdef CK_ENABLE_FP8
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances<GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp"
#ifdef CK_ENABLE_FP8
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances<MNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -34,22 +34,19 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances = ...@@ -34,22 +34,19 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 64, 16, 16, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 64, 8, 8, 32, 32, 1, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 64, 16, 16, 16, 16, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 32, 16, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 32, 128, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 32, 16, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 64, 8, 8, 32, 32, 1, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 128, 32, 32, 16, 16, 1, 1, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 128, 32, 32, 16, 16, 1, 1, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on // clang-format on
>; >;
......
...@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
BF16, BF16,
BF16, BF16,
...@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
ConvFwdDefault>{}); ConvFwdDefault>{});
add_device_operation_instances( add_device_operation_instances(
...@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
ConvFwd1x1P0>{}); ConvFwd1x1P0>{});
add_device_operation_instances( add_device_operation_instances(
...@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
ConvFwd1x1S1P0>{}); ConvFwd1x1S1P0>{});
} }
......
...@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
F16, F16,
F16, F16,
...@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
ConvFwdDefault>{}); ConvFwdDefault>{});
add_device_operation_instances( add_device_operation_instances(
...@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
ConvFwd1x1P0>{}); ConvFwd1x1P0>{});
add_device_operation_instances( add_device_operation_instances(
...@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
ConvFwd1x1S1P0>{}); ConvFwd1x1S1P0>{});
} }
......
...@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
F32, F32,
F32, F32,
...@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
ConvFwdDefault>{}); ConvFwdDefault>{});
add_device_operation_instances( add_device_operation_instances(
...@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
ConvFwd1x1P0>{}); ConvFwd1x1P0>{});
add_device_operation_instances( add_device_operation_instances(
...@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
ConvFwd1x1S1P0>{}); ConvFwd1x1S1P0>{});
} }
......
...@@ -12,7 +12,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -12,7 +12,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
int8_t, int8_t,
int8_t, int8_t,
...@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
ConvFwdDefault>{}); ConvFwdDefault>{});
add_device_operation_instances( add_device_operation_instances(
...@@ -35,7 +35,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -35,7 +35,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
ConvFwd1x1P0>{}); ConvFwd1x1P0>{});
add_device_operation_instances( add_device_operation_instances(
...@@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw ...@@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3, device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
ck::Tuple<NDHWGK, NDHWGK>, ck::Tuple<NDHWGK, G_K>,
NDHWGK, NDHWGK,
ConvFwd1x1S1P0>{}); ConvFwd1x1S1P0>{});
} }
......
set(DEVICE_NORMALIZATION_bwd_data_INSTANCES)
list(APPEND DEVICE_NORMALIZATION_bwd_data_INSTANCES
device_groupnorm_bwd_data_f32_instance.cpp
device_layernorm2d_bwd_data_f16_instance.cpp
device_layernorm2d_bwd_data_f32_instance.cpp)
add_instance_library(device_normalization_bwd_data_instance ${DEVICE_NORMALIZATION_bwd_data_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_data_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_groupnorm_bwd_data_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdData<F32, F32, F32, F32, F32, 5, 3>>>&
instances)
{
add_device_operation_instances(instances, device_groupnorm_bwd_data_f32_generic_instance{});
add_device_operation_instances(instances, device_groupnorm_bwd_data_f32_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_bwd_data_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_data_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdData<F16, F16, F16, F16, F16, 2, 1>>>&
instances)
{
add_device_operation_instances(instances,
device_layernorm_bwd_data_f16_generic_instance<2, 1>{});
add_device_operation_instances(instances, device_layernorm_bwd_data_f16_instances<2, 1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment