Commit defa2071 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into aosewski/ggemm_multi_d2

parents 28a68428 f2398f61
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <memory> #include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -18,25 +18,31 @@ namespace device { ...@@ -18,25 +18,31 @@ namespace device {
namespace instance { namespace instance {
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
// FP16 // FP16
void add_device_normalization_rank_2_1_f16_instances( void add_device_normalization_fwd_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 2, 1>>>&); std::vector<
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 2, 1>>>&);
void add_device_normalization_rank_4_3_f16_instances( void add_device_normalization_fwd_rank_4_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 4, 3>>>&); std::vector<
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 4, 3>>>&);
void add_device_normalization_rank_5_3_f16_instances( void add_device_normalization_fwd_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 5, 3>>>&); std::vector<
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 5, 3>>>&);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
// FP32 // FP32
void add_device_normalization_rank_2_1_f32_instances( void add_device_normalization_fwd_rank_2_1_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&); std::vector<
std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&);
void add_device_normalization_rank_4_3_f32_instances( void add_device_normalization_fwd_rank_4_3_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&); std::vector<
std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
void add_device_normalization_rank_5_3_f32_instances( void add_device_normalization_fwd_rank_5_3_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 5, 3>>>&); std::vector<
std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, PassThrough, 5, 3>>>&);
#endif #endif
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
...@@ -45,7 +51,7 @@ template <typename XDataType, ...@@ -45,7 +51,7 @@ template <typename XDataType,
typename SaveMeanInvStdDataType, typename SaveMeanInvStdDataType,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormalization< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormalizationFwd<
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
...@@ -55,14 +61,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -55,14 +61,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
Rank, Rank,
NumReduceDim>> NumReduceDim>>
{ {
using DeviceOp = DeviceNormalization<XDataType, using DeviceOp = DeviceNormalizationFwd<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType, SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -74,15 +80,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -74,15 +80,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
{ {
if constexpr(Rank == 2 && NumReduceDim == 1) if constexpr(Rank == 2 && NumReduceDim == 1)
{ {
add_device_normalization_rank_2_1_f16_instances(op_ptrs); add_device_normalization_fwd_rank_2_1_f16_instances(op_ptrs);
} }
else if constexpr(Rank == 4 && NumReduceDim == 3) else if constexpr(Rank == 4 && NumReduceDim == 3)
{ {
add_device_normalization_rank_4_3_f16_instances(op_ptrs); add_device_normalization_fwd_rank_4_3_f16_instances(op_ptrs);
} }
else if constexpr(Rank == 5 && NumReduceDim == 3) else if constexpr(Rank == 5 && NumReduceDim == 3)
{ {
add_device_normalization_rank_5_3_f16_instances(op_ptrs); add_device_normalization_fwd_rank_5_3_f16_instances(op_ptrs);
} }
} }
#endif #endif
...@@ -93,15 +99,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -93,15 +99,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
{ {
if constexpr(Rank == 2 && NumReduceDim == 1) if constexpr(Rank == 2 && NumReduceDim == 1)
{ {
add_device_normalization_rank_2_1_f32_instances(op_ptrs); add_device_normalization_fwd_rank_2_1_f32_instances(op_ptrs);
} }
else if constexpr(Rank == 4 && NumReduceDim == 3) else if constexpr(Rank == 4 && NumReduceDim == 3)
{ {
add_device_normalization_rank_4_3_f32_instances(op_ptrs); add_device_normalization_fwd_rank_4_3_f32_instances(op_ptrs);
} }
else if constexpr(Rank == 5 && NumReduceDim == 3) else if constexpr(Rank == 5 && NumReduceDim == 3)
{ {
add_device_normalization_rank_5_3_f32_instances(op_ptrs); add_device_normalization_fwd_rank_5_3_f32_instances(op_ptrs);
} }
} }
#endif #endif
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -18,16 +18,16 @@ namespace device { ...@@ -18,16 +18,16 @@ namespace device {
namespace instance { namespace instance {
// FP16 // FP16
void add_device_normalization_rank_5_3_swish_f16_instances( void add_device_normalization_fwd_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Swish, 5, 3>>>&); std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Swish, 5, 3>>>&);
// FP32 // FP32
void add_device_normalization_rank_5_3_swish_f32_instances( void add_device_normalization_fwd_rank_5_3_swish_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>&); std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>>&);
// [x, gamma, beta, y] = [f16, f32, f32, f16] // [x, gamma, beta, y] = [f16, f32, f32, f16]
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( void add_device_normalization_fwd_rank_5_3_swish_f16_f32_f32_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F16, F32, Swish, 5, 3>>>&); std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>>&);
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
...@@ -37,23 +37,23 @@ template <typename XDataType, ...@@ -37,23 +37,23 @@ template <typename XDataType,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalization<XDataType, ck::tensor_operation::device::DeviceNormalizationFwd<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType, SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::Swish, ck::tensor_operation::element_wise::Swish,
Rank, Rank,
NumReduceDim>> NumReduceDim>>
{ {
using DeviceOp = DeviceNormalization<XDataType, using DeviceOp = DeviceNormalizationFwd<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType, SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::Swish, ck::tensor_operation::element_wise::Swish,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -65,7 +65,7 @@ struct DeviceOperationInstanceFactory< ...@@ -65,7 +65,7 @@ struct DeviceOperationInstanceFactory<
{ {
if constexpr(Rank == 5 && NumReduceDim == 3) if constexpr(Rank == 5 && NumReduceDim == 3)
{ {
add_device_normalization_rank_5_3_swish_f16_instances(op_ptrs); add_device_normalization_fwd_rank_5_3_swish_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> && else if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
...@@ -74,7 +74,7 @@ struct DeviceOperationInstanceFactory< ...@@ -74,7 +74,7 @@ struct DeviceOperationInstanceFactory<
{ {
if constexpr(Rank == 5 && NumReduceDim == 3) if constexpr(Rank == 5 && NumReduceDim == 3)
{ {
add_device_normalization_rank_5_3_swish_f32_instances(op_ptrs); add_device_normalization_fwd_rank_5_3_swish_f32_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F32> && else if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F32> &&
...@@ -83,7 +83,7 @@ struct DeviceOperationInstanceFactory< ...@@ -83,7 +83,7 @@ struct DeviceOperationInstanceFactory<
{ {
if constexpr(Rank == 5 && NumReduceDim == 3) if constexpr(Rank == 5 && NumReduceDim == 3)
{ {
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(op_ptrs); add_device_normalization_fwd_rank_5_3_swish_f16_f32_f32_f16_instances(op_ptrs);
} }
} }
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -20,96 +20,96 @@ namespace instance { ...@@ -20,96 +20,96 @@ namespace instance {
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul2_Clamp<PassThrough>>>>& Add_Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul2_Clamp<Relu>>>>& Add_Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>& Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
#endif #endif
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul2_Clamp<PassThrough>>>>& Add_Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul2_Clamp<Relu>>>>& Add_Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>& Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
// piecewise activation function // piecewise activation function
...@@ -123,7 +123,7 @@ template <ck::index_t NumDimSpatial, ...@@ -123,7 +123,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -137,18 +137,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -137,18 +137,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul2_Clamp<Activation>>> Add_Activation_Mul2_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
DsLayout, WeiLayout,
OutLayout, DsLayout,
InDataType, OutLayout,
WeiDataType, InDataType,
DsDataType, WeiDataType,
OutDataType, DsDataType,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul2_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul2_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -193,7 +194,7 @@ template <ck::index_t NumDimSpatial, ...@@ -193,7 +194,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -207,18 +208,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -207,18 +208,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>> Add_Mul2_Activation_Mul_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
DsLayout, WeiLayout,
OutLayout, DsLayout,
InDataType, OutLayout,
WeiDataType, InDataType,
DsDataType, WeiDataType,
OutDataType, DsDataType,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -20,94 +20,96 @@ namespace instance { ...@@ -20,94 +20,96 @@ namespace instance {
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul_Clamp<PassThrough>>>>& Add_Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul_Clamp<Relu>>>>& Add_Activation_Mul_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
GK_Tuple, GKYXC,
NHWGK, GK_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
I32_Tuple, int8_t,
int8_t, I32_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>& PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
#endif #endif
void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul_Clamp<PassThrough>>>>& Add_Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul_Clamp<Relu>>>>& Add_Activation_Mul_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
GK_Tuple, GKYXC,
NHWGK, GK_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
I32_Tuple, int8_t,
int8_t, I32_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>& PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
// piecewise activation function // piecewise activation function
...@@ -121,7 +123,7 @@ template <ck::index_t NumDimSpatial, ...@@ -121,7 +123,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -135,18 +137,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -135,18 +137,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul_Clamp<Activation>>> Add_Activation_Mul_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
DsLayout, WeiLayout,
OutLayout, DsLayout,
InDataType, OutLayout,
WeiDataType, InDataType,
DsDataType, WeiDataType,
OutDataType, DsDataType,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -191,7 +194,7 @@ template <ck::index_t NumDimSpatial, ...@@ -191,7 +194,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -205,18 +208,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -205,18 +208,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>> Add_Mul_Activation_Mul_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
DsLayout, WeiLayout,
OutLayout, DsLayout,
InDataType, OutLayout,
WeiDataType, InDataType,
DsDataType, WeiDataType,
OutDataType, DsDataType,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -19,63 +19,65 @@ namespace instance { ...@@ -19,63 +19,65 @@ namespace instance {
#ifdef DL_KERNELS #ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perchannel_quantization_int8_instances( void add_device_conv2d_dl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
GK_Tuple, GKYXC,
NHWGK, GK_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
F32_Tuple, int8_t,
int8_t, F32_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>& PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances( void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
F32_Tuple, F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Activation_Mul2_Clamp<Relu>>>>& Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
#endif #endif
void add_device_conv2d_xdl_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
GK_Tuple, GKYXC,
NHWGK, GK_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
F32_Tuple, int8_t,
int8_t, F32_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>& PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
F32_Tuple, F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Activation_Mul2_Clamp<Relu>>>>& Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
...@@ -88,7 +90,7 @@ template <ck::index_t NumDimSpatial, ...@@ -88,7 +90,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -102,18 +104,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -102,18 +104,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Activation_Mul2_Clamp<Activation>>> Activation_Mul2_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
GK_Tuple, WeiLayout,
OutLayout, GK_Tuple,
InDataType, OutLayout,
WeiDataType, InDataType,
F32_Tuple, WeiDataType,
OutDataType, F32_Tuple,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Activation_Mul2_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Activation_Mul2_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -19,63 +19,65 @@ namespace instance { ...@@ -19,63 +19,65 @@ namespace instance {
#ifdef DL_KERNELS #ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perlayer_quantization_int8_instances( void add_device_conv2d_dl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
Empty_Tuple, int8_t,
int8_t, Empty_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Activation_Mul_Clamp<PassThrough>>>>& PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances( void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Activation_Mul_Clamp<Relu>>>>& Activation_Mul_Clamp<Relu>>>>&
instances); instances);
#endif #endif
void add_device_conv2d_xdl_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
Empty_Tuple, int8_t,
int8_t, Empty_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Activation_Mul_Clamp<PassThrough>>>>& PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Activation_Mul_Clamp<Relu>>>>& Activation_Mul_Clamp<Relu>>>>&
instances); instances);
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
...@@ -86,7 +88,7 @@ template <ck::index_t NumDimSpatial, ...@@ -86,7 +88,7 @@ template <ck::index_t NumDimSpatial,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -100,18 +102,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -100,18 +102,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>> Activation_Mul_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
Empty_Tuple, WeiLayout,
OutLayout, Empty_Tuple,
InDataType, OutLayout,
WeiDataType, InDataType,
Empty_Tuple, WeiDataType,
OutDataType, Empty_Tuple,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using device_transpose_f16_instances = std::tuple<
// FOR 16, 32, 16, 32, 16
// clang-format off
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 1, 1, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 2, 2, 1, 8, 4, 4, ck::Sequence<1>, ck::Sequence<1>>
// clang-format on
>;
using device_transpose_f32_instances = std::tuple<
// for 16, 8, 16, 32, 8 -> test with instances for fp16
// clang-format off
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 4, ck::Sequence<1>, ck::Sequence<1>>,
DeviceElementwise3dImpl<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 2, 2, 1, 4, 8, 8, ck::Sequence<1>, ck::Sequence<1>>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
void add_device_transpose_f16_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, PassThrough, 5>>>&
instances);
void add_device_transpose_f32_instances(
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, PassThrough, 5>>>&
instances);
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::
DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>>
{
using DeviceOp =
DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F32>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F32>>)
{
add_device_transpose_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataTypeTuple, ck::Tuple<F16>> &&
is_same_v<OutDataTypeTuple, ck::Tuple<F16>>)
{
add_device_transpose_f16_instances(op_ptrs);
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -22,7 +22,7 @@ static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNu ...@@ -22,7 +22,7 @@ static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNu
std::ofstream outFile(fileName, std::ios::binary); std::ofstream outFile(fileName, std::ios::binary);
if(outFile) if(outFile)
{ {
outFile.write(reinterpret_cast<char*>(data), dataNumItems * sizeof(T)); outFile.write(reinterpret_cast<const char*>(data), dataNumItems * sizeof(T));
outFile.close(); outFile.close();
std::cout << "Write output to file " << fileName << std::endl; std::cout << "Write output to file " << fileName << std::endl;
} }
......
...@@ -200,10 +200,11 @@ struct GeneratorTensor_3<ck::bf8_t> ...@@ -200,10 +200,11 @@ struct GeneratorTensor_3<ck::bf8_t>
template <typename T> template <typename T>
struct GeneratorTensor_4 struct GeneratorTensor_4
{ {
std::default_random_engine generator; std::mt19937 generator;
std::normal_distribution<float> distribution; std::normal_distribution<float> distribution;
GeneratorTensor_4(float mean, float stddev) : generator(1), distribution(mean, stddev){}; GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1)
: generator(seed), distribution(mean, stddev){};
template <typename... Is> template <typename... Is>
T operator()(Is...) T operator()(Is...)
......
...@@ -69,6 +69,10 @@ FOREACH(subdir_path ${dir_list}) ...@@ -69,6 +69,10 @@ FOREACH(subdir_path ${dir_list})
message("fp8 instance found!") message("fp8 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if(("${cmake_instance}" MATCHES "_bf8" OR "${cmake_instance}" MATCHES "_b8") AND DTYPES MATCHES "bf8")
message("bf8 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16") if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
message("fp16 instance found!") message("fp16 instance found!")
set(add_inst 1) set(add_inst 1)
......
add_instance_library(device_column_to_image_instance add_instance_library(device_column_to_image_instance
device_column_to_image_nhwc_1d_instance.cpp device_column_to_image_gnwc_1d_instance.cpp
device_column_to_image_nhwc_2d_instance.cpp device_column_to_image_gnhwc_2d_instance.cpp
device_column_to_image_nhwc_3d_instance.cpp device_column_to_image_gndhwc_3d_instance.cpp
device_column_to_image_nwgc_1d_instance.cpp
device_column_to_image_nhwgc_2d_instance.cpp
device_column_to_image_ndhwgc_3d_instance.cpp
) )
...@@ -11,7 +11,7 @@ namespace instance { ...@@ -11,7 +11,7 @@ namespace instance {
using namespace ck::conv_tensor_rearrange_op; using namespace ck::conv_tensor_rearrange_op;
void add_device_column_to_image_ndhwc_3d_bf16_instances( void add_device_column_to_image_gndhwc_3d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, BF16, BF16, ColumnToImage>>>& std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, BF16, BF16, ColumnToImage>>>&
instances) instances)
{ {
...@@ -22,7 +22,7 @@ void add_device_column_to_image_ndhwc_3d_bf16_instances( ...@@ -22,7 +22,7 @@ void add_device_column_to_image_ndhwc_3d_bf16_instances(
#endif #endif
} }
void add_device_column_to_image_ndhwc_3d_f16_instances( void add_device_column_to_image_gndhwc_3d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F16, F16, ColumnToImage>>>& std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F16, F16, ColumnToImage>>>&
instances) instances)
{ {
...@@ -33,7 +33,7 @@ void add_device_column_to_image_ndhwc_3d_f16_instances( ...@@ -33,7 +33,7 @@ void add_device_column_to_image_ndhwc_3d_f16_instances(
#endif #endif
} }
void add_device_column_to_image_ndhwc_3d_f32_instances( void add_device_column_to_image_gndhwc_3d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F32, F32, ColumnToImage>>>& std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F32, F32, ColumnToImage>>>&
instances) instances)
{ {
...@@ -44,7 +44,7 @@ void add_device_column_to_image_ndhwc_3d_f32_instances( ...@@ -44,7 +44,7 @@ void add_device_column_to_image_ndhwc_3d_f32_instances(
#endif #endif
} }
void add_device_column_to_image_ndhwc_3d_i8_instances( void add_device_column_to_image_gndhwc_3d_i8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, int8_t, int8_t, ColumnToImage>>>& std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, int8_t, int8_t, ColumnToImage>>>&
instances) instances)
......
...@@ -11,7 +11,7 @@ namespace instance { ...@@ -11,7 +11,7 @@ namespace instance {
using namespace ck::conv_tensor_rearrange_op; using namespace ck::conv_tensor_rearrange_op;
void add_device_column_to_image_nhwc_2d_bf16_instances( void add_device_column_to_image_gnhwc_2d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, BF16, BF16, ColumnToImage>>>& std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, BF16, BF16, ColumnToImage>>>&
instances) instances)
{ {
...@@ -22,7 +22,7 @@ void add_device_column_to_image_nhwc_2d_bf16_instances( ...@@ -22,7 +22,7 @@ void add_device_column_to_image_nhwc_2d_bf16_instances(
#endif #endif
} }
void add_device_column_to_image_nhwc_2d_f16_instances( void add_device_column_to_image_gnhwc_2d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F16, F16, ColumnToImage>>>& std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F16, F16, ColumnToImage>>>&
instances) instances)
{ {
...@@ -33,7 +33,7 @@ void add_device_column_to_image_nhwc_2d_f16_instances( ...@@ -33,7 +33,7 @@ void add_device_column_to_image_nhwc_2d_f16_instances(
#endif #endif
} }
void add_device_column_to_image_nhwc_2d_f32_instances( void add_device_column_to_image_gnhwc_2d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F32, F32, ColumnToImage>>>& std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F32, F32, ColumnToImage>>>&
instances) instances)
{ {
...@@ -44,7 +44,7 @@ void add_device_column_to_image_nhwc_2d_f32_instances( ...@@ -44,7 +44,7 @@ void add_device_column_to_image_nhwc_2d_f32_instances(
#endif #endif
} }
void add_device_column_to_image_nhwc_2d_i8_instances( void add_device_column_to_image_gnhwc_2d_i8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, int8_t, int8_t, ColumnToImage>>>& std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, int8_t, int8_t, ColumnToImage>>>&
instances) instances)
......
...@@ -11,7 +11,7 @@ namespace instance { ...@@ -11,7 +11,7 @@ namespace instance {
using namespace ck::conv_tensor_rearrange_op; using namespace ck::conv_tensor_rearrange_op;
void add_device_column_to_image_nwc_1d_bf16_instances( void add_device_column_to_image_gnwc_1d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, BF16, BF16, ColumnToImage>>>& std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, BF16, BF16, ColumnToImage>>>&
instances) instances)
{ {
...@@ -22,7 +22,7 @@ void add_device_column_to_image_nwc_1d_bf16_instances( ...@@ -22,7 +22,7 @@ void add_device_column_to_image_nwc_1d_bf16_instances(
#endif #endif
} }
void add_device_column_to_image_nwc_1d_f16_instances( void add_device_column_to_image_gnwc_1d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F16, F16, ColumnToImage>>>& std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F16, F16, ColumnToImage>>>&
instances) instances)
{ {
...@@ -33,7 +33,7 @@ void add_device_column_to_image_nwc_1d_f16_instances( ...@@ -33,7 +33,7 @@ void add_device_column_to_image_nwc_1d_f16_instances(
#endif #endif
} }
void add_device_column_to_image_nwc_1d_f32_instances( void add_device_column_to_image_gnwc_1d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F32, F32, ColumnToImage>>>& std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F32, F32, ColumnToImage>>>&
instances) instances)
{ {
...@@ -44,7 +44,7 @@ void add_device_column_to_image_nwc_1d_f32_instances( ...@@ -44,7 +44,7 @@ void add_device_column_to_image_nwc_1d_f32_instances(
#endif #endif
} }
void add_device_column_to_image_nwc_1d_i8_instances( void add_device_column_to_image_gnwc_1d_i8_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, int8_t, int8_t, ColumnToImage>>>& std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, int8_t, int8_t, ColumnToImage>>>&
instances) instances)
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::conv_tensor_rearrange_op;
void add_device_column_to_image_ndhwgc_3d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, BF16, BF16, ColumnToImage>>>&
instances)
{
#ifdef CK_ENABLE_BF16
add_device_operation_instances(instances, device_column_to_image_bf16_instances<3, NDHWGC>{});
#else
ignore = instances;
#endif
}
void add_device_column_to_image_ndhwgc_3d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, F16, F16, ColumnToImage>>>&
instances)
{
#ifdef CK_ENABLE_FP16
add_device_operation_instances(instances, device_column_to_image_f16_instances<3, NDHWGC>{});
#else
ignore = instances;
#endif
}
void add_device_column_to_image_ndhwgc_3d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, F32, F32, ColumnToImage>>>&
instances)
{
#ifdef CK_ENABLE_FP32
add_device_operation_instances(instances, device_column_to_image_f32_instances<3, NDHWGC>{});
#else
ignore = instances;
#endif
}
void add_device_column_to_image_ndhwgc_3d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, int8_t, int8_t, ColumnToImage>>>&
instances)
{
#ifdef CK_ENABLE_INT8
add_device_operation_instances(instances, device_column_to_image_i8_instances<3, NDHWGC>{});
#else
ignore = instances;
#endif
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::conv_tensor_rearrange_op;
void add_device_column_to_image_nhwgc_2d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, BF16, BF16, ColumnToImage>>>&
instances)
{
#ifdef CK_ENABLE_BF16
add_device_operation_instances(instances, device_column_to_image_bf16_instances<2, NHWGC>{});
#else
ignore = instances;
#endif
}
void add_device_column_to_image_nhwgc_2d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, F16, F16, ColumnToImage>>>&
instances)
{
#ifdef CK_ENABLE_FP16
add_device_operation_instances(instances, device_column_to_image_f16_instances<2, NHWGC>{});
#else
ignore = instances;
#endif
}
void add_device_column_to_image_nhwgc_2d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, F32, F32, ColumnToImage>>>&
instances)
{
#ifdef CK_ENABLE_FP32
add_device_operation_instances(instances, device_column_to_image_f32_instances<2, NHWGC>{});
#else
ignore = instances;
#endif
}
void add_device_column_to_image_nhwgc_2d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, int8_t, int8_t, ColumnToImage>>>&
instances)
{
#ifdef CK_ENABLE_INT8
add_device_operation_instances(instances, device_column_to_image_i8_instances<2, NHWGC>{});
#else
ignore = instances;
#endif
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::conv_tensor_rearrange_op;
void add_device_column_to_image_nwgc_1d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, BF16, BF16, ColumnToImage>>>&
instances)
{
#ifdef CK_ENABLE_BF16
add_device_operation_instances(instances, device_column_to_image_bf16_instances<1, NWGC>{});
#else
ignore = instances;
#endif
}
void add_device_column_to_image_nwgc_1d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, F16, F16, ColumnToImage>>>&
instances)
{
#ifdef CK_ENABLE_FP16
add_device_operation_instances(instances, device_column_to_image_f16_instances<1, NWGC>{});
#else
ignore = instances;
#endif
}
void add_device_column_to_image_nwgc_1d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, F32, F32, ColumnToImage>>>&
instances)
{
#ifdef CK_ENABLE_FP32
add_device_operation_instances(instances, device_column_to_image_f32_instances<1, NWGC>{});
#else
ignore = instances;
#endif
}
void add_device_column_to_image_nwgc_1d_i8_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, int8_t, int8_t, ColumnToImage>>>&
instances)
{
#ifdef CK_ENABLE_INT8
add_device_operation_instances(instances, device_column_to_image_i8_instances<1, NWGC>{});
#else
ignore = instances;
#endif
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
set(DEVICE_CONTRACTION_BILINEAR_INSTANCES) set(DEVICE_CONTRACTION_BILINEAR_INSTANCES)
#float
# FP32
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp) device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp)
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp)
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp)
#double # FP64
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp) device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp)
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp)
# FP16
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp)
# BF16
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp)
add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES}) add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// k/k/n/n are the fast changing dimension for A/B/D/E
using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance =
device_contraction_kk_instance<BF16,
BF16,
F32,
BF16,
BF16_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Bilinear>;
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment