test_contraction_interface.cpp 11.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#include <stdexcept>
#include <vector>

#include "gtest/gtest.h"

#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"

#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"

#include "ck/library/utility/device_memory.hpp"

using Pass     = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using F32 = float;
using F64 = double;

template <ck::index_t ABlockTransferSrcVectorDim,
          ck::index_t BBlockTransferSrcVectorDim,
          ck::index_t CDEBlockTransferScalarPerVector>
class ContractionInstanceWrapper
{
    public:
    static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
    static constexpr ck::index_t NumDim = 2;
    // clang-format off
    using ContractionDeviceInstance = ck::tensor_operation::device::
        //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle|         DsData| EData|            A|           B|          CDE|           GEMM| NumGemmK| Block|  MPer|  NPer|  KPer| AK1| BK1| MPer| NPer| MXdl| NXdl|  ABlockTransfer| ABlockTransfer| ABlockTransfer|             ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds|  BBlockTransfer| BBlockTransfer| BBlockTransfer|              BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|    CShuffle|    CShuffle| CBlockTransferClusterLengths|                  CBlockTransfer|
        //#####################################|        |        |        |  Type|  Type|    Type| DataType|           Type|  Type|  Elementwise| Elementwise|  Elementwise| Spacialization| Prefetch|  Size| Block| Block| Block|    |    |  XDL|  XDL|  Per|  Per|   ThreadCluster|  ThreadCluster| SrcAccessOrder|               SrcVectorDim|      SrcScalar|      DstScalar| AddExtraM|   ThreadCluster|  ThreadCluster| SrcAccessOrder|               SrcVectorDim|      SrcScalar|      DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave|         _MBlock_MWaveMPerXdl|                 ScalarPerVector|
        //#####################################|        |        |        |      |      |        |         |               |      |    Operation|   Operation|    Operation|               |    Stage|      |      |      |      |    |    |     |     | Wave| Wave| Lengths_K0_M_K1|   ArrangeOrder|               |                           |      PerVector|   PerVector_K1|          | Lengths_K0_N_K1|   ArrangeOrder|               |                           |      PerVector|   PerVector_K1|          |  PerShuffle|  PerShuffle|         _NBlock_NWaveNPerXdl|                   _NWaveNPerXdl|
        //#####################################|        |        |        |      |      |        |         |               |      |             |            |             |               |         |      |      |      |      |    |    |     |     |     |     |                |               |               |                           |               |               |          |                |               |               |                           |               |               |          |            |            |                             |                                |
        DeviceContractionMultipleD_Xdl_CShuffle<  NumDim,  NumDim,  NumDim,   F32,   F32,     F32,      F32, ck::Tuple<F32>,   F32,         Pass,        Pass,     Bilinear,       GemmSpec,        1,   256,   256,   128,    16,   4,   4,   32,   32,    4,    2,     S<4, 64, 1>,     S<1, 0, 2>,     S<1, 0, 2>, ABlockTransferSrcVectorDim,              4,              4,         1,     S<4, 64, 1>,     S<1, 0, 2>,     S<1, 0, 2>, BBlockTransferSrcVectorDim,              4,              4,         1,           1,           1,              S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>;
    // clang-format on

    bool isSupported(std::vector<ck::index_t>& ADims,
                     std::vector<ck::index_t>& BDims,
                     std::vector<ck::index_t>& DDims,
                     std::vector<ck::index_t>& EDims,
                     std::vector<ck::index_t>& AStrides,
                     std::vector<ck::index_t>& BStrides,
                     std::vector<ck::index_t>& DStrides,
                     std::vector<ck::index_t>& EStrides) const
    {
        auto contraction = ContractionDeviceInstance{};

        auto argument = contraction.MakeArgument(nullptr,
                                                 nullptr,
                                                 std::array<const void*, 1>{nullptr},
                                                 nullptr,
                                                 ADims,
                                                 AStrides,
                                                 BDims,
                                                 BStrides,
                                                 std::array<std::vector<ck::index_t>, 1>{DDims},
                                                 std::array<std::vector<ck::index_t>, 1>{DStrides},
                                                 EDims,
                                                 EStrides,
                                                 Pass{},
                                                 Pass{},
                                                 Bilinear{1.f, 1.f});
        return contraction.IsSupportedArgument(argument);
    }
};

template <typename DataTypeA,
          typename DataTypeB,
          typename DataTypeC,
          typename DataTypeD,
          ck::index_t NumDim>
class ContractionDeviceOpWrapper
{

    protected:
    using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<NumDim,
                                                                              NumDim,
                                                                              NumDim,
                                                                              DataTypeA,
                                                                              DataTypeB,
                                                                              ck::Tuple<DataTypeC>,
                                                                              DataTypeD,
                                                                              Pass,
                                                                              Pass,
                                                                              Bilinear>;

