profile_reduce_impl.hpp 20.2 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
#pragma once
5

Chao Liu's avatar
Chao Liu committed
6
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
7
#include "ck/utility/reduction_enums.hpp"
Chao Liu's avatar
Chao Liu committed
8
9

#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp"
10
11
12

#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
13
14
15
16
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_reduction.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
17
18
19
20

namespace ck {
namespace tensor_operation {
namespace device {
21
namespace instance {
22

23
template <int Rank, int NumReduceDim, int ReduceOpId, bool PropagateNan, bool UseIndex>
24
25
struct ReduceDescription
{
Qianfeng's avatar
Qianfeng committed
26
27
28
    static constexpr int Rank_         = Rank;
    static constexpr int NumReduceDim_ = NumReduceDim;
    static constexpr int ReduceOpId_   = ReduceOpId;
29
30
    static constexpr int PropagateNan_ = PropagateNan;
    static constexpr int UseIndex_     = UseIndex;
31
32
};

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
using reduce_description_instances =
    std::tuple<ReduceDescription<4, 3, 0, false, false>, // for ADD
               ReduceDescription<4, 4, 0, false, false>,
               ReduceDescription<4, 1, 0, false, false>,
               ReduceDescription<2, 1, 0, false, false>,

               ReduceDescription<4, 3, 5, false, false>, // for AVG
               ReduceDescription<4, 4, 5, false, false>,
               ReduceDescription<4, 1, 5, false, false>,
               ReduceDescription<2, 1, 5, false, false>,

               ReduceDescription<4, 3, 7, false, false>, // for NORM2
               ReduceDescription<4, 4, 7, false, false>,
               ReduceDescription<4, 1, 7, false, false>,
               ReduceDescription<2, 1, 7, false, false>,

               ReduceDescription<4, 3, 2, false, false>, // for MIN
               ReduceDescription<4, 4, 2, false, false>,
               ReduceDescription<4, 1, 2, false, false>,
               ReduceDescription<2, 1, 2, false, false>,
               ReduceDescription<4, 3, 3, false, false>, // for MAX
               ReduceDescription<4, 4, 3, false, false>,
               ReduceDescription<4, 1, 3, false, false>,
               ReduceDescription<2, 1, 3, false, false>,
               ReduceDescription<4, 3, 4, false, false>, // for AMAX
               ReduceDescription<4, 4, 4, false, false>,
               ReduceDescription<4, 1, 4, false, false>,
               ReduceDescription<2, 1, 4, false, false>,

               ReduceDescription<4, 3, 2, false, true>, // for MIN
               ReduceDescription<4, 4, 2, false, true>,
               ReduceDescription<4, 1, 2, false, true>,
               ReduceDescription<2, 1, 2, false, true>,
               ReduceDescription<4, 3, 3, false, true>, // for MAX
               ReduceDescription<4, 4, 3, false, true>,
               ReduceDescription<4, 1, 3, false, true>,
               ReduceDescription<2, 1, 3, false, true>,
               ReduceDescription<4, 3, 4, false, true>, // for AMAX
               ReduceDescription<4, 4, 4, false, true>,
               ReduceDescription<4, 1, 4, false, true>,
               ReduceDescription<2, 1, 4, false, true>>;
74
75
76
77

template <typename DescriptionType>
bool description_match(const DescriptionType& description,
                       int Rank,
Qianfeng's avatar
Qianfeng committed
78
                       const std::vector<int>& reduceDims,
79
                       ReduceTensorOp ReduceOpId,
80
81
                       bool PropagateNan,
                       bool UseIndex)
82
83
{
    if(description.Rank_ != Rank || description.ReduceOpId_ != static_cast<int>(ReduceOpId) ||
84
85
       description.PropagateNan_ != static_cast<int>(PropagateNan) ||
       description.UseIndex_ != static_cast<int>(UseIndex))
86
87
        return (false);

Qianfeng's avatar
Qianfeng committed
88
    if(DescriptionType::NumReduceDim_ != reduceDims.size())
89
90
91
92
93
94
95
        return (false);

    bool result = true;

    return (result);
};

96
} // namespace instance
97
98
99
100
101
102
103
} // namespace device
} // namespace tensor_operation
} // namespace ck

namespace ck {
namespace profiler {

Qianfeng's avatar
Qianfeng committed
104
105
template <index_t Rank, index_t NumReduceDim>
static inline std::vector<int> get_invariant_dims(const std::vector<int>& reduceDims)
106
{
Qianfeng's avatar
Qianfeng committed
107
    assert(NumReduceDim == reduceDims.size());
108

Qianfeng's avatar
Qianfeng committed
109
    int reduceFlag = 0;
110

Qianfeng's avatar
Qianfeng committed
111
112
    // flag the bits for the reduceDims
    for(int i = 0; i < NumReduceDim; i++)
113
    {
Qianfeng's avatar
Qianfeng committed
114
        reduceFlag |= 1 << reduceDims[i];
115
116
    };

Qianfeng's avatar
Qianfeng committed
117
118
119
120
121
122
123
124
125
126
    std::vector<int> invariantDims;

    // collect invariant dimensions
    for(int i = 0; i < Rank; i++)
        if((reduceFlag & (1 << i)) == 0)
        {
            invariantDims.push_back(i);
        };

    return invariantDims;
127
128
129
130
131
132
};

template <typename InDataType,
          typename AccDataType,
          typename OutDataType,
          int Rank,
Qianfeng's avatar
Qianfeng committed
133
          int NumReduceDim,
134
          ReduceTensorOp ReduceOpId,
135
136
137
          bool PropagateNan,
          bool UseIndex>
bool profile_reduce_impl_impl(bool do_verification,
138
139
                              int init_method,
                              bool do_dumpout,
JD's avatar
JD committed
140
                              bool time_kernel,
141
                              const std::vector<size_t>& inLengths,
Qianfeng's avatar
Qianfeng committed
142
                              const std::vector<int>& reduceDims,
143
144
145
146
                              float alpha,
                              float beta)
{
    using namespace ck::tensor_operation::device;
147
    using namespace ck::tensor_operation::device::instance;
148
    using ck::host_common::dumpBufferToFile;
149
150

    constexpr bool op_support_indices =
151
152
        (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
         ReduceOpId == ReduceTensorOp::AMAX);
153

154
    constexpr bool OutputIndex = (op_support_indices && UseIndex);
155
156
157

    constexpr bool out_support_atomic_add = std::is_same<OutDataType, float>::value;
    constexpr bool op_support_atomic_add =
158
        !op_support_indices && ReduceOpId != ReduceTensorOp::NORM2;
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    constexpr bool use_atomic_add = (out_support_atomic_add && op_support_atomic_add);

    // 1) If InDataType is half_t, must use half_t as AccDataType for indexable reduction operations
    // 2) If InDataType is half_t, must use float as AccDataType for non-indexable reduction
    // operations
    constexpr bool invalid_reduce_1 =
        std::is_same<InDataType, half_t>::value &&
        ((!op_support_indices && !std::is_same<AccDataType, float>::value) ||
         (op_support_indices && !std::is_same<AccDataType, half_t>::value));

    // 1) If InDataType is float, must use float as AccDataType for indexable reduction operations
    constexpr bool invalid_reduce_2 =
        std::is_same<InDataType, float>::value &&
        (op_support_indices && !std::is_same<AccDataType, float>::value);

    // 1) The indices can only be used when the reduction operation is indexable
175
    constexpr bool invalid_reduce_3 = (!op_support_indices && UseIndex);
176

177
178
179
180
181
182
183
184
185
186
187
    // 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations
    // 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction
    // operations
    constexpr bool invalid_reduce_4 =
        std::is_same<InDataType, int8_t>::value &&
        ((!op_support_indices && !std::is_same<AccDataType, int32_t>::value) ||
         (op_support_indices && !std::is_same<AccDataType, int8_t>::value));

    // 1) If InDataType is int8_t, the supported operation must be either indexable operations or
    // ADD/AVG
    constexpr bool invalid_reduce_5 = std::is_same<InDataType, int8_t>::value &&
188
189
                                      (!op_support_indices && ReduceOpId != ReduceTensorOp::ADD &&
                                       ReduceOpId != ReduceTensorOp::AVG);
190
191
192
193
194
195
196

    // 1) If InDataType is bhalf_t, must use float as AccDataType for all reduction operations
    constexpr bool invalid_reduce_6 =
        std::is_same<InDataType, bhalf_t>::value && !std::is_same<AccDataType, float>::value;

    constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 ||
                                     invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6);
197

198
199
    bool pass = true;

200
201
202
203
204
205
    if constexpr(!invalid_reduce)
    {
        Tensor<InDataType> in(inLengths);

        std::vector<size_t> outLengths;

Qianfeng's avatar
Qianfeng committed
206
207
208
        const auto invariantDims = get_invariant_dims<Rank, NumReduceDim>(reduceDims);

        if(reduceDims.size() == Rank)
209
210
            outLengths.push_back(1);
        else
Qianfeng's avatar
Qianfeng committed
211
            for(auto dim : invariantDims)
212
213
214
215
                outLengths.push_back(inLengths[dim]);

        Tensor<OutDataType> out_ref(outLengths);
        Tensor<OutDataType> out(outLengths);
216
217
        Tensor<int32_t> out_indices_ref(outLengths);
        Tensor<int32_t> out_indices(outLengths);
218

219
220
        auto inStrides  = in.GetStrides();
        auto outStrides = out.GetStrides();
221

222
223
        size_t invariant_total_length = out.GetElementSize();
        size_t reduce_total_length    = in.GetElementSize() / invariant_total_length;
224

225
        std::size_t num_thread = 1;
226
227
228
229
230

        if(do_verification)
        {
            switch(init_method)
            {
231
232
233
            case 0: break;
            case 1:
                in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
234
                if(beta != 0.0f)
235
                    out_ref.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
236
                break;
237
            case 2:
238
239
240
241
242
                in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
                if(beta != 0.0f)
                    out_ref.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
                break;
            default:
243
                in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
244
                if(beta != 0.0f)
245
246
                    out_ref.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0},
                                                num_thread);
247
248
249
            }

