Commit a30c626b authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Make ComputeDataType an optional argument

parent b019d839
......@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instanc
F64,
Empty_Tuple,
F64,
F64,
PassThrough,
PassThrough,
Scale>>>& instances)
Scale,
F64>>>& instances)
{
add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance{});
......
......@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instanc
F64,
Empty_Tuple,
F64,
F64,
PassThrough,
PassThrough,
Scale>>>& instances)
Scale,
F64>>>& instances)
{
add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance{});
......
......@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instanc
F64,
Empty_Tuple,
F64,
F64,
PassThrough,
PassThrough,
Scale>>>& instances)
Scale,
F64>>>& instances)
{
add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance{});
......
......@@ -41,10 +41,10 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
F64,
Empty_Tuple,
F64,
F64,
PassThrough,
PassThrough,
Scale>>>& instances)
Scale,
F64>>>& instances)
{
add_device_operation_instances(
instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance{});
......
......@@ -124,10 +124,10 @@ int profile_contraction_impl(ck::index_t do_verification,
DataType,
DTupleDataType,
DataType,
ComputeDataType,
AElementOp,
BElementOp,
CDElementOp>;
CDElementOp,
ComputeDataType>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
......
......@@ -75,7 +75,6 @@ template <typename DataTypeA,
typename DataTypeB,
typename DataTypeC,
typename DataTypeD,
typename DataTypeCompute,
ck::index_t NumDim>
class ContractionDeviceOpWrapper
{
......@@ -88,7 +87,6 @@ class ContractionDeviceOpWrapper
DataTypeB,
ck::Tuple<DataTypeC>,
DataTypeD,
DataTypeCompute,
Pass,
Pass,
Bilinear>;
......@@ -131,9 +129,9 @@ TEST(TestContractionInterface, IncorrectNumDims)
{
std::vector<std::vector<ck::index_t>> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}};
std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}};
ContractionDeviceOpWrapper<F32, F32, F32, F32, F32, 1> wrapper_1d;
ContractionDeviceOpWrapper<F32, F32, F32, F32, F32, 2> wrapper_2d;
ContractionDeviceOpWrapper<F32, F32, F32, F32, F32, 3> wrapper_3d;
ContractionDeviceOpWrapper<F32, F32, F32, F32, 1> wrapper_1d;
ContractionDeviceOpWrapper<F32, F32, F32, F32, 2> wrapper_2d;
ContractionDeviceOpWrapper<F32, F32, F32, F32, 3> wrapper_3d;
EXPECT_FALSE(wrapper_1d.IsSupportedInstance(Dims[0], Strides[0]));
EXPECT_TRUE(wrapper_2d.IsSupportedInstance(Dims[1], Strides[1]));
EXPECT_FALSE(wrapper_3d.IsSupportedInstance(Dims[2], Strides[2]));
......@@ -143,8 +141,8 @@ TEST(TestContractionInterface, IncorrectDataTypes)
{
std::vector<ck::index_t> Dims = {4, 4, 4, 4};
std::vector<ck::index_t> Strides = {64, 16, 4, 1};
ContractionDeviceOpWrapper<F32, F32, F64, F64, F32, 2> wrapper_1;
ContractionDeviceOpWrapper<F64, F64, F32, F32, F32, 2> wrapper_2;
ContractionDeviceOpWrapper<F32, F32, F64, F64, 2> wrapper_1;
ContractionDeviceOpWrapper<F64, F64, F32, F32, 2> wrapper_2;
EXPECT_FALSE(wrapper_1.IsSupportedInstance(Dims, Strides));
EXPECT_FALSE(wrapper_2.IsSupportedInstance(Dims, Strides));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment