profile_reduce_impl.hpp 22.1 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
#pragma once
5

Chao Liu's avatar
Chao Liu committed
6
7
8
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"

9
#include "ck/library/tensor_operation_instance/gpu/reduce/reduce.hpp"
Po Yen Chen's avatar
Po Yen Chen committed
10
11
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
12
#include "ck/library/utility/device_memory.hpp"
13
#include "ck/library/reference_tensor_operation/cpu/reference_reduce.hpp"
14
15
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
16
17
18
19

namespace ck {
namespace tensor_operation {
namespace device {
20
namespace instance {
21

22
23
24
25
26
template <index_t Rank,
          index_t NumReduceDim,
          ReduceTensorOp ReduceOpId,
          bool PropagateNan,
          bool UseIndex>
27
28
struct ReduceDescription
{
29
30
31
32
33
    static constexpr index_t Rank_              = Rank;
    static constexpr index_t NumReduceDim_      = NumReduceDim;
    static constexpr ReduceTensorOp ReduceOpId_ = ReduceOpId;
    static constexpr bool PropagateNan_         = PropagateNan;
    static constexpr bool UseIndex_             = UseIndex;
34
35
};

36
using reduce_description_instances =
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    std::tuple<ReduceDescription<4, 3, ReduceTensorOp::ADD, false, false>, // for ADD
               ReduceDescription<4, 4, ReduceTensorOp::ADD, false, false>,
               ReduceDescription<4, 1, ReduceTensorOp::ADD, false, false>,
               ReduceDescription<2, 1, ReduceTensorOp::ADD, false, false>,

               ReduceDescription<4, 3, ReduceTensorOp::AVG, false, false>, // for AVG
               ReduceDescription<4, 4, ReduceTensorOp::AVG, false, false>,
               ReduceDescription<4, 1, ReduceTensorOp::AVG, false, false>,
               ReduceDescription<2, 1, ReduceTensorOp::AVG, false, false>,

               ReduceDescription<4, 3, ReduceTensorOp::NORM2, false, false>, // for NORM2
               ReduceDescription<4, 4, ReduceTensorOp::NORM2, false, false>,
               ReduceDescription<4, 1, ReduceTensorOp::NORM2, false, false>,
               ReduceDescription<2, 1, ReduceTensorOp::NORM2, false, false>,

               ReduceDescription<4, 3, ReduceTensorOp::MIN, false, false>, // for MIN
               ReduceDescription<4, 4, ReduceTensorOp::MIN, false, false>,
               ReduceDescription<4, 1, ReduceTensorOp::MIN, false, false>,
               ReduceDescription<2, 1, ReduceTensorOp::MIN, false, false>,
               ReduceDescription<4, 3, ReduceTensorOp::MAX, false, false>, // for MAX
               ReduceDescription<4, 4, ReduceTensorOp::MAX, false, false>,
               ReduceDescription<4, 1, ReduceTensorOp::MAX, false, false>,
               ReduceDescription<2, 1, ReduceTensorOp::MAX, false, false>,
               ReduceDescription<4, 3, ReduceTensorOp::AMAX, false, false>, // for AMAX
               ReduceDescription<4, 4, ReduceTensorOp::AMAX, false, false>,
               ReduceDescription<4, 1, ReduceTensorOp::AMAX, false, false>,
               ReduceDescription<2, 1, ReduceTensorOp::AMAX, false, false>,

               ReduceDescription<4, 3, ReduceTensorOp::MIN, false, true>, // for MIN
               ReduceDescription<4, 4, ReduceTensorOp::MIN, false, true>,
               ReduceDescription<4, 1, ReduceTensorOp::MIN, false, true>,
               ReduceDescription<2, 1, ReduceTensorOp::MIN, false, true>,
               ReduceDescription<4, 3, ReduceTensorOp::MAX, false, true>, // for MAX
               ReduceDescription<4, 4, ReduceTensorOp::MAX, false, true>,
               ReduceDescription<4, 1, ReduceTensorOp::MAX, false, true>,
               ReduceDescription<2, 1, ReduceTensorOp::MAX, false, true>,
               ReduceDescription<4, 3, ReduceTensorOp::AMAX, false, true>, // for AMAX
               ReduceDescription<4, 4, ReduceTensorOp::AMAX, false, true>,
               ReduceDescription<4, 1, ReduceTensorOp::AMAX, false, true>,
               ReduceDescription<2, 1, ReduceTensorOp::AMAX, false, true>>;
77
78
79
80

template <typename DescriptionType>
bool description_match(const DescriptionType& description,
                       int Rank,
Qianfeng's avatar
Qianfeng committed
81
                       const std::vector<int>& reduceDims,
82
                       ReduceTensorOp ReduceOpId,
83
84
                       bool PropagateNan,
                       bool UseIndex)
85
{
86
87
    if(description.Rank_ != Rank || description.ReduceOpId_ != ReduceOpId ||
       description.PropagateNan_ != PropagateNan || description.UseIndex_ != UseIndex)
88
89
        return (false);

Qianfeng's avatar
Qianfeng committed
90
    if(DescriptionType::NumReduceDim_ != reduceDims.size())
91
92
93
94
95
96
97
        return (false);

    bool result = true;

    return (result);
};

98
} // namespace instance
99
100
101
102
103
104
105
} // namespace device
} // namespace tensor_operation
} // namespace ck

namespace ck {
namespace profiler {

106
107
108
template <int Rank, int NumReduceDim>
static inline std::array<int, Rank - NumReduceDim>
get_invariant_dims(const std::array<int, NumReduceDim>& reduceDims)
109
{
Qianfeng's avatar
Qianfeng committed
110
    int reduceFlag = 0;
111

Qianfeng's avatar
Qianfeng committed
112
113
    // flag the bits for the reduceDims
    for(int i = 0; i < NumReduceDim; i++)
114
    {
Qianfeng's avatar
Qianfeng committed
115
        reduceFlag |= 1 << reduceDims[i];
116
117
    };

118
    std::array<int, Rank - NumReduceDim> invariantDims;
Qianfeng's avatar
Qianfeng committed
119
120

    // collect invariant dimensions
121
    int dim = 0;
Qianfeng's avatar
Qianfeng committed
122
123
124
    for(int i = 0; i < Rank; i++)
        if((reduceFlag & (1 << i)) == 0)
        {
125
126
            invariantDims[dim] = i;
            dim++;
Qianfeng's avatar
Qianfeng committed
127
128
129
        };

    return invariantDims;
130
131
132
133
134
135
};

template <typename InDataType,
          typename AccDataType,
          typename OutDataType,
          int Rank,
Qianfeng's avatar
Qianfeng committed
136
          int NumReduceDim,
137
          ReduceTensorOp ReduceOpId,
138
139
140
          bool PropagateNan,
          bool UseIndex>
bool profile_reduce_impl_impl(bool do_verification,
141
142
                              int init_method,
                              bool do_dumpout,
JD's avatar
JD committed
143
                              bool time_kernel,
144
                              const std::vector<size_t>& inLengths,
145
                              const std::array<int, NumReduceDim>& reduceDims,
146
147
148
149
                              float alpha,
                              float beta)
{
    using namespace ck::tensor_operation::device;
150
    using namespace ck::tensor_operation::device::instance;
151
    using ck::host_common::dumpBufferToFile;
152

153
154
    constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;

155
    constexpr bool op_support_indices =
156
157
        (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
         ReduceOpId == ReduceTensorOp::AMAX);
158

159
    constexpr bool OutputIndex = (op_support_indices && UseIndex);
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

    // 1) If InDataType is half_t, must use half_t as AccDataType for indexable reduction operations
    // 2) If InDataType is half_t, must use float as AccDataType for non-indexable reduction
    // operations
    constexpr bool invalid_reduce_1 =
        std::is_same<InDataType, half_t>::value &&
        ((!op_support_indices && !std::is_same<AccDataType, float>::value) ||
         (op_support_indices && !std::is_same<AccDataType, half_t>::value));

    // 1) If InDataType is float, must use float as AccDataType for indexable reduction operations
    constexpr bool invalid_reduce_2 =
        std::is_same<InDataType, float>::value &&
        (op_support_indices && !std::is_same<AccDataType, float>::value);

    // 1) The indices can only be used when the reduction operation is indexable
175
    constexpr bool invalid_reduce_3 = (!op_support_indices && UseIndex);
176

177
178
179
180
181
182
183
184
185
186
187
    // 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations
    // 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction
    // operations
    constexpr bool invalid_reduce_4 =
        std::is_same<InDataType, int8_t>::value &&
        ((!op_support_indices && !std::is_same<AccDataType, int32_t>::value) ||
         (op_support_indices && !std::is_same<AccDataType, int8_t>::value));

    // 1) If InDataType is int8_t, the supported operation must be either indexable operations or
    // ADD/AVG
    constexpr bool invalid_reduce_5 = std::is_same<InDataType, int8_t>::value &&
188
189
                                      (!op_support_indices && ReduceOpId != ReduceTensorOp::ADD &&
                                       ReduceOpId != ReduceTensorOp::AVG);
190
191
192
193
194
195
196

    // 1) If InDataType is bhalf_t, must use float as AccDataType for all reduction operations
    constexpr bool invalid_reduce_6 =
        std::is_same<InDataType, bhalf_t>::value && !std::is_same<AccDataType, float>::value;

    constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 ||
                                     invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6);
197

198
199
    int num_kernel = 0;
    bool pass      = true;
200

201
202
203
204
205
206
    if constexpr(!invalid_reduce)
    {
        Tensor<InDataType> in(inLengths);

        std::vector<size_t> outLengths;

Qianfeng's avatar
Qianfeng committed
207
208
209
        const auto invariantDims = get_invariant_dims<Rank, NumReduceDim>(reduceDims);

        if(reduceDims.size() == Rank)
210
211
            outLengths.push_back(1);
        else
Qianfeng's avatar
Qianfeng committed
212
            for(auto dim : invariantDims)
213
214
215
216
                outLengths.push_back(inLengths[dim]);

        Tensor<OutDataType> out_ref(outLengths);
        Tensor<OutDataType> out(outLengths);
217
218
        Tensor<int32_t> out_indices_ref(outLengths);
        Tensor<int32_t> out_indices(outLengths);
219
220
221
222
223
224
225

        auto inStrides  = in.mDesc.GetStrides();
        auto outStrides = out.mDesc.GetStrides();

        size_t invariant_total_length = out.mDesc.GetElementSize();
        size_t reduce_total_length    = in.mDesc.GetElementSize() / invariant_total_length;

226
        std::size_t num_thread = 1;
227
228
229
230
231

        if(do_verification)
        {
            switch(init_method)
            {
232
233
234
            case 0: break;
            case 1:
                in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
235
                if(beta != 0.0f)
236
                    out_ref.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
237
                break;
238
            case 2:
239
240
241
242
243
                in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
                if(beta != 0.0f)
                    out_ref.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
                break;
            default:
244
                in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
245
                if(beta != 0.0f)
246
247
                    out_ref.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0},
                                                num_thread);
248
249
250
            }

            if(beta != 0.0f)
251
                for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++)
252
253
254
255
                    out.mData[i] = out_ref.mData[i];
        };

        // these buffers are usually provided by the user application
256
257
        DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
        DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
258
259
260
261
262
263

        in_dev.ToDevice(in.mData.data());

        if(beta != 0.0f)
            out_dev.ToDevice(out.mData.data());

264
        size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int) : 0;
265
266
267
268
269
270

        DeviceMem out_indices_dev(indicesSizeInBytes);

        float best_avg_time   = 0;
        float best_gb_per_sec = 0;

271
        using InElementwiseOperation =
272
            typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
273
        using AccElementwiseOperation =
274
            typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
275

276
277
278
279
280
281
282
283
        using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;

        InElementwiseOperation in_elementwise_op;
        AccElementwiseOperation acc_elementwise_op;

        std::tie(in_elementwise_op, acc_elementwise_op) =
            reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
                static_cast<int32_t>(reduce_total_length));
284

285
286
287
288
289
290
291
292
293
294
295
296
297
        using ReduceOp = ck::tensor_operation::device::DeviceReduce<InDataType,
                                                                    AccDataType,
                                                                    OutDataType,
                                                                    Rank,
                                                                    NumReduceDim,
                                                                    ReduceOperation,
                                                                    InElementwiseOperation,
                                                                    AccElementwiseOperation,
                                                                    PropagateNan,
                                                                    OutputIndex>;
        const auto reduce_ptrs =
            ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
                ReduceOp>::GetInstances();
298

299
        if(reduce_ptrs.empty())
300
301
302
303
        {
            throw std::runtime_error("Wrong! No device REDUCE instance found");
        };

304
305
306
307
        std::array<index_t, Rank> arrInLengths;
        std::array<index_t, Rank> arrInStrides;
        std::array<index_t, NumOutDim> arrOutLengths;
        std::array<index_t, NumOutDim> arrOutStrides;
308

Po Yen Chen's avatar
Po Yen Chen committed
309
310
311
312
        ck::ranges::copy(inLengths, arrInLengths.begin());
        ck::ranges::copy(inStrides, arrInStrides.begin());
        ck::ranges::copy(outLengths, arrOutLengths.begin());
        ck::ranges::copy(outStrides, arrOutStrides.begin());
313

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        if(do_verification)
        {
            using ReferenceReduceInstance =
                ck::tensor_operation::host::ReferenceReduce<InDataType,
                                                            AccDataType,
                                                            OutDataType,
                                                            Rank,
                                                            NumReduceDim,
                                                            ReduceOperation,
                                                            InElementwiseOperation,
                                                            AccElementwiseOperation,
                                                            PropagateNan,
                                                            OutputIndex>;

            auto reduce_ref = ReferenceReduceInstance{};

            auto argument_ptr_ref = reduce_ref.MakeArgumentPointer(arrInLengths,
                                                                   arrInStrides,
                                                                   arrOutLengths,
                                                                   arrOutStrides,
                                                                   reduceDims,
335
336
                                                                   static_cast<double>(alpha),
                                                                   static_cast<double>(beta),
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
                                                                   in.mData.data(),
                                                                   nullptr,
                                                                   out_ref.mData.data(),
                                                                   out_indices_ref.mData.data(),
                                                                   in_elementwise_op,
                                                                   acc_elementwise_op);

            if(!reduce_ref.IsSupportedArgument(argument_ptr_ref.get()))
            {
                std::cout
                    << "The runtime parameters not supported by the reduce reference, exiting!"
                    << std::endl;
                return (false);
            };

            auto invoker_ptr_ref = reduce_ref.MakeInvokerPointer();

            (void)invoker_ptr_ref->Run(argument_ptr_ref.get());
        };

357
        for(auto& reduce_ptr : reduce_ptrs)
358
        {
359
360
361
362
            auto argument_ptr = reduce_ptr->MakeArgumentPointer(arrInLengths,
                                                                arrInStrides,
                                                                arrOutLengths,
                                                                arrOutStrides,
363
                                                                reduceDims,
364
365
                                                                static_cast<double>(alpha),
                                                                static_cast<double>(beta),
366
                                                                in_dev.GetDeviceBuffer(),
367
                                                                nullptr,
368
369
                                                                out_dev.GetDeviceBuffer(),
                                                                out_indices_dev.GetDeviceBuffer(),
370
371
                                                                in_elementwise_op,
                                                                acc_elementwise_op);
372
373
374

            if(!reduce_ptr->IsSupportedArgument(argument_ptr.get()))
                continue;
375
376
            else
                num_kernel++;
377
378
379
380
381

            std::string reduce_name = reduce_ptr->GetTypeString();

            auto invoker_ptr = reduce_ptr->MakeInvokerPointer();

JD's avatar
JD committed
382
383
            float avg_time =
                invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
384
385
386
387
388
389
390

            std::size_t num_bytes =
                invariant_total_length * reduce_total_length * sizeof(InDataType) +
                invariant_total_length * sizeof(OutDataType);

            float gb_per_sec = num_bytes / 1.E6 / avg_time;

391
392
393
            if(time_kernel)
                std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, "
                          << reduce_name << std::endl;
394
395
396
397
398
399
400
401
402

            if(gb_per_sec > best_gb_per_sec)
            {
                best_avg_time   = avg_time;
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
403
404
                bool single_pass;

405
                out_dev.FromDevice(out.mData.data());
406
                single_pass = ck::utils::check_err(out, out_ref);
407

408
                if(OutputIndex)
409
410
                {
                    out_indices_dev.FromDevice(out_indices.mData.data());
411
                    single_pass = single_pass && ck::utils::check_err(out_indices, out_indices_ref);
412
413
                };

414
                if(!single_pass)
415
                {
416
417
418
419
                    std::cout << "Fail Info: " << reduce_ptr->GetTypeString() << std::endl;
                }

                pass = pass && single_pass;
420
421
422
423
424
425
426
427
            };

            if(do_dumpout)
            {
                dumpBufferToFile("dump_in.bin", in.mData.data(), in.mDesc.GetElementSize());
                dumpBufferToFile("dump_out.bin", out.mData.data(), out.mDesc.GetElementSize());
                dumpBufferToFile(
                    "dump_out_host.bin", out_ref.mData.data(), out_ref.mDesc.GetElementSize());
428
                if(OutputIndex)
429
430
431
432
433
434
435
436
437
438
439
                {
                    dumpBufferToFile("dump_indices.bin",
                                     out_indices.mData.data(),
                                     out_indices.mDesc.GetElementSize());
                    dumpBufferToFile("dump_indices_host.bin",
                                     out_indices_ref.mData.data(),
                                     out_indices_ref.mDesc.GetElementSize());
                };
            };
        };

440
        if(time_kernel && num_kernel > 0)
441
442
            std::cout << "Best Perf: " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s"
                      << std::endl;
443
444
445
    }
    else
    {
446
447
448
449
450
451
452
453
        throw std::runtime_error(
            "The requested reduction operation is not supported, please check!");
    };

    if(num_kernel == 0)
    {
        std::cout << "Error: No kernel is applicable" << std::endl;
        return false;
454
    };
455
456

    return pass;
457
458
459
};

template <typename InDataType, typename AccDataType, typename OutDataType>
460
bool profile_reduce_impl(bool do_verification,
461
462
                         int init_method,
                         bool do_dumpout,
JD's avatar
JD committed
463
                         bool time_kernel,
464
                         const std::vector<size_t>& inLengths,
Qianfeng's avatar
Qianfeng committed
465
                         const std::vector<int>& reduceDims,
466
                         ReduceTensorOp ReduceOpId,
467
468
                         bool PropagateNan,
                         bool UseIndex,
469
470
471
472
                         float alpha,
                         float beta)
{
    bool matched = false;
473
    bool pass    = true;
474
475

    using tuple_of_description_instances =
476
        tensor_operation::device::instance::reduce_description_instances;
477
478
479
480
481
482
483
484
485
486

    const auto tuple_object = tuple_of_description_instances{};

    static_for<0, std::tuple_size<tuple_of_description_instances>::value, 1>{}([&](auto i) {
        if(matched)
            return;

        using descType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;

        if(!description_match(
487
               descType{}, inLengths.size(), reduceDims, ReduceOpId, PropagateNan, UseIndex))
488
489
            return;

490
491
        std::array<ck::index_t, descType::NumReduceDim_> arrReduceDims;

Po Yen Chen's avatar
Po Yen Chen committed
492
        ck::ranges::copy(reduceDims, arrReduceDims.begin());
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508

        pass = pass && profile_reduce_impl_impl<InDataType,
                                                AccDataType,
                                                OutDataType,
                                                descType::Rank_,
                                                descType::NumReduceDim_,
                                                static_cast<ReduceTensorOp>(descType::ReduceOpId_),
                                                descType::PropagateNan_,
                                                descType::UseIndex_>(do_verification,
                                                                     init_method,
                                                                     do_dumpout,
                                                                     time_kernel,
                                                                     inLengths,
                                                                     arrReduceDims,
                                                                     alpha,
                                                                     beta);
509
510
511

        matched = true;
    });
512
513

    return pass;
514
515
516
517
};

} // namespace profiler
} // namespace ck