Commit a30c626b authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Make ComputeDataType an optional argument

parent b019d839
...@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instanc ...@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instanc
F64, F64,
Empty_Tuple, Empty_Tuple,
F64, F64,
F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances) Scale,
F64>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance{}); instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance{});
......
...@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instanc ...@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instanc
F64, F64,
Empty_Tuple, Empty_Tuple,
F64, F64,
F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances) Scale,
F64>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance{}); instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance{});
......
...@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instanc ...@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instanc
F64, F64,
Empty_Tuple, Empty_Tuple,
F64, F64,
F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances) Scale,
F64>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance{}); instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance{});
......
...@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc ...@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
F64, F64,
Empty_Tuple, Empty_Tuple,
F64, F64,
F64,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances) Scale,
F64>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance{}); instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance{});
......
...@@ -124,10 +124,10 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -124,10 +124,10 @@ int profile_contraction_impl(ck::index_t do_verification,
DataType, DataType,
DTupleDataType, DTupleDataType,
DataType, DataType,
ComputeDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDElementOp>; CDElementOp,
ComputeDataType>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
......
...@@ -75,7 +75,6 @@ template <typename DataTypeA, ...@@ -75,7 +75,6 @@ template <typename DataTypeA,
typename DataTypeB, typename DataTypeB,
typename DataTypeC, typename DataTypeC,
typename DataTypeD, typename DataTypeD,
typename DataTypeCompute,
ck::index_t NumDim> ck::index_t NumDim>
class ContractionDeviceOpWrapper class ContractionDeviceOpWrapper
{ {
...@@ -88,7 +87,6 @@ class ContractionDeviceOpWrapper ...@@ -88,7 +87,6 @@ class ContractionDeviceOpWrapper
DataTypeB, DataTypeB,
ck::Tuple<DataTypeC>, ck::Tuple<DataTypeC>,
DataTypeD, DataTypeD,
DataTypeCompute,
Pass, Pass,
Pass, Pass,
Bilinear>; Bilinear>;
...@@ -131,9 +129,9 @@ TEST(TestContractionInterface, IncorrectNumDims) ...@@ -131,9 +129,9 @@ TEST(TestContractionInterface, IncorrectNumDims)
{ {
std::vector<std::vector<ck::index_t>> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}}; std::vector<std::vector<ck::index_t>> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}};
std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}}; std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}};
ContractionDeviceOpWrapper<F32, F32, F32, F32, F32, 1> wrapper_1d; ContractionDeviceOpWrapper<F32, F32, F32, F32, 1> wrapper_1d;
ContractionDeviceOpWrapper<F32, F32, F32, F32, F32, 2> wrapper_2d; ContractionDeviceOpWrapper<F32, F32, F32, F32, 2> wrapper_2d;
ContractionDeviceOpWrapper<F32, F32, F32, F32, F32, 3> wrapper_3d; ContractionDeviceOpWrapper<F32, F32, F32, F32, 3> wrapper_3d;
EXPECT_FALSE(wrapper_1d.IsSupportedInstance(Dims[0], Strides[0])); EXPECT_FALSE(wrapper_1d.IsSupportedInstance(Dims[0], Strides[0]));
EXPECT_TRUE(wrapper_2d.IsSupportedInstance(Dims[1], Strides[1])); EXPECT_TRUE(wrapper_2d.IsSupportedInstance(Dims[1], Strides[1]));
EXPECT_FALSE(wrapper_3d.IsSupportedInstance(Dims[2], Strides[2])); EXPECT_FALSE(wrapper_3d.IsSupportedInstance(Dims[2], Strides[2]));
...@@ -143,8 +141,8 @@ TEST(TestContractionInterface, IncorrectDataTypes) ...@@ -143,8 +141,8 @@ TEST(TestContractionInterface, IncorrectDataTypes)
{ {
std::vector<ck::index_t> Dims = {4, 4, 4, 4}; std::vector<ck::index_t> Dims = {4, 4, 4, 4};
std::vector<ck::index_t> Strides = {64, 16, 4, 1}; std::vector<ck::index_t> Strides = {64, 16, 4, 1};
ContractionDeviceOpWrapper<F32, F32, F64, F64, F32, 2> wrapper_1; ContractionDeviceOpWrapper<F32, F32, F64, F64, 2> wrapper_1;
ContractionDeviceOpWrapper<F64, F64, F32, F32, F32, 2> wrapper_2; ContractionDeviceOpWrapper<F64, F64, F32, F32, 2> wrapper_2;
EXPECT_FALSE(wrapper_1.IsSupportedInstance(Dims, Strides)); EXPECT_FALSE(wrapper_1.IsSupportedInstance(Dims, Strides));
EXPECT_FALSE(wrapper_2.IsSupportedInstance(Dims, Strides)); EXPECT_FALSE(wrapper_2.IsSupportedInstance(Dims, Strides));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment