client_app_impl.hpp 6.6 KB
Newer Older
Jehandad Khan's avatar
Jehandad Khan committed
1
2
#pragma once

Jehandad Khan's avatar
Jehandad Khan committed
3
#include "host_interface.hpp"
Jehandad Khan's avatar
Jehandad Khan committed
4

5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
enum ConvDataType
{
    F32_F32_F32,    // 0
    F16_F16_F16,    // 1
    BF16_BF16_BF16, // 2
    INT8_INT8_INT8, // 3
};

enum ConvInputLayout
{
    NCHW, // 0
    NHWC, // 1
};

enum ConvWeightLayout
{
    KCYX, // 0
    KYXC, // 1
};

enum ConvOutputLayout
{
    NKHW, // 0
    NHWK, // 1
};
Jehandad Khan's avatar
Jehandad Khan committed
30

31
void check_hip_error(void)
32
33
{
    hipError_t err = hipGetLastError();
Jehandad Khan's avatar
Jehandad Khan committed
34
    if(err != hipSuccess)
35
    {
Jehandad Khan's avatar
Jehandad Khan committed
36
37
        std::cerr << "Error: " << hipGetErrorString(err) << std::endl;
        exit(err);
38
39
40
41
42
43
    }
}
std::string getDeviceName(int device)
{
    struct hipDeviceProp_t prop;
    hipGetDeviceProperties(&prop, device);
44
    check_hip_error();
45
46
47
48
49
50
51
    return std::string(prop.name);
}

int getDriver(void)
{
    int driver;
    hipDriverGetVersion(&driver);
52
    check_hip_error();
53
54
55
    return driver;
}

Jehandad Khan's avatar
Jehandad Khan committed
56
57
namespace ck {
namespace app {
Jehandad Khan's avatar
Jehandad Khan committed
58
59
struct DeviceMem
{
Jehandad Khan's avatar
Jehandad Khan committed
60
61
62
63
64
65
66
67
68
    DeviceMem() = delete;
    DeviceMem(std::size_t mem_size);
    void* GetDeviceBuffer();
    void ToDevice(const void* p);
    void FromDevice(void* p);
    ~DeviceMem();

    void* mpDeviceBuf;
    std::size_t mMemSize;
Jehandad Khan's avatar
Jehandad Khan committed
69
};
Jehandad Khan's avatar
Jehandad Khan committed
70

Jehandad Khan's avatar
Jehandad Khan committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
    hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}

void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; }

void DeviceMem::ToDevice(const void* p)
{
    hipGetErrorString(
        hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}

void DeviceMem::FromDevice(void* p)
{
    hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}

DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
Jehandad Khan's avatar
Jehandad Khan committed
90
91
92
93
94

void profile_conv_fwd_impl(int do_verification,
                           int init_method,
                           bool do_log,
                           int nrepeat,
95
                           ConvDataType data_type,
Jehandad Khan's avatar
Jehandad Khan committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
                           ck::index_t N,
                           ck::index_t K,
                           ck::index_t C,
                           std::vector<ck::index_t> input_spatial_lengths,
                           std::vector<ck::index_t> filter_spatial_lengths,
                           std::vector<ck::index_t> output_spatial_lengths,
                           std::vector<ck::index_t> conv_filter_strides,
                           std::vector<ck::index_t> conv_filter_dilations,
                           std::vector<ck::index_t> input_left_pads,
                           std::vector<ck::index_t> input_right_pads)
{
    const ck::index_t Y = filter_spatial_lengths[0];
    const ck::index_t X = filter_spatial_lengths[1];

    const ck::index_t Hi = input_spatial_lengths[0];
    const ck::index_t Wi = input_spatial_lengths[1];

    const ck::index_t Ho = output_spatial_lengths[0];
    const ck::index_t Wo = output_spatial_lengths[1];

Jehandad Khan's avatar
Jehandad Khan committed
116
    const auto in_sz  = N * C * Hi * Wi;
117
118
    const auto wei_sz = K * C * Y * X;
    const auto out_sz = N * K * Ho * Wo;
Jehandad Khan's avatar
Jehandad Khan committed
119

Jehandad Khan's avatar
Jehandad Khan committed
120
    using WeiDataType = float;
Jehandad Khan's avatar
Jehandad Khan committed
121
    using InDataType  = float;
Jehandad Khan's avatar
Jehandad Khan committed
122
123
    using OutDataType = float;

Jehandad Khan's avatar
Jehandad Khan committed
124
125
126
    app::DeviceMem in_device_buf(sizeof(InDataType) * in_sz);
    app::DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_sz);
    app::DeviceMem out_device_buf(sizeof(OutDataType) * out_sz);
Jehandad Khan's avatar
Jehandad Khan committed
127
128
129
    // data is already on device!

    // add device Conv instances
Jehandad Khan's avatar
Jehandad Khan committed
130
    std::vector<DeviceConvFwdPtr_t> conv_ptrs;
131
132
133
134
135
136
137
138
139
140
141
142
143
    if(data_type == F16_F16_F16)
    {
        add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(conv_ptrs);
        add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(conv_ptrs);
    }
    else if(data_type == BF16_BF16_BF16)
        add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(conv_ptrs);
    else if(data_type == F32_F32_F32)
        add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(conv_ptrs);
    else if(data_type == INT8_INT8_INT8)
        add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(conv_ptrs);
    else
        throw std::runtime_error("wrong! Invalid data type");
Jehandad Khan's avatar
Jehandad Khan committed
144
    if(conv_ptrs.empty())
Jehandad Khan's avatar
Jehandad Khan committed
145
146
147
148
149
150
151
152
    {
        throw std::runtime_error("wrong! no device Conv instance found");
    }

    std::string best_conv_name;
    float best_ave_time   = 0;
    float best_tflops     = 0;
    float best_gb_per_sec = 0;
Jehandad Khan's avatar
Jehandad Khan committed
153
    int deviceIndex       = 0;
154
    hipSetDevice(deviceIndex);
155
    check_hip_error();
156
157
158

    hipStream_t stream_id = nullptr;
    hipStreamCreate(&stream_id);
159
    check_hip_error();
Jehandad Khan's avatar
Jehandad Khan committed
160
161
162
163

    // profile device Conv instances
    for(auto& conv_ptr : conv_ptrs)
    {
Jehandad Khan's avatar
Jehandad Khan committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        auto argument_ptr =
            conv_ptr.MakeArgumentPointer(static_cast<void*>(in_device_buf.GetDeviceBuffer()),
                                         static_cast<void*>(wei_device_buf.GetDeviceBuffer()),
                                         static_cast<void*>(out_device_buf.GetDeviceBuffer()),
                                         N,
                                         K,
                                         C,
                                         input_spatial_lengths,
                                         filter_spatial_lengths,
                                         output_spatial_lengths,
                                         conv_filter_strides,
                                         conv_filter_dilations,
                                         input_left_pads,
                                         input_right_pads);
Jehandad Khan's avatar
Jehandad Khan committed
178

Jehandad Khan's avatar
Jehandad Khan committed
179
        auto invoker_ptr = conv_ptr.MakeInvokerPointer();
Jehandad Khan's avatar
Jehandad Khan committed
180

Jehandad Khan's avatar
Jehandad Khan committed
181
        if(conv_ptr.IsSupportedArgument(argument_ptr.get()))
Jehandad Khan's avatar
Jehandad Khan committed
182
        {
Jehandad Khan's avatar
Jehandad Khan committed
183
            std::string conv_name = conv_ptr.GetTypeString();
Jehandad Khan's avatar
Jehandad Khan committed
184
            float ave_time        = invoker_ptr->Run(argument_ptr.get(), nrepeat, stream_id, true);
Jehandad Khan's avatar
Jehandad Khan committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

            std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;

            std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) +
                                    sizeof(WeiDataType) * (K * C * Y * X) +
                                    sizeof(OutDataType) * (N * K * Ho * Wo);

            float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

            float gb_per_sec = num_btype / 1.E6 / ave_time;

            std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
                      << " GB/s, " << conv_name << std::endl;

            if(tflops > best_tflops)
            {
                best_conv_name  = conv_name;
                best_tflops     = tflops;
                best_ave_time   = ave_time;
                best_gb_per_sec = gb_per_sec;
            }
        }
    }

    std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
              << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl;
}

Jehandad Khan's avatar
Jehandad Khan committed
213
} // namespace app
Jehandad Khan's avatar
Jehandad Khan committed
214
} // namespace ck