client_app_impl.hpp 6.63 KB
Newer Older
Jehandad Khan's avatar
Jehandad Khan committed
1
2
#pragma once

Jehandad Khan's avatar
Jehandad Khan committed
3
#include "host_interface.hpp"
Jehandad Khan's avatar
Jehandad Khan committed
4

5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
enum ConvDataType
{
    F32_F32_F32,    // 0
    F16_F16_F16,    // 1
    BF16_BF16_BF16, // 2
    INT8_INT8_INT8, // 3
};

enum ConvInputLayout
{
    NCHW, // 0
    NHWC, // 1
};

enum ConvWeightLayout
{
    KCYX, // 0
    KYXC, // 1
};

enum ConvOutputLayout
{
    NKHW, // 0
    NHWK, // 1
};
Jehandad Khan's avatar
Jehandad Khan committed
30

31
32
33
34
// Code to check CUDA errors
void check_cuda_error(void)
{
    hipError_t err = hipGetLastError();
Jehandad Khan's avatar
Jehandad Khan committed
35
    if(err != hipSuccess)
36
    {
Jehandad Khan's avatar
Jehandad Khan committed
37
38
        std::cerr << "Error: " << hipGetErrorString(err) << std::endl;
        exit(err);
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    }
}
std::string getDeviceName(int device)
{
    struct hipDeviceProp_t prop;
    hipGetDeviceProperties(&prop, device);
    check_cuda_error();
    return std::string(prop.name);
}

int getDriver(void)
{
    int driver;
    hipDriverGetVersion(&driver);
    check_cuda_error();
    return driver;
}

Jehandad Khan's avatar
Jehandad Khan committed
57
58
namespace ck {
namespace app {
Jehandad Khan's avatar
Jehandad Khan committed
59
60
struct DeviceMem
{
Jehandad Khan's avatar
Jehandad Khan committed
61
62
63
64
65
66
67
68
69
    DeviceMem() = delete;
    DeviceMem(std::size_t mem_size);
    void* GetDeviceBuffer();
    void ToDevice(const void* p);
    void FromDevice(void* p);
    ~DeviceMem();

    void* mpDeviceBuf;
    std::size_t mMemSize;
Jehandad Khan's avatar
Jehandad Khan committed
70
};
Jehandad Khan's avatar
Jehandad Khan committed
71

Jehandad Khan's avatar
Jehandad Khan committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
    hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}

void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; }