            if(beta != 0.0f)
250
251
252
            {
                ck::ranges::copy(out_ref, out.begin());
            }
253
254
255
        };

        // these buffers are usually provided by the user application
256
257
        DeviceMem in_dev(in.GetMemorySize());
        DeviceMem out_dev(out.GetMemorySize());
258

259
        in_dev.ToDevice(in.data());
260
261

        if(beta != 0.0f)
262
            out_dev.ToDevice(out.data());
263

264
        size_t indicesSizeInBytes = OutputIndex ? out.GetElementSize() * sizeof(int) : 0;
265
266
267
268
269
270

        DeviceMem out_indices_dev(indicesSizeInBytes);

        float best_avg_time   = 0;
        float best_gb_per_sec = 0;

271
        using InElementwiseOperation =
272
            typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
273
        using AccElementwiseOperation =
274
            typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
275

276
277
278
279
280
281
282
283
        using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;

        InElementwiseOperation in_elementwise_op;
        AccElementwiseOperation acc_elementwise_op;

        std::tie(in_elementwise_op, acc_elementwise_op) =
            reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
                static_cast<int32_t>(reduce_total_length));
284

285
        using DeviceReduceInstPtr0 =
286
            DeviceReducePtr<InElementwiseOperation, AccElementwiseOperation>;
287
288
289
290
291
292
293

        std::vector<DeviceReduceInstPtr0> reduce0_ptrs;

        add_device_reduce_instance_threadwise<InDataType,
                                              AccDataType,
                                              OutDataType,
                                              Rank,
Qianfeng's avatar
Qianfeng committed
294
                                              NumReduceDim,
295
                                              ReduceOpId,
296
297
                                              PropagateNan,
                                              UseIndex>(reduce0_ptrs);
298
299
300
301
302

        add_device_reduce_instance_blockwise<InDataType,
                                             AccDataType,
                                             OutDataType,
                                             Rank,
Qianfeng's avatar
Qianfeng committed
303
                                             NumReduceDim,
304
                                             ReduceOpId,
305
306
                                             PropagateNan,
                                             UseIndex>(reduce0_ptrs);
307
308

        if constexpr(use_atomic_add)
309
        {
310
311
312
313
            add_device_reduce_instance_multiblock_atomic_add<InDataType,
                                                             AccDataType,
                                                             OutDataType,
                                                             Rank,
Qianfeng's avatar
Qianfeng committed
314
                                                             NumReduceDim,
315
                                                             ReduceOpId,
316
317
                                                             PropagateNan,
                                                             UseIndex>(reduce0_ptrs);
318
        }
319

320
        if(reduce0_ptrs.empty())
321
322
323
324
325
326
        {
            throw std::runtime_error("Wrong! No device REDUCE instance found");
        };

        if(do_verification)
        {
327
328
329
            ReductionHost<InDataType,
                          AccDataType,
                          OutDataType,
330
331
332
                          ReduceOperation,
                          InElementwiseOperation,
                          AccElementwiseOperation,
333
334
335
                          Rank,
                          NumReduceDim,
                          PropagateNan,
336
                          OutputIndex>
337
                hostReduce(in.GetDesc(), out_ref.GetDesc(), invariantDims, reduceDims);
338

339
            hostReduce.Run(alpha,
340
                           in.data(),
341
                           beta,
342
343
                           out_ref.data(),
                           out_indices_ref.data(),
344
345
                           in_elementwise_op,
                           acc_elementwise_op);
346
347
        };

348
349
350
351
352
353
354
355
356
        std::vector<ck::index_t> i_inLengths;
        std::vector<ck::index_t> i_inStrides;
        std::vector<ck::index_t> i_outLengths;
        std::vector<ck::index_t> i_outStrides;

        i_inLengths.assign(inLengths.begin(), inLengths.end());
        i_inStrides.assign(inStrides.begin(), inStrides.end());
        i_outLengths.assign(outLengths.begin(), outLengths.end());
        i_outStrides.assign(outStrides.begin(), outStrides.end());
357
358
359

        for(auto& reduce_ptr : reduce0_ptrs)
        {
360
361
362
363
364
365
366
367
            auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths,
                                                                i_inStrides,
                                                                i_outLengths,
                                                                i_outStrides,
                                                                reduceDims,
                                                                alpha,
                                                                beta,
                                                                in_dev.GetDeviceBuffer(),
368
                                                                nullptr,
369
370
                                                                out_dev.GetDeviceBuffer(),
                                                                out_indices_dev.GetDeviceBuffer(),
371
372
                                                                in_elementwise_op,
                                                                acc_elementwise_op);
373
374
375
376
377
378
379
380

            if(!reduce_ptr->IsSupportedArgument(argument_ptr.get()))
                continue;

            std::string reduce_name = reduce_ptr->GetTypeString();

            auto invoker_ptr = reduce_ptr->MakeInvokerPointer();

JD's avatar
JD committed
381
382
            float avg_time =
                invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
383
384
385
386
387
388
389

            std::size_t num_bytes =
                invariant_total_length * reduce_total_length * sizeof(InDataType) +
                invariant_total_length * sizeof(OutDataType);

            float gb_per_sec = num_bytes / 1.E6 / avg_time;

390
391
392
            if(time_kernel)
                std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, "
                          << reduce_name << std::endl;
393
394
395
396
397
398
399
400
401

            if(gb_per_sec > best_gb_per_sec)
            {
                best_avg_time   = avg_time;
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
402
403
                bool single_pass;

404
405
                out_dev.FromDevice(out.data());
                single_pass = ck::utils::check_err(out, out_ref);
406

407
                if(OutputIndex)
408
                {
409
410
                    out_indices_dev.FromDevice(out_indices.data());
                    single_pass = single_pass && ck::utils::check_err(out_indices, out_indices_ref);
411
412
                };

413
                if(!single_pass)
414
                {
415
416
417
418
                    std::cout << "Fail Info: " << reduce_ptr->GetTypeString() << std::endl;
                }

                pass = pass && single_pass;
419
420
421
422
            };

            if(do_dumpout)
            {
423
424
425
                dumpBufferToFile("dump_in.bin", in.data(), in.GetElementSize());
                dumpBufferToFile("dump_out.bin", out.data(), out.GetElementSize());
                dumpBufferToFile("dump_out_host.bin", out_ref.data(), out_ref.GetElementSize());
426
                if(OutputIndex)
427
                {
428
429
                    dumpBufferToFile(
                        "dump_indices.bin", out_indices.data(), out_indices.GetElementSize());
430
                    dumpBufferToFile("dump_indices_host.bin",
431
432
                                     out_indices_ref.data(),
                                     out_indices_ref.GetElementSize());
433
434
435
436
                };
            };
        };

437
438
439
        if(time_kernel)
            std::cout << "Best Perf: " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s"
                      << std::endl;
440
441
442
443
444
445
    }
    else
    {
        std::cout << "The requested reduction operation is not supported, please check !!!"
                  << std::endl;
    };
446
447

    return pass;
448
449
450
};

template <typename InDataType, typename AccDataType, typename OutDataType>
451
bool profile_reduce_impl(bool do_verification,
452
453
                         int init_method,
                         bool do_dumpout,
JD's avatar
JD committed
454
                         bool time_kernel,
455
                         const std::vector<size_t>& inLengths,
Qianfeng's avatar
Qianfeng committed
456
                         const std::vector<int>& reduceDims,
457
                         ReduceTensorOp ReduceOpId,
458
459
                         bool PropagateNan,
                         bool UseIndex,
460
461
462
463
                         float alpha,
                         float beta)
{
    bool matched = false;
464
    bool pass    = true;
465
466

    using tuple_of_description_instances =
467
        tensor_operation::device::instance::reduce_description_instances;
468
469
470
471
472
473
474
475
476
477

    const auto tuple_object = tuple_of_description_instances{};

    static_for<0, std::tuple_size<tuple_of_description_instances>::value, 1>{}([&](auto i) {
        if(matched)
            return;

        using descType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;

        if(!description_match(
478
               descType{}, inLengths.size(), reduceDims, ReduceOpId, PropagateNan, UseIndex))
479
480
            return;

481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
        pass = pass &&
               profile_reduce_impl_impl<InDataType,
                                        AccDataType,
                                        OutDataType,
                                        descType::Rank_,
                                        descType::NumReduceDim_,
                                        static_cast<ReduceTensorOp>(descType::ReduceOpId_),
                                        static_cast<bool>(descType::PropagateNan_),
                                        static_cast<bool>(descType::UseIndex_)>(do_verification,
                                                                                init_method,
                                                                                do_dumpout,
                                                                                time_kernel,
                                                                                inLengths,
                                                                                reduceDims,
                                                                                alpha,
                                                                                beta);
497
498
499

        matched = true;
    });
500
501

    return pass;
502
503
504
505
};

} // namespace profiler
} // namespace ck