profile_conv_fwd_impl.hpp 12.4 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <iomanip>
#include <iostream>
#include <typeinfo>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/tensor_operation_instance/gpu/convolution_forward.hpp"

#include "ck/library/utility/check_err.hpp"
Chao Liu's avatar
Chao Liu committed
18
19
20
21
22
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
Chao Liu's avatar
Chao Liu committed
23
24
25
26

namespace ck {
namespace profiler {

Chao Liu's avatar
Chao Liu committed
27
// FIXME: only support NCHW and NHWC layout, need to be more general
Chao Liu's avatar
Chao Liu committed
28
29
30
31
32
33
34
35
36
37
38
template <ck::index_t NumDimSpatial,
          typename InLayout,
          typename WeiLayout,
          typename OutLayout,
          typename InDataType,
          typename WeiDataType,
          typename OutDataType>
int profile_conv_fwd_impl(int do_verification,
                          int init_method,
                          bool do_log,
                          bool time_kernel,
Chao Liu's avatar
Chao Liu committed
39
                          const ck::tensor_operation::device::ConvParams& params)
Chao Liu's avatar
Chao Liu committed
40
41
42
{
    bool pass = true;

Chao Liu's avatar
Chao Liu committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    // make host tensor descritpor
    auto f_nhwc_host_tensor_descriptor =
        [](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
            std::vector<std::size_t> nhwc_lengths{static_cast<std::size_t>(n),
                                                  static_cast<std::size_t>(c)};
            nhwc_lengths.insert(
                nhwc_lengths.begin() + 1, spatial_lengths.begin(), spatial_lengths.end());

            return HostTensorDescriptor(nhwc_lengths);
        };

    auto f_nchw_host_tensor_descriptor =
        [](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
            std::vector<std::size_t> nchw_lengths{static_cast<std::size_t>(n),
                                                  static_cast<std::size_t>(c)};
            nchw_lengths.insert(nchw_lengths.end(), spatial_lengths.begin(), spatial_lengths.end());

            return HostTensorDescriptor(nchw_lengths);
Chao Liu's avatar
Chao Liu committed
61
62
        };

Chao Liu's avatar
Chao Liu committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    HostTensorDescriptor in_desc, wei_desc, out_desc;

    // FIXME: properly implement "make host descriptor" for different layout
    if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NWC> ||
                 is_same_v<InLayout, ck::tensor_layout::convolution::NHWC> ||
                 is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
    {
        in_desc =
            f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
    }
    else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NCW> ||
                      is_same_v<InLayout, ck::tensor_layout::convolution::NCHW> ||
                      is_same_v<InLayout, ck::tensor_layout::convolution::NCDHW>)
    {
        in_desc =
            f_nchw_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
    }

    // FIXME: properly implement "make host descriptor" for different layout
    if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
                 is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
                 is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
    {
        wei_desc =
            f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
    }
    else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KCX> ||
                      is_same_v<WeiLayout, ck::tensor_layout::convolution::KCYX> ||
                      is_same_v<WeiLayout, ck::tensor_layout::convolution::KCZYX>)
    {
        wei_desc =
            f_nchw_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
    }

    // FIXME: properly implement "make host descriptor" for different layout
    if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NWK> ||
                 is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK> ||
                 is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
    {
        out_desc =
            f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
    }
    else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NKW> ||
                      is_same_v<OutLayout, ck::tensor_layout::convolution::NKHW> ||
                      is_same_v<OutLayout, ck::tensor_layout::convolution::NKDHW>)
    {
        out_desc =
            f_nchw_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
    }

    Tensor<InDataType> input(in_desc);
    Tensor<WeiDataType> weight(wei_desc);
    Tensor<OutDataType> host_output(out_desc);
    Tensor<OutDataType> device_output(out_desc);
Chao Liu's avatar
Chao Liu committed
117

Chao Liu's avatar
Chao Liu committed
118
119
120
    std::cout << "input: " << input.mDesc << std::endl;
    std::cout << "weight: " << weight.mDesc << std::endl;
    std::cout << "output: " << host_output.mDesc << std::endl;
Chao Liu's avatar
Chao Liu committed
121
122
123
124
125

    switch(init_method)
    {
    case 0: break;
    case 1:
Chao Liu's avatar
Chao Liu committed
126
127
        input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
        weight.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
Chao Liu's avatar
Chao Liu committed
128
129
        break;
    default:
Chao Liu's avatar
Chao Liu committed
130
131
        input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
        weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
Chao Liu's avatar
Chao Liu committed
132
133
    }

Chao Liu's avatar
Chao Liu committed
134
135
136
    using InElementOp  = ck::tensor_operation::element_wise::PassThrough;
    using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
    using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
Chao Liu's avatar
Chao Liu committed
137

Chao Liu's avatar
Chao Liu committed
138
139
140
    const auto in_element_op  = InElementOp{};
    const auto wei_element_op = WeiElementOp{};
    const auto out_element_op = OutElementOp{};
Chao Liu's avatar
Chao Liu committed
141

Chao Liu's avatar
Chao Liu committed
142
143
144
    DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
    DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpace());
    DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace());
Chao Liu's avatar
Chao Liu committed
145

Chao Liu's avatar
Chao Liu committed
146
147
    in_device_buf.ToDevice(input.mData.data());
    wei_device_buf.ToDevice(weight.mData.data());
Chao Liu's avatar
Chao Liu committed
148

Chao Liu's avatar
Chao Liu committed
149
150
151
152
153
154
155
156
157
158
    using DeviceOp = ck::tensor_operation::device::DeviceConvFwd<NumDimSpatial,
                                                                 InLayout,
                                                                 WeiLayout,
                                                                 OutLayout,
                                                                 InDataType,
                                                                 WeiDataType,
                                                                 OutDataType,
                                                                 InElementOp,
                                                                 WeiElementOp,
                                                                 OutElementOp>;
Chao Liu's avatar
Chao Liu committed
159
160
161
162
163
164
165

    // get device op instances
    const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
        DeviceOp>::GetInstances();

    std::cout << "found " << op_ptrs.size() << " instances" << std::endl;

Chao Liu's avatar
Chao Liu committed
166
    // run reference op
Chao Liu's avatar
Chao Liu committed
167
168
    if(do_verification)
    {
Chao Liu's avatar
Chao Liu committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NumDimSpatial,
                                                                     InLayout,
                                                                     WeiLayout,
                                                                     OutLayout,
                                                                     InDataType,
                                                                     WeiDataType,
                                                                     OutDataType,
                                                                     InElementOp,
                                                                     WeiElementOp,
                                                                     OutElementOp>{};

        auto ref_invoker  = ref_conv.MakeInvoker();
        auto ref_argument = ref_conv.MakeArgument(input,
                                                  weight,
                                                  host_output,
                                                  params.conv_filter_strides_,
                                                  params.conv_filter_dilations_,
                                                  params.input_left_pads_,
                                                  params.input_right_pads_,
                                                  in_element_op,
                                                  wei_element_op,
                                                  out_element_op);

        // init host output to zero
        host_output.SetZero();
Chao Liu's avatar
Chao Liu committed
194
195
196
197
198

        ref_invoker.Run(ref_argument);
    }

    std::string best_op_name;
Chao Liu's avatar
Chao Liu committed
199
    float best_avg_time   = 0;
Chao Liu's avatar
Chao Liu committed
200
201
202
    float best_tflops     = 0;
    float best_gb_per_sec = 0;

Chao Liu's avatar
Chao Liu committed
203
    // profile device op instances
Chao Liu's avatar
Chao Liu committed
204
205
206
    for(auto& op_ptr : op_ptrs)
    {
        auto argument_ptr =
Chao Liu's avatar
Chao Liu committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
            op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
                                        static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
                                        static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
                                        params.N_,
                                        params.K_,
                                        params.C_,
                                        params.input_spatial_lengths_,
                                        params.filter_spatial_lengths_,
                                        params.GetOutputSpatialLengths(),
                                        params.conv_filter_strides_,
                                        params.conv_filter_dilations_,
                                        params.input_left_pads_,
                                        params.input_right_pads_,
                                        in_element_op,
                                        wei_element_op,
                                        out_element_op);
Chao Liu's avatar
Chao Liu committed
223
224
225
226
227

        auto invoker_ptr = op_ptr->MakeInvokerPointer();

        if(op_ptr->IsSupportedArgument(argument_ptr.get()))
        {
Chao Liu's avatar
Chao Liu committed
228
229
            // re-init output to zero before profiling next kernel
            out_device_buf.SetZero();
Chao Liu's avatar
Chao Liu committed
230
231
232

            std::string op_name = op_ptr->GetTypeString();

Chao Liu's avatar
Chao Liu committed
233
            float avg_time =
Chao Liu's avatar
Chao Liu committed
234
235
                invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});

Chao Liu's avatar
Chao Liu committed
236
237
            std::size_t flop      = params.GetFlops();
            std::size_t num_btype = params.GetByte<InDataType, WeiDataType, OutDataType>();
Chao Liu's avatar
Chao Liu committed
238

Chao Liu's avatar
Chao Liu committed
239
            float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
Chao Liu's avatar
Chao Liu committed
240

Chao Liu's avatar
Chao Liu committed
241
            float gb_per_sec = num_btype / 1.E6 / avg_time;
Chao Liu's avatar
Chao Liu committed
242

Chao Liu's avatar
Chao Liu committed
243
            std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
Chao Liu's avatar
Chao Liu committed
244
245
246
247
248
249
                      << gb_per_sec << " GB/s, " << op_name << std::endl;

            if(tflops > best_tflops)
            {
                best_op_name    = op_name;
                best_tflops     = tflops;
Chao Liu's avatar
Chao Liu committed
250
                best_avg_time   = avg_time;
Chao Liu's avatar
Chao Liu committed
251
252
253
254
255
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
Chao Liu's avatar
Chao Liu committed
256
                out_device_buf.FromDevice(device_output.mData.data());
Chao Liu's avatar
Chao Liu committed
257

Chao Liu's avatar
Chao Liu committed
258
                pass = pass & ck::utils::check_err(device_output.mData, host_output.mData);
Chao Liu's avatar
Chao Liu committed
259
260
261

                if(do_log)
                {
Chao Liu's avatar
Chao Liu committed
262
263
264
                    LogRangeAsType<float>(std::cout << "input : ", input.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "weight: ", weight.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "host_output  : ", host_output.mData, ",")
Chao Liu's avatar
Chao Liu committed
265
                        << std::endl;
Chao Liu's avatar
Chao Liu committed
266
                    LogRangeAsType<float>(std::cout << "device_output: ", device_output.mData, ",")
Chao Liu's avatar
Chao Liu committed
267
268
269
270
271
272
273
274
275
276
                        << std::endl;
                }
            }
        }
        else
        {
            std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
        }
    }

Chao Liu's avatar
Chao Liu committed
277
278
279
    std::cout << "Best configuration parameters:"
              << "\nname: " << best_op_name << "\navg_time: " << best_avg_time
              << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
Chao Liu's avatar
Chao Liu committed
280

Chao Liu's avatar
Chao Liu committed
281
    return 0;
Chao Liu's avatar
Chao Liu committed
282
283
284
285
}

} // namespace profiler
} // namespace ck