Commit a30c626b authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Make ComputeDataType an optional argument

parent b019d839
...@@ -30,7 +30,6 @@ using CShuffleDataType = F32; ...@@ -30,7 +30,6 @@ using CShuffleDataType = F32;
using DDataType = F32; using DDataType = F32;
using DsDataType = ck::Tuple<DDataType>; using DsDataType = ck::Tuple<DDataType>;
using EDataType = F32; using EDataType = F32;
using ComputeDataType = F32;
static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2; static constexpr ck::index_t NumDimN = 2;
...@@ -141,7 +140,6 @@ int main(int argc, char* argv[]) ...@@ -141,7 +140,6 @@ int main(int argc, char* argv[])
BDataType, BDataType,
ck::Tuple<DDataType>, ck::Tuple<DDataType>,
EDataType, EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>; ck::tensor_operation::element_wise::Bilinear>;
......
...@@ -30,7 +30,6 @@ using CShuffleDataType = F64; ...@@ -30,7 +30,6 @@ using CShuffleDataType = F64;
using DDataType = F64; using DDataType = F64;
using DsDataType = ck::Tuple<DDataType>; using DsDataType = ck::Tuple<DDataType>;
using EDataType = F64; using EDataType = F64;
using ComputeDataType = F64;
static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2; static constexpr ck::index_t NumDimN = 2;
...@@ -186,7 +185,6 @@ int main(int argc, char* argv[]) ...@@ -186,7 +185,6 @@ int main(int argc, char* argv[])
BDataType, BDataType,
ck::Tuple<DDataType>, ck::Tuple<DDataType>,
EDataType, EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>; ck::tensor_operation::element_wise::Bilinear>;
......
...@@ -29,7 +29,6 @@ using AccDataType = F32; ...@@ -29,7 +29,6 @@ using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>; using DsDataType = ck::Tuple<>;
using EDataType = F32; using EDataType = F32;
using ComputeDataType = F32;
static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2; static constexpr ck::index_t NumDimN = 2;
...@@ -128,7 +127,6 @@ int main(int argc, char* argv[]) ...@@ -128,7 +127,6 @@ int main(int argc, char* argv[])
BDataType, BDataType,
ck::Tuple<>, ck::Tuple<>,
EDataType, EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Scale>; ck::tensor_operation::element_wise::Scale>;
......
...@@ -29,7 +29,6 @@ using AccDataType = F64; ...@@ -29,7 +29,6 @@ using AccDataType = F64;
using CShuffleDataType = F64; using CShuffleDataType = F64;
using DsDataType = ck::Tuple<>; using DsDataType = ck::Tuple<>;
using EDataType = F64; using EDataType = F64;
using ComputeDataType = F64;
static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2; static constexpr ck::index_t NumDimN = 2;
...@@ -176,7 +175,6 @@ int main(int argc, char* argv[]) ...@@ -176,7 +175,6 @@ int main(int argc, char* argv[])
BDataType, BDataType,
ck::Tuple<>, ck::Tuple<>,
EDataType, EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Scale>; ck::tensor_operation::element_wise::Scale>;
......
...@@ -31,10 +31,10 @@ template <index_t NumDimM, ...@@ -31,10 +31,10 @@ template <index_t NumDimM,
typename BDataType, typename BDataType,
typename DsDataType, typename DsDataType,
typename EDataType, typename EDataType,
typename ComputeDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation,
typename ComputeDataType = ADataType>
struct DeviceContractionMultipleD : public BaseOperator struct DeviceContractionMultipleD : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
......
...@@ -155,10 +155,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -155,10 +155,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BDataType, BDataType,
DsDataType, DsDataType,
EDataType, EDataType,
ComputeDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation,
ComputeDataType>
{ {
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
......
...@@ -25,10 +25,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn ...@@ -25,10 +25,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -38,10 +38,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn ...@@ -38,10 +38,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -51,10 +51,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn ...@@ -51,10 +51,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -64,10 +64,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn ...@@ -64,10 +64,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -77,10 +77,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp ...@@ -77,10 +77,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -90,10 +90,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp ...@@ -90,10 +90,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -103,10 +103,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp ...@@ -103,10 +103,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -116,10 +116,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp ...@@ -116,10 +116,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -129,10 +129,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp ...@@ -129,10 +129,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -142,10 +142,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp ...@@ -142,10 +142,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -155,10 +155,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp ...@@ -155,10 +155,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -168,10 +168,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp ...@@ -168,10 +168,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
BF16>>>& instances);
#endif // CK_ENABLE_FP32 #endif // CK_ENABLE_FP32
#ifdef CK_ENABLE_FP64 #ifdef CK_ENABLE_FP64
...@@ -183,10 +183,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn ...@@ -183,10 +183,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn
F64, F64,
F64_Tuple, F64_Tuple,
F64, F64,
F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F64>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -196,10 +196,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn ...@@ -196,10 +196,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn
F64, F64,
F64_Tuple, F64_Tuple,
F64, F64,
F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F64>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -209,10 +209,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn ...@@ -209,10 +209,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn
F64, F64,
F64_Tuple, F64_Tuple,
F64, F64,
F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F64>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -222,10 +222,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn ...@@ -222,10 +222,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn
F64, F64,
F64_Tuple, F64_Tuple,
F64, F64,
F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F64>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -235,10 +235,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp ...@@ -235,10 +235,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp
F64, F64,
F64_Tuple, F64_Tuple,
F64, F64,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -248,10 +248,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp ...@@ -248,10 +248,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp
F64, F64,
F64_Tuple, F64_Tuple,
F64, F64,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -261,10 +261,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp ...@@ -261,10 +261,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp
F64, F64,
F64_Tuple, F64_Tuple,
F64, F64,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -274,10 +274,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp ...@@ -274,10 +274,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp
F64, F64,
F64_Tuple, F64_Tuple,
F64, F64,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
#endif // CK_ENABLE_FP64 #endif // CK_ENABLE_FP64
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
...@@ -289,10 +289,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp ...@@ -289,10 +289,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16, F16,
F16_Tuple, F16_Tuple,
F16, F16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -302,10 +302,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp ...@@ -302,10 +302,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16, F16,
F16_Tuple, F16_Tuple,
F16, F16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -315,10 +315,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp ...@@ -315,10 +315,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16, F16,
F16_Tuple, F16_Tuple,
F16, F16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -328,10 +328,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp ...@@ -328,10 +328,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16, F16,
F16_Tuple, F16_Tuple,
F16, F16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
#endif // CK_ENABLE_FP16 #endif // CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
...@@ -343,10 +343,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_ ...@@ -343,10 +343,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16, BF16,
BF16_Tuple, BF16_Tuple,
BF16, BF16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -356,10 +356,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_ ...@@ -356,10 +356,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16, BF16,
BF16_Tuple, BF16_Tuple,
BF16, BF16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -369,10 +369,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_ ...@@ -369,10 +369,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16, BF16,
BF16_Tuple, BF16_Tuple,
BF16, BF16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -382,10 +382,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_ ...@@ -382,10 +382,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16, BF16,
BF16_Tuple, BF16_Tuple,
BF16, BF16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear,
F32>>>& instances);
#endif // CK_ENABLE_FP16 #endif // CK_ENABLE_FP16
// Contraction + Bilinear // Contraction + Bilinear
...@@ -405,10 +405,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -405,10 +405,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
BDataType, BDataType,
ck::Tuple<DDataType>, ck::Tuple<DDataType>,
EDataType, EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>> ck::tensor_operation::element_wise::Bilinear,
ComputeDataType>>
{ {
using DeviceOp = DeviceContractionMultipleD<NumDimM, using DeviceOp = DeviceContractionMultipleD<NumDimM,
NumDimN, NumDimN,
...@@ -417,10 +417,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -417,10 +417,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
BDataType, BDataType,
ck::Tuple<DDataType>, ck::Tuple<DDataType>,
EDataType, EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>; ck::tensor_operation::element_wise::Bilinear,
ComputeDataType>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -25,10 +25,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instanc ...@@ -25,10 +25,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instanc
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -38,10 +38,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instanc ...@@ -38,10 +38,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instanc
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -51,10 +51,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instanc ...@@ -51,10 +51,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instanc
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -64,10 +64,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc ...@@ -64,10 +64,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -77,10 +77,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16 ...@@ -77,10 +77,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -90,10 +90,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16 ...@@ -90,10 +90,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -103,10 +103,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16 ...@@ -103,10 +103,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -116,10 +116,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16 ...@@ -116,10 +116,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -129,10 +129,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf1 ...@@ -129,10 +129,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf1
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
BF16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -142,10 +142,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf1 ...@@ -142,10 +142,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf1
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
BF16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -155,10 +155,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf1 ...@@ -155,10 +155,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf1
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
BF16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -168,10 +168,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf1 ...@@ -168,10 +168,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf1
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
BF16>>>& instances);
#endif // CK_ENABLE_FP32 #endif // CK_ENABLE_FP32
#ifdef CK_ENABLE_FP64 #ifdef CK_ENABLE_FP64
...@@ -183,10 +183,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instanc ...@@ -183,10 +183,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instanc
F64, F64,
Empty_Tuple, Empty_Tuple,
F64, F64,
F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F64>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -196,10 +196,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instanc ...@@ -196,10 +196,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instanc
F64, F64,
Empty_Tuple, Empty_Tuple,
F64, F64,
F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F64>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -209,10 +209,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instanc ...@@ -209,10 +209,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instanc
F64, F64,
Empty_Tuple, Empty_Tuple,
F64, F64,
F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F64>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -222,10 +222,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc ...@@ -222,10 +222,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
F64, F64,
Empty_Tuple, Empty_Tuple,
F64, F64,
F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F64>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -235,10 +235,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32 ...@@ -235,10 +235,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32
F64, F64,
Empty_Tuple, Empty_Tuple,
F64, F64,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -248,10 +248,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32 ...@@ -248,10 +248,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32
F64, F64,
Empty_Tuple, Empty_Tuple,
F64, F64,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -261,10 +261,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32 ...@@ -261,10 +261,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32
F64, F64,
Empty_Tuple, Empty_Tuple,
F64, F64,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -274,10 +274,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32 ...@@ -274,10 +274,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32
F64, F64,
Empty_Tuple, Empty_Tuple,
F64, F64,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
#endif // CK_ENABLE_FP64 #endif // CK_ENABLE_FP64
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
...@@ -289,10 +289,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32 ...@@ -289,10 +289,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32
F16, F16,
Empty_Tuple, Empty_Tuple,
F16, F16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -302,10 +302,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32 ...@@ -302,10 +302,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32
F16, F16,
Empty_Tuple, Empty_Tuple,
F16, F16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -315,10 +315,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32 ...@@ -315,10 +315,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32
F16, F16,
Empty_Tuple, Empty_Tuple,
F16, F16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -328,10 +328,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32 ...@@ -328,10 +328,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32
F16, F16,
Empty_Tuple, Empty_Tuple,
F16, F16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
#endif // CK_ENABLE_FP16 #endif // CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
...@@ -343,10 +343,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_ ...@@ -343,10 +343,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_
BF16, BF16,
Empty_Tuple, Empty_Tuple,
BF16, BF16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -356,10 +356,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_ ...@@ -356,10 +356,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_
BF16, BF16,
Empty_Tuple, Empty_Tuple,
BF16, BF16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -369,10 +369,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_ ...@@ -369,10 +369,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_
BF16, BF16,
Empty_Tuple, Empty_Tuple,
BF16, BF16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -382,10 +382,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_ ...@@ -382,10 +382,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_
BF16, BF16,
Empty_Tuple, Empty_Tuple,
BF16, BF16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale,
F32>>>& instances);
#endif // CK_ENABLE_FP16 #endif // CK_ENABLE_FP16
// Contraction + Scale // Contraction + Scale
...@@ -404,10 +404,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -404,10 +404,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
BDataType, BDataType,
ck::Tuple<>, ck::Tuple<>,
EDataType, EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Scale>> ck::tensor_operation::element_wise::Scale,
ComputeDataType>>
{ {
using DeviceOp = DeviceContractionMultipleD<NumDimM, using DeviceOp = DeviceContractionMultipleD<NumDimM,
NumDimN, NumDimN,
...@@ -416,10 +416,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -416,10 +416,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
BDataType, BDataType,
ck::Tuple<>, ck::Tuple<>,
EDataType, EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Scale>; ck::tensor_operation::element_wise::Scale,
ComputeDataType>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_ ...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16, BF16,
BF16_Tuple, BF16_Tuple,
BF16, BF16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances) Bilinear,
F32>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_ ...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16, BF16,
BF16_Tuple, BF16_Tuple,
BF16, BF16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances) Bilinear,
F32>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_ ...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16, BF16,
BF16_Tuple, BF16_Tuple,
BF16, BF16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances) Bilinear,
F32>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_ ...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16, BF16,
BF16_Tuple, BF16_Tuple,
BF16, BF16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances) Bilinear,
F32>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp ...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16, F16,
F16_Tuple, F16_Tuple,
F16, F16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances) Bilinear,
F32>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp ...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16, F16,
F16_Tuple, F16_Tuple,
F16, F16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances) Bilinear,
F32>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp ...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16, F16,
F16_Tuple, F16_Tuple,
F16, F16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances) Bilinear,
F32>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp ...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16, F16,
F16_Tuple, F16_Tuple,
F16, F16,
F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances) Bilinear,
F32>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp ...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances) Bilinear,
BF16>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp ...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances) Bilinear,
BF16>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp ...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances) Bilinear,
BF16>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp ...@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32, F32,
F32_Tuple, F32_Tuple,
F32, F32,
BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances) Bilinear,
BF16>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment