"vscode:/vscode.git/clone" did not exist on "667e52ccd0b15275c7ef1b6040bb08362293a2f5"
test_contraction_interface.cpp 3.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#include <stdexcept>
#include <vector>

#include "gtest/gtest.h"

#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"

#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"

#include "ck/library/utility/device_memory.hpp"

using Pass     = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;

using F32 = float;
using F64 = double;

template <typename DataTypeA,
          typename DataTypeB,
          typename DataTypeC,
          typename DataTypeD,
          int NumDim>
class ContractionDeviceWrapper
{

    protected:
    using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<NumDim,
                                                                              NumDim,
                                                                              NumDim,
                                                                              DataTypeA,
                                                                              DataTypeB,
                                                                              ck::Tuple<DataTypeC>,
                                                                              DataTypeD,
                                                                              Pass,
                                                                              Pass,
                                                                              Bilinear>;

    public:
    bool IsSupported() const
    {
        std::vector<ck::index_t> dummy_dims(NumDim * 2, 4);
        std::vector<ck::index_t> dummy_strides(NumDim * 2, 1);

        bool supported     = false;
        const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
            DeviceOp>::GetInstances();

        for(auto& op_ptr : op_ptrs)
        {
            auto argument_ptr =
                op_ptr->MakeArgumentPointer(nullptr,
                                            nullptr,
                                            std::array<const void*, 1>{nullptr},
                                            nullptr,
                                            dummy_dims,
                                            dummy_strides,
                                            dummy_dims,
                                            dummy_strides,
                                            std::array<std::vector<ck::index_t>, 1>{dummy_dims},
                                            std::array<std::vector<ck::index_t>, 1>{dummy_strides},
                                            dummy_dims,
                                            dummy_strides,
                                            Pass{},
                                            Pass{},
                                            Bilinear{1.f, 1.f});

            supported = supported || op_ptr->IsSupportedArgument(argument_ptr.get());
        }
        return supported;
    }
};

TEST(TestContractionInterface, IncorrectNumDims)
{
    ContractionDeviceWrapper<F32, F32, F32, F32, 1> wrapper_1d;
    ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper_2d;
    ContractionDeviceWrapper<F32, F32, F32, F32, 3> wrapper_3d;
    EXPECT_FALSE(wrapper_1d.IsSupported());
    EXPECT_TRUE(wrapper_2d.IsSupported());
    EXPECT_FALSE(wrapper_3d.IsSupported());
}

TEST(TestContractionInterface, IncorrectDataTypes)
{
    ContractionDeviceWrapper<F32, F32, F64, F64, 2> wrapper_1;
    ContractionDeviceWrapper<F64, F64, F32, F32, 2> wrapper_2;
    EXPECT_FALSE(wrapper_1.IsSupported());
    EXPECT_FALSE(wrapper_2.IsSupported());
}
94
95
96
97
98

// TEST(TestContractionInterface, CornerCases)
// {
//     EXPECT_FALSE()
// }