"test/vscode:/vscode.git/clone" did not exist on "1bdd010291e4878ad5768392ec6a7cf8acd79be9"
reduce_blockwise.cpp 12.4 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
5
6
7
8
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
9

Chao Liu's avatar
Chao Liu committed
10
11
12
13
14
15
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"

#include "ck/library/utility/check_err.hpp"
16
17
18
19
20
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_reduction.hpp"
21
22
23
24

using namespace ck;
using namespace ck::tensor_operation::device;

25
26
using InDataType  = ck::half_t;
using OutDataType = ck::half_t;
27
28
using AccDataType = float;

Qianfeng's avatar
Qianfeng committed
29
30
constexpr int Rank         = 4;
constexpr int NumReduceDim = 3;
31

32
constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
33
34
constexpr bool PropagateNan         = true;
constexpr bool OutputIndex          = false;
35

36
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
37
using InElementwiseOperation =
38
    typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
39
using AccElementwiseOperation =
40
    typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
using DeviceReduceInstance = DeviceReduceMultiBlock<InDataType,
                                                    AccDataType,
                                                    OutDataType,
                                                    Rank,
                                                    NumReduceDim,
                                                    ReduceOperation,
                                                    InElementwiseOperation,
                                                    AccElementwiseOperation,
                                                    InMemoryDataOperationEnum::Set,
                                                    PropagateNan,
                                                    OutputIndex,
                                                    false, // HaveIndexInputIfOutputIndex
                                                    256,
                                                    4,
                                                    64,
                                                    1,
                                                    1,
                                                    0,
                                                    1,
                                                    1>;
62
63
64
65
66
67
68
69
70
71
72
73

static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
                                       {"verify", required_argument, nullptr, 'v'},
                                       {"help", no_argument, nullptr, '?'},
                                       {nullptr, 0, nullptr, 0}};

class SimpleAppArgs
{
    private:
    int option_index = 0;

    public:
74
75
    std::vector<size_t> inLengths = {16, 64, 32, 960};
    std::vector<float> scales     = {1.0f, 0.0f};
76

JD's avatar
JD committed
77
78
    bool do_verification = true;
    int init_method      = 1;
79
    bool time_kernel     = true;
80
81
82
83
84
85
86
87
88
89

    public:
    void show_usage(const char* cmd)
    {
        std::cout << "Usage of " << cmd << std::endl;
        std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths"
                  << std::endl;
        std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by "
                     "comparing with the host-based reduction"
                  << std::endl;
90
91
92
        std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer "
                     "value, 3=decimal value)"
                  << std::endl;
93
        std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl;
94
95
96
97
    };

    int processArgs(int argc, char* argv[])
    {
98
99
        using ck::host_common::getTypeValuesFromString;

100
        int ch;
101
102
103

        while(1)
        {
104
            ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index);
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
            if(ch == -1)
                break;
            switch(ch)
            {
            case 'D':
                if(!optarg)
                    throw std::runtime_error("Invalid option format!");

                inLengths = getTypeValuesFromString<size_t>(optarg);
                break;
            case 'v':
                if(!optarg)
                    throw std::runtime_error("Invalid option format!");

                do_verification = static_cast<bool>(std::atoi(optarg));
                break;
            case '?':
                if(std::string(long_options[option_index].name) == "help")
                {
                    show_usage(argv[0]);
                    return (-1);
                };
                break;
            default: show_usage(argv[0]); return (-1);
            };
        };

        if(optind + 2 > argc)
            throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");

        init_method = std::atoi(argv[optind++]);
136
        time_kernel = static_cast<bool>(std::atoi(argv[optind]));
137
138
139
140
141
142
143
144
145
146
147
148
149

        if(scales.empty())
        {
            scales.push_back(1.0f);
            scales.push_back(0.0f);
        };

        return (0);
    };
};

int main(int argc, char* argv[])
{
Qianfeng's avatar
Qianfeng committed
150
151
152
    const std::vector<int> reduceDims{0, 1, 2};
    const std::vector<int> invariantDims{3};

153
154
    SimpleAppArgs args;

155
156
157
158
159
    if(argc > 1)
    {
        if(args.processArgs(argc, argv) < 0)
            return (-1);
    };
160
161

    constexpr bool op_support_indices =
162
163
        (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
         ReduceOpId == ReduceTensorOp::AMAX);
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    // if input is half type, no reason to use float for indiced reduction operation and must use
    // float for non-indiced reduction operation for accuracy
    constexpr bool invalid_reduce_1 =
        std::is_same<InDataType, ck::half_t>::value &&
        ((!op_support_indices && !std::is_same<AccDataType, float>::value) ||
         (op_support_indices && !std::is_same<AccDataType, ck::half_t>::value));

    // if input is float type, no reason to use double for indiced reduction operation
    constexpr bool invalid_reduce_2 =
        std::is_same<InDataType, float>::value &&
        (op_support_indices && !std::is_same<AccDataType, float>::value);

    // indices option can only be used when it is really needed
178
    constexpr bool invalid_reduce_3 = (!op_support_indices && OutputIndex);
179
180
181
182
183
184
185
186
187
188

    constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3);

    if constexpr(invalid_reduce)
        std::cout << "Reduction setting is not supported, exiting!" << std::endl;

    Tensor<InDataType> in(args.inLengths);

    std::vector<size_t> outLengths;

Qianfeng's avatar
Qianfeng committed
189
    if(invariantDims.empty())
190
191
        outLengths.push_back(1);
    else
Qianfeng's avatar
Qianfeng committed
192
        for(auto dim : invariantDims)
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
            outLengths.push_back(args.inLengths[dim]);

    Tensor<OutDataType> out_ref(outLengths);
    Tensor<OutDataType> out(outLengths);
    Tensor<int> out_indices_ref(outLengths);
    Tensor<int> out_indices(outLengths);

    auto inStrides  = in.mDesc.GetStrides();
    auto outStrides = out.mDesc.GetStrides();

    size_t invariant_total_length = out.mDesc.GetElementSize();
    size_t reduce_total_length    = in.mDesc.GetElementSize() / invariant_total_length;

    float alpha = args.scales[0];
    float beta  = args.scales[1];

209
    std::size_t num_thread = 1;
210
211
212
213
214

    if(args.do_verification)
    {
        switch(args.init_method)
        {
215
216
217
        case 0: break;
        case 1:
            in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
218
            if(beta != 0.0f)
219
                out_ref.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
220
            break;
221
        case 2:
222
223
224
225
226
            in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
            if(beta != 0.0f)
                out_ref.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
            break;
        default:
227
            in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
228
            if(beta != 0.0f)
229
                out_ref.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
230
231
232
        }

        if(beta != 0.0f)
233
            for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++)
234
235
236
237
                out.mData[i] = out_ref.mData[i];
    };

    // these buffers are usually provided by the user application
238
239
    DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
    DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
240
241
242
243
244
245

    in_dev.ToDevice(in.mData.data());

    if(beta != 0.0f)
        out_dev.ToDevice(out.mData.data());

246
    size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0;
247

248
    DeviceMem out_index_dev(indicesSizeInBytes);
249

250
251
252
253
254
255
256
    InElementwiseOperation in_elementwise_op;
    AccElementwiseOperation acc_elementwise_op;

    std::tie(in_elementwise_op, acc_elementwise_op) =
        reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
            static_cast<int32_t>(reduce_total_length));

257
258
    if(args.do_verification)
    {
259
260
261
        ReductionHost<InDataType,
                      AccDataType,
                      OutDataType,
262
263
264
                      ReduceOperation,
                      InElementwiseOperation,
                      AccElementwiseOperation,
265
266
267
                      Rank,
                      NumReduceDim,
                      PropagateNan,
268
                      OutputIndex>
Qianfeng's avatar
Qianfeng committed
269
            hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
270

271
272
273
274
275
276
277
        hostReduce.Run(alpha,
                       in.mData.data(),
                       beta,
                       out_ref.mData.data(),
                       out_indices_ref.mData.data(),
                       in_elementwise_op,
                       acc_elementwise_op);
278
279
    };

280
281
282
283
284
285
286
287
288
    std::vector<ck::index_t> i_inLengths;
    std::vector<ck::index_t> i_inStrides;
    std::vector<ck::index_t> i_outLengths;
    std::vector<ck::index_t> i_outStrides;

    i_inLengths.assign(args.inLengths.begin(), args.inLengths.end());
    i_inStrides.assign(inStrides.begin(), inStrides.end());
    i_outLengths.assign(outLengths.begin(), outLengths.end());
    i_outStrides.assign(outStrides.begin(), outStrides.end());
289
290
291

    auto reduce = DeviceReduceInstance{};

292
293
294
295
296
297
298
299
300
301
302
303
304
    auto argument_ptr = reduce.MakeArgumentPointer(i_inLengths,
                                                   i_inStrides,
                                                   i_outLengths,
                                                   i_outStrides,
                                                   reduceDims,
                                                   alpha,
                                                   beta,
                                                   in_dev.GetDeviceBuffer(),
                                                   nullptr,
                                                   out_dev.GetDeviceBuffer(),
                                                   out_index_dev.GetDeviceBuffer(),
                                                   in_elementwise_op,
                                                   acc_elementwise_op);
305
306
307
308
309
310
311
312
313
314
315
316

    if(!reduce.IsSupportedArgument(argument_ptr.get()))
    {
        std::cout
            << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
            << std::endl;
    };

    std::string reduce_name = reduce.GetTypeString();

    auto invoker_ptr = reduce.MakeInvokerPointer();

JD's avatar
JD committed
317
    float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel});
318
319
320
321
322
323
324
325
326

    std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InDataType) +
                            invariant_total_length * sizeof(OutDataType);

    float gb_per_sec = num_bytes / 1.E6 / avg_time;

    std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name
              << std::endl;

Anthony Chang's avatar
Anthony Chang committed
327
    bool pass = true;
328

329
330
331
    if(args.do_verification)
    {
        out_dev.FromDevice(out.mData.data());
332
        pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
333

334
        if(OutputIndex)
335
        {
336
337
            out_index_dev.FromDevice(out_indices.mData.data());
            pass = pass && ck::utils::check_err(out_indices.mData, out_indices_ref.mData);
338
339
        };
    };
340
341

    return (pass ? 0 : 1);
342
}