void DeviceMem::ToDevice(const void* p)
{
    hipGetErrorString(
        hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}

void DeviceMem::FromDevice(void* p)
{
    hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}

DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
Jehandad Khan's avatar
Jehandad Khan committed
91
92
93
94
95

void profile_conv_fwd_impl(int do_verification,
                           int init_method,
                           bool do_log,
                           int nrepeat,
96
                           ConvDataType data_type,
Jehandad Khan's avatar
Jehandad Khan committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
                           ck::index_t N,
                           ck::index_t K,
                           ck::index_t C,
                           std::vector<ck::index_t> input_spatial_lengths,
                           std::vector<ck::index_t> filter_spatial_lengths,
                           std::vector<ck::index_t> output_spatial_lengths,
                           std::vector<ck::index_t> conv_filter_strides,
                           std::vector<ck::index_t> conv_filter_dilations,
                           std::vector<ck::index_t> input_left_pads,
                           std::vector<ck::index_t> input_right_pads)
{
    const ck::index_t Y = filter_spatial_lengths[0];
    const ck::index_t X = filter_spatial_lengths[1];

    const ck::index_t Hi = input_spatial_lengths[0];
    const ck::index_t Wi = input_spatial_lengths[1];

    const ck::index_t Ho = output_spatial_lengths[0];
    const ck::index_t Wo = output_spatial_lengths[1];

Jehandad Khan's avatar
Jehandad Khan committed
117
    const auto in_sz  = N * C * Hi * Wi;
118
119
    const auto wei_sz = K * C * Y * X;
    const auto out_sz = N * K * Ho * Wo;
Jehandad Khan's avatar
Jehandad Khan committed
120

Jehandad Khan's avatar
Jehandad Khan committed
121
    using WeiDataType = float;
Jehandad Khan's avatar
Jehandad Khan committed
122
    using InDataType  = float;
Jehandad Khan's avatar
Jehandad Khan committed
123
124
    using OutDataType = float;

Jehandad Khan's avatar
Jehandad Khan committed
125
126
127
    app::DeviceMem in_device_buf(sizeof(InDataType) * in_sz);
    app::DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_sz);
    app::DeviceMem out_device_buf(sizeof(OutDataType) * out_sz);
Jehandad Khan's avatar
Jehandad Khan committed
128
129
130
    // data is already on device!

    // add device Conv instances
Jehandad Khan's avatar
Jehandad Khan committed
131
    std::vector<DeviceConvFwdPtr_t> conv_ptrs;
132
133
134
135
136
137
138
139
140
141
142
143
144
    if(data_type == F16_F16_F16)
    {
        add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(conv_ptrs);
        add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(conv_ptrs);
    }
    else if(data_type == BF16_BF16_BF16)
        add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(conv_ptrs);
    else if(data_type == F32_F32_F32)
        add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(conv_ptrs);
    else if(data_type == INT8_INT8_INT8)
        add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(conv_ptrs);
    else
        throw std::runtime_error("wrong! Invalid data type");
Jehandad Khan's avatar
Jehandad Khan committed
145
    if(conv_ptrs.empty())
Jehandad Khan's avatar
Jehandad Khan committed
146
147
148
149
150
151
152
153
    {
        throw std::runtime_error("wrong! no device Conv instance found");
    }

    std::string best_conv_name;
    float best_ave_time   = 0;
    float best_tflops     = 0;
    float best_gb_per_sec = 0;
Jehandad Khan's avatar
Jehandad Khan committed
154
    int deviceIndex       = 0;
155
156
157
158
159
160
    hipSetDevice(deviceIndex);
    check_cuda_error();

    hipStream_t stream_id = nullptr;
    hipStreamCreate(&stream_id);
    check_cuda_error();
Jehandad Khan's avatar
Jehandad Khan committed
161
162
163
164

    // profile device Conv instances
    for(auto& conv_ptr : conv_ptrs)
    {
Jehandad Khan's avatar
Jehandad Khan committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        auto argument_ptr =
            conv_ptr.MakeArgumentPointer(static_cast<void*>(in_device_buf.GetDeviceBuffer()),
                                         static_cast<void*>(wei_device_buf.GetDeviceBuffer()),
                                         static_cast<void*>(out_device_buf.GetDeviceBuffer()),
                                         N,
                                         K,
                                         C,
                                         input_spatial_lengths,
                                         filter_spatial_lengths,
                                         output_spatial_lengths,
                                         conv_filter_strides,
                                         conv_filter_dilations,
                                         input_left_pads,
                                         input_right_pads);
Jehandad Khan's avatar
Jehandad Khan committed
179

Jehandad Khan's avatar
Jehandad Khan committed
180
        auto invoker_ptr = conv_ptr.MakeInvokerPointer();
Jehandad Khan's avatar
Jehandad Khan committed
181

Jehandad Khan's avatar
Jehandad Khan committed
182
        if(conv_ptr.IsSupportedArgument(argument_ptr.get()))
Jehandad Khan's avatar
Jehandad Khan committed
183
        {
Jehandad Khan's avatar
Jehandad Khan committed
184
            std::string conv_name = conv_ptr.GetTypeString();
Jehandad Khan's avatar
Jehandad Khan committed
185
            float ave_time        = invoker_ptr->Run(argument_ptr.get(), nrepeat, stream_id, true);
Jehandad Khan's avatar
Jehandad Khan committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

            std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;

            std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) +
                                    sizeof(WeiDataType) * (K * C * Y * X) +
                                    sizeof(OutDataType) * (N * K * Ho * Wo);

            float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

            float gb_per_sec = num_btype / 1.E6 / ave_time;

            std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
                      << " GB/s, " << conv_name << std::endl;

            if(tflops > best_tflops)
            {
                best_conv_name  = conv_name;
                best_tflops     = tflops;
                best_ave_time   = ave_time;
                best_gb_per_sec = gb_per_sec;
            }
        }
    }

    std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
              << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl;
}

Jehandad Khan's avatar
Jehandad Khan committed
214
} // namespace app
Jehandad Khan's avatar
Jehandad Khan committed
215
} // namespace ck