    public:
    bool IsSupportedInstance(std::vector<ck::index_t>& Dims,
                             std::vector<ck::index_t>& Strides) const
    {

        bool supported     = false;
        const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
            DeviceOp>::GetInstances();

        for(auto& op_ptr : op_ptrs)
        {
            auto argument_ptr =
                op_ptr->MakeArgumentPointer(nullptr,
                                            nullptr,
                                            std::array<const void*, 1>{nullptr},
                                            nullptr,
                                            Dims,
                                            Strides,
                                            Dims,
                                            Strides,
                                            std::array<std::vector<ck::index_t>, 1>{Dims},
                                            std::array<std::vector<ck::index_t>, 1>{Strides},
                                            Dims,
                                            Strides,
                                            Pass{},
                                            Pass{},
                                            Bilinear{1.f, 1.f});

            supported = supported || op_ptr->IsSupportedArgument(argument_ptr.get());
        }
        return supported;
    }
};

TEST(TestContractionInterface, IncorrectNumDims)
{
    std::vector<std::vector<ck::index_t>> Dims    = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}};
    std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}};
    ContractionDeviceOpWrapper<F32, F32, F32, F32, 1> wrapper_1d;
    ContractionDeviceOpWrapper<F32, F32, F32, F32, 2> wrapper_2d;
    ContractionDeviceOpWrapper<F32, F32, F32, F32, 3> wrapper_3d;
    EXPECT_FALSE(wrapper_1d.IsSupportedInstance(Dims[0], Strides[0]));
    EXPECT_TRUE(wrapper_2d.IsSupportedInstance(Dims[1], Strides[1]));
    EXPECT_FALSE(wrapper_3d.IsSupportedInstance(Dims[2], Strides[2]));
}

TEST(TestContractionInterface, IncorrectDataTypes)
{
    std::vector<ck::index_t> Dims    = {4, 4, 4, 4};
    std::vector<ck::index_t> Strides = {64, 16, 4, 1};
    ContractionDeviceOpWrapper<F32, F32, F64, F64, 2> wrapper_1;
    ContractionDeviceOpWrapper<F64, F64, F32, F32, 2> wrapper_2;
    EXPECT_FALSE(wrapper_1.IsSupportedInstance(Dims, Strides));
    EXPECT_FALSE(wrapper_2.IsSupportedInstance(Dims, Strides));
}

TEST(TestContractionSupportedArgs, ABMemoryAccess)
{
    std::vector<ck::index_t> Dims           = {4, 4, 4, 4};
    std::vector<ck::index_t> Strides        = {64, 16, 4, 1};
    std::vector<ck::index_t> StridesM1      = {4, 1, 64, 16};
    std::vector<ck::index_t> StridesK1      = {64, 16, 4, 1};
    std::vector<ck::index_t> InvalidStrides = {4, 4, 4, 4};
    // Memory access to A
    ContractionInstanceWrapper<1, 2, 4> wrapperA1;
    ContractionInstanceWrapper<2, 2, 4> wrapperA2;
    EXPECT_FALSE(
        wrapperA1.isSupported(Dims, Dims, Dims, Dims, InvalidStrides, Strides, Strides, Strides));
    EXPECT_FALSE(
        wrapperA2.isSupported(Dims, Dims, Dims, Dims, InvalidStrides, Strides, Strides, Strides));
    EXPECT_TRUE(
        wrapperA1.isSupported(Dims, Dims, Dims, Dims, StridesM1, Strides, Strides, Strides));
    EXPECT_TRUE(
        wrapperA2.isSupported(Dims, Dims, Dims, Dims, StridesK1, Strides, Strides, Strides));
    // Memory access to B
    ContractionInstanceWrapper<2, 1, 4> wrapperB1;
    ContractionInstanceWrapper<2, 2, 4> wrapperB2;
    EXPECT_FALSE(
        wrapperB1.isSupported(Dims, Dims, Dims, Dims, Strides, InvalidStrides, Strides, Strides));
    EXPECT_FALSE(
        wrapperB2.isSupported(Dims, Dims, Dims, Dims, Strides, InvalidStrides, Strides, Strides));
    EXPECT_TRUE(
        wrapperB1.isSupported(Dims, Dims, Dims, Dims, Strides, StridesM1, Strides, Strides));
    EXPECT_TRUE(
        wrapperB2.isSupported(Dims, Dims, Dims, Dims, Strides, StridesK1, Strides, Strides));
}

TEST(TestContractionSupportedArgs, DEMemoryAccess)
{
    std::vector<ck::index_t> Dims           = {4, 4, 4, 4};
    std::vector<ck::index_t> Strides        = {64, 16, 4, 1};
    std::vector<ck::index_t> InvalidStrides = {64, 16, 1, 4};
    ContractionInstanceWrapper<2, 2, 4> wrapper;
    // Memory access to D
    EXPECT_FALSE(
        wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, InvalidStrides, Strides));
    EXPECT_TRUE(wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, Strides));
    // Memory access to E
    EXPECT_FALSE(
        wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, InvalidStrides));
    EXPECT_TRUE(wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, Strides));
}