test_contraction_interface.cpp 5.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#include <stdexcept>
#include <vector>

#include "gtest/gtest.h"

#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"

#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"

#include "ck/library/utility/device_memory.hpp"

using Pass     = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;

using F32 = float;
using F64 = double;

template <typename DataTypeA,
          typename DataTypeB,
          typename DataTypeC,
          typename DataTypeD,
26
          ck::index_t NumDim>
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class ContractionDeviceWrapper
{

    protected:
    using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<NumDim,
                                                                              NumDim,
                                                                              NumDim,
                                                                              DataTypeA,
                                                                              DataTypeB,
                                                                              ck::Tuple<DataTypeC>,
                                                                              DataTypeD,
                                                                              Pass,
                                                                              Pass,
                                                                              Bilinear>;

    public:
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    ContractionDeviceWrapper(std::vector<ck::index_t>& Dims, std::vector<ck::index_t>& Strides)
        : InputDims_(Dims), OutputDims_(Dims), InputStrides_(Strides), OutputStrides_(Strides)
    {
    }
    ContractionDeviceWrapper(std::vector<ck::index_t>& InDims,
                             std::vector<ck::index_t>& OutDims,
                             std::vector<ck::index_t>& InStrides,
                             std::vector<ck::index_t>& OutStrides)
        : InputDims_(InDims),
          OutputDims_(OutDims),
          InputStrides_(InStrides),
          OutputStrides_(OutStrides)
    {
    }

    std::vector<ck::index_t>& InputDims_;
    std::vector<ck::index_t>& OutputDims_;
    std::vector<ck::index_t>& InputStrides_;
    std::vector<ck::index_t>& OutputStrides_;
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    bool IsSupported() const
    {

        bool supported     = false;
        const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
            DeviceOp>::GetInstances();

        for(auto& op_ptr : op_ptrs)
        {
            auto argument_ptr =
                op_ptr->MakeArgumentPointer(nullptr,
                                            nullptr,
                                            std::array<const void*, 1>{nullptr},
                                            nullptr,
76
77
78
79
80
81
82
83
                                            InputStrides_,
                                            InputStrides_,
                                            InputStrides_,
                                            InputStrides_,
                                            std::array<std::vector<ck::index_t>, 1>{InputStrides_},
                                            std::array<std::vector<ck::index_t>, 1>{InputStrides_},
                                            OutputDims_,
                                            OutputStrides_,
84
85
86
87
88
89
90
91
92
93
94
95
                                            Pass{},
                                            Pass{},
                                            Bilinear{1.f, 1.f});

            supported = supported || op_ptr->IsSupportedArgument(argument_ptr.get());
        }
        return supported;
    }
};

TEST(TestContractionInterface, IncorrectNumDims)
{
96
97
98
99
100
    std::vector<std::vector<ck::index_t>> Dims    = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}};
    std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}};
    ContractionDeviceWrapper<F32, F32, F32, F32, 1> wrapper_1d(Dims[0], Strides[0]);
    ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper_2d(Dims[1], Strides[1]);
    ContractionDeviceWrapper<F32, F32, F32, F32, 3> wrapper_3d(Dims[2], Strides[2]);
101
102
103
104
105
106
107
    EXPECT_FALSE(wrapper_1d.IsSupported());
    EXPECT_TRUE(wrapper_2d.IsSupported());
    EXPECT_FALSE(wrapper_3d.IsSupported());
}

TEST(TestContractionInterface, IncorrectDataTypes)
{
108
109
110
111
    std::vector<ck::index_t> Dims    = {4, 4, 4, 4};
    std::vector<ck::index_t> Strides = {64, 16, 4, 1};
    ContractionDeviceWrapper<F32, F32, F64, F64, 2> wrapper_1(Dims, Strides);
    ContractionDeviceWrapper<F64, F64, F32, F32, 2> wrapper_2(Dims, Strides);
112
113
114
    EXPECT_FALSE(wrapper_1.IsSupported());
    EXPECT_FALSE(wrapper_2.IsSupported());
}
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
TEST(TestContractionInterface, GridwiseGemm)
{
    std::vector<ck::index_t> InDims     = {1, 2, 3, 4};
    std::vector<ck::index_t> InStrides  = {24, 12, 4, 1};
    std::vector<ck::index_t> OutDims    = {4, 3, 2, 1};
    std::vector<ck::index_t> OutStrides = {6, 2, 1, 1};
    ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper(InDims, OutDims, InStrides, OutStrides);

    EXPECT_FALSE(wrapper.IsSupported());
}

TEST(TestContractionInterface, MemoryAccess)
{
    std::vector<ck::index_t> Dims    = {4, 4, 4, 4};
    std::vector<ck::index_t> Strides = {4, 16, 64, 256};
    ContractionDeviceWrapper<F32, F32, F32, F32, 2> wrapper(Dims, Strides);

    EXPECT_FALSE(wrapper.IsSupported());
}