Commit a30c626b authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Make ComputeDataType an optional argument

parent b019d839
......@@ -30,7 +30,6 @@ using CShuffleDataType = F32;
using DDataType = F32;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F32;
using ComputeDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
......@@ -141,7 +140,6 @@ int main(int argc, char* argv[])
BDataType,
ck::Tuple<DDataType>,
EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>;
......
......@@ -30,7 +30,6 @@ using CShuffleDataType = F64;
using DDataType = F64;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F64;
using ComputeDataType = F64;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
......@@ -186,7 +185,6 @@ int main(int argc, char* argv[])
BDataType,
ck::Tuple<DDataType>,
EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>;
......
......@@ -29,7 +29,6 @@ using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F32;
using ComputeDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
......@@ -128,7 +127,6 @@ int main(int argc, char* argv[])
BDataType,
ck::Tuple<>,
EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Scale>;
......
......@@ -29,7 +29,6 @@ using AccDataType = F64;
using CShuffleDataType = F64;
using DsDataType = ck::Tuple<>;
using EDataType = F64;
using ComputeDataType = F64;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
......@@ -176,7 +175,6 @@ int main(int argc, char* argv[])
BDataType,
ck::Tuple<>,
EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Scale>;
......
......@@ -31,10 +31,10 @@ template <index_t NumDimM,
typename BDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
typename CDEElementwiseOperation,
typename ComputeDataType = ADataType>
struct DeviceContractionMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
......
......@@ -155,10 +155,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
CDEElementwiseOperation,
ComputeDataType>
{
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
......
......@@ -25,10 +25,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn
F32,
F32_Tuple,
F32,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -38,10 +38,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn
F32,
F32_Tuple,
F32,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -51,10 +51,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn
F32,
F32_Tuple,
F32,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -64,10 +64,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
F32,
F32_Tuple,
F32,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -77,10 +77,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -90,10 +90,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -103,10 +103,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -116,10 +116,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -129,10 +129,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -142,10 +142,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -155,10 +155,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -168,10 +168,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
BF16>>>& instances);
#endif // CK_ENABLE_FP32
#ifdef CK_ENABLE_FP64
......@@ -183,10 +183,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn
F64,
F64_Tuple,
F64,
F64,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F64>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -196,10 +196,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn
F64,
F64_Tuple,
F64,
F64,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F64>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -209,10 +209,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn
F64,
F64_Tuple,
F64,
F64,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F64>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -222,10 +222,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn
F64,
F64_Tuple,
F64,
F64,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F64>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -235,10 +235,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp
F64,
F64_Tuple,
F64,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -248,10 +248,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp
F64,
F64_Tuple,
F64,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -261,10 +261,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp
F64,
F64_Tuple,
F64,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -274,10 +274,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp
F64,
F64_Tuple,
F64,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
#endif // CK_ENABLE_FP64
#ifdef CK_ENABLE_FP16
......@@ -289,10 +289,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16,
F16_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -302,10 +302,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16,
F16_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -315,10 +315,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16,
F16_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -328,10 +328,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16,
F16_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
#endif // CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16
......@@ -343,10 +343,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16,
BF16_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -356,10 +356,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16,
BF16_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -369,10 +369,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16,
BF16_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -382,10 +382,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16,
BF16_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances);
Bilinear,
F32>>>& instances);
#endif // CK_ENABLE_FP16
// Contraction + Bilinear
......@@ -405,10 +405,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
BDataType,
ck::Tuple<DDataType>,
EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>>
ck::tensor_operation::element_wise::Bilinear,
ComputeDataType>>
{
using DeviceOp = DeviceContractionMultipleD<NumDimM,
NumDimN,
......@@ -417,10 +417,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
BDataType,
ck::Tuple<DDataType>,
EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>;
ck::tensor_operation::element_wise::Bilinear,
ComputeDataType>;
static auto GetInstances()
{
......
......@@ -25,10 +25,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instanc
F32,
Empty_Tuple,
F32,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -38,10 +38,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instanc
F32,
Empty_Tuple,
F32,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -51,10 +51,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instanc
F32,
Empty_Tuple,
F32,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -64,10 +64,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
F32,
Empty_Tuple,
F32,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -77,10 +77,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16
F32,
Empty_Tuple,
F32,
F16,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -90,10 +90,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16
F32,
Empty_Tuple,
F32,
F16,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -103,10 +103,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16
F32,
Empty_Tuple,
F32,
F16,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -116,10 +116,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16
F32,
Empty_Tuple,
F32,
F16,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -129,10 +129,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf1
F32,
Empty_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
BF16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -142,10 +142,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf1
F32,
Empty_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
BF16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -155,10 +155,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf1
F32,
Empty_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
BF16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -168,10 +168,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf1
F32,
Empty_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
BF16>>>& instances);
#endif // CK_ENABLE_FP32
#ifdef CK_ENABLE_FP64
......@@ -183,10 +183,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instanc
F64,
Empty_Tuple,
F64,
F64,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F64>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -196,10 +196,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instanc
F64,
Empty_Tuple,
F64,
F64,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F64>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -209,10 +209,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instanc
F64,
Empty_Tuple,
F64,
F64,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F64>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -222,10 +222,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
F64,
Empty_Tuple,
F64,
F64,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F64>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -235,10 +235,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32
F64,
Empty_Tuple,
F64,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -248,10 +248,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32
F64,
Empty_Tuple,
F64,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -261,10 +261,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32
F64,
Empty_Tuple,
F64,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -274,10 +274,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32
F64,
Empty_Tuple,
F64,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
#endif // CK_ENABLE_FP64
#ifdef CK_ENABLE_FP16
......@@ -289,10 +289,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32
F16,
Empty_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -302,10 +302,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32
F16,
Empty_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -315,10 +315,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32
F16,
Empty_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -328,10 +328,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32
F16,
Empty_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
#endif // CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16
......@@ -343,10 +343,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_
BF16,
Empty_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -356,10 +356,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_
BF16,
Empty_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -369,10 +369,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_
BF16,
Empty_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -382,10 +382,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_
BF16,
Empty_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Scale>>>& instances);
Scale,
F32>>>& instances);
#endif // CK_ENABLE_FP16
// Contraction + Scale
......@@ -404,10 +404,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
BDataType,
ck::Tuple<>,
EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Scale>>
ck::tensor_operation::element_wise::Scale,
ComputeDataType>>
{
using DeviceOp = DeviceContractionMultipleD<NumDimM,
NumDimN,
......@@ -416,10 +416,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
BDataType,
ck::Tuple<>,
EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Scale>;
ck::tensor_operation::element_wise::Scale,
ComputeDataType>;
static auto GetInstances()
{
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16,
BF16_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16,
BF16_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16,
BF16_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16,
BF16_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16,
F16_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16,
F16_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16,
F16_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16,
F16_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
BF16>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
BF16>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
BF16>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
BF16>>>& instances)
{
add_device_operation_instances(
instances,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment