profile_convnd_bwd_data_impl.hpp 18.3 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
#pragma once
Chao Liu's avatar
Chao Liu committed
5
6
7

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
8
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
Chao Liu's avatar
Chao Liu committed
9
10
11
12
13
14
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
15
16
17
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/utility/ranges.hpp"
18
19
20

using F16  = ck::half_t;
using F32  = float;
21
using BF16 = ck::bhalf_t;
22
using INT8 = int8_t;
Chao Liu's avatar
Chao Liu committed
23

24
25
26
namespace ck {
namespace tensor_operation {
namespace device {
27
namespace instance {
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

using DeviceConvBwdDataNoOpPtr =
    DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
                         ck::tensor_operation::element_wise::PassThrough,
                         ck::tensor_operation::element_wise::PassThrough>;
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
    std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(
    std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
    std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
    std::vector<DeviceConvBwdDataNoOpPtr>&);

void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
    std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
    std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
    std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
    std::vector<DeviceConvBwdDataNoOpPtr>&);

void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
    std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
    std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
    std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
    std::vector<DeviceConvBwdDataNoOpPtr>&);
59
} // namespace instance
60
61
62
63
64
65
} // namespace device
} // namespace tensor_operation
} // namespace ck

namespace ck {
namespace profiler {
66
using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device::instance::DeviceConvBwdDataNoOpPtr;
67
68
69
70
71
72
73
74
75
76

template <typename InLayout>
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
                                                      int num_dim_spatial = 2)
{
    namespace tl = ck::tensor_layout::convolution;

    switch(num_dim_spatial)
    {
    case 3: {
77
        return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
78
79
    }
    case 2: {
80
        return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
81
82
    }
    case 1: {
83
        return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    }
    default: {
        throw std::runtime_error("Unsupported number of spatial dimensions provided!");
    }
    }
}
template <typename WeiLayout>
HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::size_t>& dims,
                                                        int num_dim_spatial = 2)
{
    namespace tl = ck::tensor_layout::convolution;

    switch(num_dim_spatial)
    {
    case 3: {
99
        return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
100
101
    }
    case 2: {
102
        return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
103
104
    }
    case 1: {
105
        return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    }
    default: {
        throw std::runtime_error("Unsupported number of spatial dimensions provided!");
    }
    }
}
template <typename OutLayout>
HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector<std::size_t>& dims,
                                                      int num_dim_spatial = 2)
{
    namespace tl = ck::tensor_layout::convolution;

    switch(num_dim_spatial)
    {
    case 3: {
121
        return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
122
123
    }
    case 2: {
124
        return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
125
126
    }
    case 1: {
127
        return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    }
    default: {
        throw std::runtime_error("Unsupported number of spatial dimensions provided!");
    }
    }
}
template <typename InDataType, typename WeiDataType, typename OutDataType>
void get_device_conv_bwd_data_op_ptr(
    InDataType, WeiDataType, OutDataType, std::vector<DeviceConvBwdDataNoOpPtr>&, int)
{
    std::cout << "can not find device conv bwd data" << std::endl;
    exit(1);
}
template <>
void get_device_conv_bwd_data_op_ptr(
    F32, F32, F32, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
{
    switch(num_dim_spatial)
    {
    case 1:
148
        ck::tensor_operation::device::instance::
149
150
151
            add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
        break;
    case 2:
152
        ck::tensor_operation::device::instance::
153
154
155
            add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
        break;
    case 3:
156
        ck::tensor_operation::device::instance::
157
158
159
160
161
162
163
164
165
166
167
168
            add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
        break;
    default: break;
    }
}
template <>
void get_device_conv_bwd_data_op_ptr(
    F16, F16, F16, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
{
    switch(num_dim_spatial)
    {
    case 1:
169
        ck::tensor_operation::device::instance::
170
171
172
            add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
        break;
    case 2:
173
        ck::tensor_operation::device::instance::
174
175
176
            add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
        break;
    case 3:
177
        ck::tensor_operation::device::instance::
178
179
180
181
182
183
184
185
186
187
188
189
            add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
        break;
    default: break;
    }
}
template <>
void get_device_conv_bwd_data_op_ptr(
    BF16, BF16, BF16, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
{
    switch(num_dim_spatial)
    {
    case 1:
190
        ck::tensor_operation::device::instance::
191
192
193
            add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
        break;
    case 2:
194
        ck::tensor_operation::device::instance::
195
196
197
            add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
        break;
    case 3:
198
        ck::tensor_operation::device::instance::
199
200
201
202
203
204
205
206
207
208
209
210
            add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
        break;
    default: break;
    }
}
template <>
void get_device_conv_bwd_data_op_ptr(
    INT8, INT8, INT8, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
{
    switch(num_dim_spatial)
    {
    case 1:
211
        ck::tensor_operation::device::instance::
212
213
214
            add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
        break;
    case 2:
215
        ck::tensor_operation::device::instance::
216
217
218
            add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
        break;
    case 3:
219
        ck::tensor_operation::device::instance::
220
221
222
223
224
225
226
227
228
229
230
            add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
        break;
    default: break;
    }
}

template <typename T>
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
{
    float max_diff = 1e-6;

231
    for(std::size_t i = 0; i < ref.mData.size(); ++i)
232
233
234
235
236
237
238
239
240
241
242
243
244
    {
        float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
        if(max_diff < diff)
        {
            return false;
        }
    }
    return true;
}
template <typename DataType>
void show_data_nhwc_layout(Tensor<DataType>& nhwc)
{
    std::cout << "[";
245
    for(int n = 0; n < ck::type_convert<int>(nhwc.GetLengths()[0]); n++)
246
247
    {
        std::cout << "[";
248
        for(int hi = 0; hi < ck::type_convert<int>(nhwc.GetLengths()[2]); hi++)
249
250
        {
            std::cout << "[";
251
            for(int wi = 0; wi < ck::type_convert<int>(nhwc.GetLengths()[3]); wi++)
252
253
            {
                std::cout << "[";
254
                for(int c = 0; c < ck::type_convert<int>(nhwc.GetLengths()[1]); c++)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
                {
                    std::cout << static_cast<float>(nhwc(n, c, hi, wi)) << "  ";
                }
                std::cout << "]";
            }
            std::cout << "]";
        }
        std::cout << "]";
    }
    std::cout << "]";
}

template <int NDimSpatial,
          typename InDataType,
          typename WeiDataType,
          typename OutDataType,
          typename AccDataType,
          typename InLayout,
          typename WeiLayout,
          typename OutLayout>
bool profile_convnd_bwd_data_impl(int do_verification,
                                  int init_method,
                                  bool do_log,
JD's avatar
JD committed
278
                                  bool time_kernel,
279
280
281
                                  ck::index_t N,
                                  ck::index_t K,
                                  ck::index_t C,
ltqin's avatar
ltqin committed
282
283
284
285
286
287
288
                                  const std::vector<ck::index_t>& input_spatial_lengths,
                                  const std::vector<ck::index_t>& filter_spatial_lengths,
                                  const std::vector<ck::index_t>& output_spatial_lengths,
                                  const std::vector<ck::index_t>& conv_filter_strides,
                                  const std::vector<ck::index_t>& conv_filter_dilations,
                                  const std::vector<ck::index_t>& input_left_pads,
                                  const std::vector<ck::index_t>& input_right_pads)
289
290
291
292
293
294
295
296
297
{
    using InElementOp  = ck::tensor_operation::element_wise::PassThrough;
    using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
    using OutElementOp = ck::tensor_operation::element_wise::PassThrough;

    const auto in_element_op  = InElementOp{};
    const auto wei_element_op = WeiElementOp{};
    const auto out_element_op = OutElementOp{};

298
    auto input_dims = ck::ranges::to<std::vector<std::size_t>>({N, C});
299
300
301
    input_dims.insert(
        std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths));

302
    auto filter_dims = ck::ranges::to<std::vector<std::size_t>>({K, C});
303
304
305
306
    filter_dims.insert(std::end(filter_dims),
                       std::begin(filter_spatial_lengths),
                       std::end(filter_spatial_lengths));

307
    auto output_dims = ck::ranges::to<std::vector<std::size_t>>({N, K});
308
309
310
311
    output_dims.insert(std::end(output_dims),
                       std::begin(output_spatial_lengths),
                       std::end(output_spatial_lengths));

ltqin's avatar
ltqin committed
312
    Tensor<InDataType> input_host_result(
313
        get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
ltqin's avatar
ltqin committed
314
    Tensor<InDataType> input_device_result(
315
        get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
ltqin's avatar
ltqin committed
316
    Tensor<WeiDataType> weights(
317
        get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial));
ltqin's avatar
ltqin committed
318
    Tensor<OutDataType> output(
319
320
        get_output_host_ensor_descriptor<OutLayout>(output_dims, NDimSpatial));

321
322
323
    std::cout << "input: " << input_host_result.GetDesc() << std::endl;
    std::cout << "weights: " << weights.GetDesc() << std::endl;
    std::cout << "output: " << output.GetDesc() << std::endl;
324
325
326
327
328

    switch(init_method)
    {
    case 0: break;
    case 1:
ltqin's avatar
ltqin committed
329
330
        output.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
        weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
331
332
        break;
    default:
ltqin's avatar
ltqin committed
333
334
        output.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
        weights.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
335
336
    }

337
338
339
    DeviceMem in_device_buf(input_device_result.GetMemorySize());
    DeviceMem wei_device_buf(weights.GetMemorySize());
    DeviceMem out_device_buf(output.GetMemorySize());
340

341
342
    out_device_buf.ToDevice(output.data());
    wei_device_buf.ToDevice(weights.data());
343
344

    // reset input to zero
ltqin's avatar
ltqin committed
345
    in_device_buf.SetZero();
346
347
348
349
350
351

    if(do_verification)
    {
        auto RunReference = [&](auto& ref_conv) {
            auto ref_invoker = ref_conv.MakeInvoker();

ltqin's avatar
ltqin committed
352
353
354
            auto ref_argument = ref_conv.MakeArgument(input_host_result,
                                                      weights,
                                                      output,
355
356
357
358
359
360
361
362
363
                                                      conv_filter_strides,
                                                      conv_filter_dilations,
                                                      input_left_pads,
                                                      input_right_pads,
                                                      InElementOp{},
                                                      WeiElementOp{},
                                                      OutElementOp{});
            ref_invoker.Run(ref_argument);
        };
ltqin's avatar
ltqin committed
364
365
366
367
368
369
370
371
372
373

        auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
                                                                         WeiDataType,
                                                                         OutDataType,
                                                                         AccDataType,
                                                                         InElementOp,
                                                                         WeiElementOp,
                                                                         OutElementOp,
                                                                         NDimSpatial>();
        RunReference(ref_conv);
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    }

    // add device Conv instances
    std::vector<DeviceConvBwdDataNoOpPtr> conv_ptrs;
    get_device_conv_bwd_data_op_ptr(
        InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial);

    if(conv_ptrs.size() <= 0)
    {
        throw std::runtime_error("wrong! no device Conv instance found");
    }

    std::string best_conv_name;
    float best_ave_time   = 0;
    float best_tflops     = 0;
    float best_gb_per_sec = 0;

    // profile device Conv instances
    bool success = true;
    for(auto& conv_ptr : conv_ptrs)
    {
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
        auto argument_ptr = conv_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),
                                                          wei_device_buf.GetDeviceBuffer(),
                                                          out_device_buf.GetDeviceBuffer(),
                                                          N,
                                                          K,
                                                          C,
                                                          input_spatial_lengths,
                                                          filter_spatial_lengths,
                                                          output_spatial_lengths,
                                                          conv_filter_strides,
                                                          conv_filter_dilations,
                                                          input_left_pads,
                                                          input_right_pads,
                                                          in_element_op,
                                                          wei_element_op,
                                                          out_element_op);
411
412
413
414
415
416
417

        auto invoker_ptr = conv_ptr->MakeInvokerPointer();

        if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
        {
            std::string conv_name = conv_ptr->GetTypeString();

JD's avatar
JD committed
418
419
            float ave_time =
                invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
420
421

            std::size_t flop =
422
423
424
425
                ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths);
            std::size_t num_btype =
                ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
                    N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths);
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442

            float tflops     = static_cast<float>(flop) / 1.E9 / ave_time;
            float gb_per_sec = num_btype / 1.E6 / ave_time;

            std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
                      << " GB/s" << std::endl;

            if(tflops > best_tflops)
            {
                best_conv_name  = conv_name;
                best_tflops     = tflops;
                best_ave_time   = ave_time;
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
443
                in_device_buf.FromDevice(input_device_result.data());
444

ltqin's avatar
ltqin committed
445
                if(!check_out(input_host_result, input_device_result))
446
447
448
449
450
451
452
453
454
455
                {
                    std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;

                    success = false;
                }
                else
                {
                    std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl;
                }

456
                success = ck::utils::check_err(input_host_result, input_device_result);
457
458
459
460

                if(do_log)
                {
                    std::cout << "in : ";
ltqin's avatar
ltqin committed
461
                    show_data_nhwc_layout(output);
462
463
464
                    std::cout << std::endl;

                    std::cout << "wei: ";
ltqin's avatar
ltqin committed
465
                    show_data_nhwc_layout(weights);
466
467
468
                    std::cout << std::endl;

                    std::cout << "out_host  : ";
ltqin's avatar
ltqin committed
469
                    show_data_nhwc_layout(input_host_result);
470
471
472
                    std::cout << std::endl;

                    std::cout << "out_device: ";
ltqin's avatar
ltqin committed
473
                    show_data_nhwc_layout(input_device_result);
474
475
476
477
478
479
480
481
482
483
484
485
486
                    std::cout << std::endl;
                }
            }
        }
    }

    std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
              << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl;
    return success;
}

} // namespace profiler
} // namespace ck