quantization.cpp 44.4 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
6
#include <migraphx/generate.hpp>
7
#include <migraphx/ref/target.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
8
9
10
#include <migraphx/verify.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/dead_code_elimination.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
11
#include <migraphx/propagate_constant.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
12
13
14
15
16
#include <migraphx/pass_manager.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
#include <migraphx/half.hpp>

kahmed10's avatar
kahmed10 committed
17
18
19
migraphx::instruction_ref
create_clip_op(migraphx::program& p, float max, float min, migraphx::instruction_ref input)
{
20
    auto* mm        = p.get_main_module();
kahmed10's avatar
kahmed10 committed
21
    auto input_lens = input->get_shape().lens();
22
23
24
25
26
    auto max_val    = mm->add_literal(max);
    auto min_val    = mm->add_literal(min);
    max_val         = mm->add_instruction(migraphx::op::multibroadcast{input_lens}, max_val);
    min_val         = mm->add_instruction(migraphx::op::multibroadcast{input_lens}, min_val);
    return mm->add_instruction(migraphx::op::clip{}, input, min_val, max_val);
kahmed10's avatar
kahmed10 committed
27
28
29
30
31
32
33
34
}

migraphx::instruction_ref create_clip_op(migraphx::instruction_ref insert_loc,
                                         migraphx::program& p,
                                         float max,
                                         float min,
                                         migraphx::instruction_ref input)
{
35
    auto* mm        = p.get_main_module();
kahmed10's avatar
kahmed10 committed
36
    auto input_lens = input->get_shape().lens();
37
38
39
40
41
    auto max_val    = mm->add_literal(max);
    auto min_val    = mm->add_literal(min);
    max_val = mm->insert_instruction(insert_loc, migraphx::op::multibroadcast{input_lens}, max_val);
    min_val = mm->insert_instruction(insert_loc, migraphx::op::multibroadcast{input_lens}, min_val);
    return mm->insert_instruction(insert_loc, migraphx::op::clip{}, input, min_val, max_val);
kahmed10's avatar
kahmed10 committed
42
43
}

Shucai Xiao's avatar
Shucai Xiao committed
44
45
TEST_CASE(param_add)
{
46
    auto create_program_float = [](bool add_return = false) {
Shucai Xiao's avatar
Shucai Xiao committed
47
        migraphx::program p;
48
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
49
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
50
51
52
        auto p1  = mm->add_parameter("x", s);
        auto p2  = mm->add_parameter("y", s);
        auto sum = mm->add_instruction(migraphx::op::add{}, p1, p2);
53
54
        if(add_return)
        {
55
            mm->add_return({sum});
56
        }
Shucai Xiao's avatar
Shucai Xiao committed
57
58
59
60

        return p;
    };

61
    auto create_program_half = [](bool add_return = false) {
Shucai Xiao's avatar
Shucai Xiao committed
62
        migraphx::program p;
63
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
64
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
65
66
67
68
69
70
        auto p1  = mm->add_parameter("x", s);
        auto hp1 = mm->insert_instruction(std::next(p1), migraphx::op::convert{}, p1);
        auto p2  = mm->add_parameter("y", s);
        auto hp2 = mm->insert_instruction(std::next(p2), migraphx::op::convert{}, p2);
        auto hs  = mm->add_instruction(migraphx::op::add{}, hp1, hp2);
        auto res = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hs);
71
72
        if(add_return)
        {
73
            mm->add_return({res});
74
        }
Shucai Xiao's avatar
Shucai Xiao committed
75
76
77
78
79
80
81
82

        return p;
    };

    {
        auto p1 = create_program_float();
        auto p2 = create_program_half();

Shucai Xiao's avatar
Shucai Xiao committed
83
        migraphx::quantize_fp16(p1);
Shucai Xiao's avatar
Shucai Xiao committed
84
85
86
87
88
89
90
        EXPECT(p1 == p2);
    }

    {
        auto p1 = create_program_float();
        auto p2 = create_program_half();

Shucai Xiao's avatar
Shucai Xiao committed
91
        migraphx::quantize_fp16(p1, {"add"});
Shucai Xiao's avatar
Shucai Xiao committed
92
93
        EXPECT(p1 == p2);
    }
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    {
        auto p1 = create_program_float(true);
        auto p2 = create_program_half(true);

        migraphx::quantize_fp16(p1);
        EXPECT(p1 == p2);
    }

    {
        auto p1 = create_program_float(true);
        auto p2 = create_program_half(true);

        migraphx::quantize_fp16(p1, {"add"});
        EXPECT(p1 == p2);
    }
Shucai Xiao's avatar
Shucai Xiao committed
110
111
112
113
114
115
}

TEST_CASE(param_add_sub)
{
    auto create_program_float = [] {
        migraphx::program p;
116
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
117
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
118
119
120
121
122
        auto p1   = mm->add_parameter("x", s);
        auto p2   = mm->add_parameter("y", s);
        auto sum  = mm->add_instruction(migraphx::op::add{}, p1, p2);
        auto diff = mm->add_instruction(migraphx::op::sub{}, sum, p2);
        mm->add_instruction(migraphx::op::add{}, diff, p1);
Shucai Xiao's avatar
Shucai Xiao committed
123
124
125
126
127
128

        return p;
    };

    auto create_program_half_add = [] {
        migraphx::program p;
129
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
130
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
131
132
        auto p1  = mm->add_parameter("x", s);
        auto hp1 = mm->insert_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
133
            std::next(p1), migraphx::op::convert{migraphx::shape::half_type}, p1);
134
135
        auto p2  = mm->add_parameter("y", s);
        auto hp2 = mm->insert_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
136
            std::next(p2), migraphx::op::convert{migraphx::shape::half_type}, p2);
137
138
139
140
        auto hsum  = mm->add_instruction(migraphx::op::add{}, hp1, hp2);
        auto sum   = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hsum);
        auto diff  = mm->add_instruction(migraphx::op::sub{}, sum, p2);
        auto hdiff = mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
141
            migraphx::op::convert{migraphx::op::convert{migraphx::shape::half_type}}, diff);
142
143
        auto res = mm->add_instruction(migraphx::op::add{}, hdiff, hp1);
        mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, res);
Shucai Xiao's avatar
Shucai Xiao committed
144
145
146
147
148
149

        return p;
    };

    auto create_program_half_sub = [] {
        migraphx::program p;
150
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
151
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
152
153
154
        auto p1  = mm->add_parameter("x", s);
        auto p2  = mm->add_parameter("y", s);
        auto hp2 = mm->insert_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
155
            std::next(p2), migraphx::op::convert{migraphx::shape::half_type}, p2);
156
157
158
159
160
        auto sum   = mm->add_instruction(migraphx::op::add{}, p1, p2);
        auto hsum  = mm->add_instruction(migraphx::op::convert{migraphx::shape::half_type}, sum);
        auto hdiff = mm->add_instruction(migraphx::op::sub{}, hsum, hp2);
        auto diff  = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hdiff);
        mm->add_instruction(migraphx::op::add{}, diff, p1);
Shucai Xiao's avatar
Shucai Xiao committed
161
162
163
164

        return p;
    };

165
166
    auto create_program_half_all = [] {
        migraphx::program p;
167
        auto* mm = p.get_main_module();
168
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
169
170
        auto p1  = mm->add_parameter("x", s);
        auto hp1 = mm->insert_instruction(
171
            std::next(p1), migraphx::op::convert{migraphx::shape::half_type}, p1);
172
173
        auto p2  = mm->add_parameter("y", s);
        auto hp2 = mm->insert_instruction(
174
            std::next(p2), migraphx::op::convert{migraphx::shape::half_type}, p2);
175
176
177
178
        auto hsum  = mm->add_instruction(migraphx::op::add{}, hp1, hp2);
        auto hdiff = mm->add_instruction(migraphx::op::sub{}, hsum, hp2);
        auto hres  = mm->add_instruction(migraphx::op::add{}, hdiff, hp1);
        mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hres);
179
180
181
182

        return p;
    };

Shucai Xiao's avatar
Shucai Xiao committed
183
184
185
186
    {
        auto p1 = create_program_float();
        auto p2 = create_program_half_add();

Shucai Xiao's avatar
Shucai Xiao committed
187
        migraphx::quantize_fp16(p1, {"add"});
Shucai Xiao's avatar
Shucai Xiao committed
188
189
190
191
192
193
194
        EXPECT(p1 == p2);
    }

    {
        auto p1 = create_program_float();
        auto p2 = create_program_half_sub();

Shucai Xiao's avatar
Shucai Xiao committed
195
        migraphx::quantize_fp16(p1, {"sub"});
Shucai Xiao's avatar
Shucai Xiao committed
196
197
        EXPECT(p1 == p2);
    }
198
199
200
201
202

    {
        auto p1 = create_program_float();
        auto p2 = create_program_half_all();

Shucai Xiao's avatar
Shucai Xiao committed
203
        migraphx::quantize_fp16(p1);
204
        migraphx::run_passes(*p1.get_main_module(), {migraphx::dead_code_elimination{}});
Shucai Xiao's avatar
Shucai Xiao committed
205

206
207
        EXPECT(p1 == p2);
    }
Shucai Xiao's avatar
Shucai Xiao committed
208
209
210
211
212
213
}

TEST_CASE(literal_add)
{
    auto create_program_float = [] {
        migraphx::program p;
214
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
215
216
217
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
        std::vector<float> data(2 * 3);
        std::iota(data.begin(), data.end(), 1.0f);
218
219
220
        auto l1 = mm->add_literal(migraphx::literal(s, data));
        auto l2 = mm->add_literal(migraphx::literal(s, data));
        mm->add_instruction(migraphx::op::add{}, l1, l2);
Shucai Xiao's avatar
Shucai Xiao committed
221
222
223
224
225
226

        return p;
    };

    auto create_program_half = [] {
        migraphx::program p;
227
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
228
229
230
        migraphx::shape s{migraphx::shape::half_type, {2, 3}};
        std::vector<migraphx::half> data(2 * 3);
        std::iota(data.begin(), data.end(), 1.0f);
231
232
233
234
        auto l1 = mm->add_literal(migraphx::literal(s, data));
        auto l2 = mm->add_literal(migraphx::literal(s, data));
        auto hs = mm->add_instruction(migraphx::op::add{}, l1, l2);
        mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hs);
Shucai Xiao's avatar
Shucai Xiao committed
235
236
237
238
239
240
241
242

        return p;
    };

    {
        auto p1 = create_program_float();
        auto p2 = create_program_half();

Shucai Xiao's avatar
Shucai Xiao committed
243
        migraphx::quantize_fp16(p1, {"all"});
244
        migraphx::run_passes(*p1.get_main_module(),
Shucai Xiao's avatar
Shucai Xiao committed
245
                             {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
246
        migraphx::run_passes(*p2.get_main_module(),
Shucai Xiao's avatar
Shucai Xiao committed
247
                             {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
Shucai Xiao's avatar
Shucai Xiao committed
248
249
250
251
252
253
254
255

        EXPECT(p1 == p2);
    }

    {
        auto p1 = create_program_float();
        auto p2 = create_program_half();

Shucai Xiao's avatar
Shucai Xiao committed
256
        migraphx::quantize_fp16(p1, {"add"});
257
        migraphx::run_passes(*p1.get_main_module(),
Shucai Xiao's avatar
Shucai Xiao committed
258
                             {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
259
        migraphx::run_passes(*p2.get_main_module(),
Shucai Xiao's avatar
Shucai Xiao committed
260
                             {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
Shucai Xiao's avatar
Shucai Xiao committed
261
262
263
264
        EXPECT(p1 == p2);
    }
}

265
266
TEST_CASE(op_capture)
{
Shucai Xiao's avatar
Shucai Xiao committed
267
    auto test_func = [&](std::size_t ins_index, const std::vector<migraphx::argument>& args) {
268
269
270
271
272
273
        (void)ins_index;
        (void)args;
    };

    auto create_program_float = [] {
        migraphx::program p;
274
        auto* mm = p.get_main_module();
275
276
277
        migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
        migraphx::shape s2{migraphx::shape::float_type, {3, 6}};

278
279
280
281
282
283
284
        auto p1 = mm->add_parameter("x", s1);
        auto p2 = mm->add_parameter("y", s1);
        auto pb = mm->add_parameter("b", s2);
        auto pc = mm->add_parameter("c", s2);
        auto pa = mm->add_instruction(migraphx::op::add{}, p1, p2);
        auto ps = mm->add_instruction(migraphx::op::dot{}, pa, pb, pc);
        mm->add_instruction(migraphx::op::dot{}, pa, ps);
285
286
287
288
289
290

        return p;
    };

    auto create_program_op = [&] {
        migraphx::program p;
291
        auto* mm = p.get_main_module();
292
293
294
        migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
        migraphx::shape s2{migraphx::shape::float_type, {3, 6}};

295
296
297
298
299
300
301
302
303
304
305
        auto p1  = mm->add_parameter("x", s1);
        auto p2  = mm->add_parameter("y", s1);
        auto pb  = mm->add_parameter("b", s2);
        auto pc  = mm->add_parameter("c", s2);
        auto pa  = mm->add_instruction(migraphx::op::add{}, p1, p2);
        auto opb = mm->insert_instruction(std::next(pb), migraphx::op::capture{1, test_func}, pb);
        auto opc = mm->insert_instruction(std::next(pc), migraphx::op::capture{2, test_func}, pc);
        auto opa = mm->add_instruction(migraphx::op::capture{0, test_func}, pa);
        auto ps  = mm->add_instruction(migraphx::op::dot{}, opa, opb, opc);
        auto ops = mm->add_instruction(migraphx::op::capture{3, test_func}, ps);
        mm->add_instruction(migraphx::op::dot{}, opa, ops);
306
307
308
309
310

        return p;
    };

    {
Shucai Xiao's avatar
Shucai Xiao committed
311
312
        auto p             = create_program_float();
        auto op_capture_p  = create_program_op();
313
        migraphx::target t = migraphx::ref::target{};
Shucai Xiao's avatar
Shucai Xiao committed
314
        migraphx::capture_arguments(p, t, {"dot", "convolution"});
315
316
317
318
        EXPECT(p == op_capture_p);
    }
}

319
320
321
322
TEST_CASE(dot_float)
{
    auto create_program = [] {
        migraphx::program p;
323
        auto* mm = p.get_main_module();
324
325
326
        migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
327
328
329
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
330

331
        mm->add_instruction(migraphx::op::dot{2.0f, 1.5f}, pa, pb, pc);
332
333
334
335
336
337

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
338
        auto* mm = p.get_main_module();
339
340
341
        migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
342
343
344
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
345
346
        // quantize parameter a to int8 type, multiply the scale
        std::vector<float> vfa(sa.elements(), 0.1f);
347
348
349
        auto fa = mm->add_literal(migraphx::literal(sa, vfa));
        auto ma = mm->add_instruction(migraphx::op::mul{}, fa, pa);
        auto ra = mm->add_instruction(migraphx::op::round{}, ma);
kahmed10's avatar
kahmed10 committed
350
        auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
351
        auto qa = mm->add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
352
353
354
355

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pb);
        std::vector<float> vfb(sb.elements(), 0.1f);
356
357
358
        auto fb = mm->add_literal(migraphx::literal(sb, vfb));
        auto mb = mm->insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb);
        auto rb = mm->insert_instruction(insert_loc, migraphx::op::round{}, mb);
kahmed10's avatar
kahmed10 committed
359
        auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
360
361
        auto qb = mm->insert_instruction(
            insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
362

363
364
        auto qdot = mm->add_instruction(migraphx::op::quant_dot{1, 0}, qa, qb);
        auto fdot = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot);
365
        std::vector<float> v_alpha(fdot->get_shape().elements(), 200.0f);
366
367
        auto new_alpha = mm->add_literal(migraphx::literal(fdot->get_shape(), v_alpha));
        auto alpha_ab  = mm->add_instruction(migraphx::op::mul{}, new_alpha, fdot);
368
        std::vector<float> v_beta(pc->get_shape().elements(), 1.5f);
369
370
371
        auto beta   = mm->add_literal(migraphx::literal(pc->get_shape(), v_beta));
        auto beta_c = mm->add_instruction(migraphx::op::mul{}, beta, pc);
        mm->add_instruction(migraphx::op::add{}, alpha_ab, beta_c);
372
373
374
375
376
377
378

        return p;
    };

    auto p = create_program();
    const std::vector<std::pair<float, float>>& quant_params{
        {0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
379
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
380
    migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
381

382
383
384
385
386
387
388
389
390
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

TEST_CASE(dot_double_2args)
{
    auto create_program = [] {
        migraphx::program p;
391
        auto* mm = p.get_main_module();
392
393
        migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
394
395
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
396

397
        mm->add_instruction(migraphx::op::dot{2.0f, 1.5f}, pa, pb);
398
399
400
401
402
403

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
404
        auto* mm = p.get_main_module();
405
406
407
        migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::double_type, {2, 8}};
408
409
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
410
411
        // quantize parameter a to int8 type, multiply the scale
        std::vector<float> vfa(sa.elements(), 0.1f);
412
413
414
415
416
417
        auto fpa = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa);
        auto fa = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
        auto ma = mm->add_instruction(migraphx::op::mul{}, fa, fpa);
        auto ra = mm->add_instruction(migraphx::op::round{}, ma);
        auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
        auto qa = mm->add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
418
419
420

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pb);
421
        auto fpb        = mm->insert_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
422
            insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb);
423
        std::vector<float> vfb(sb.elements(), 0.1f);
424
425
426
        auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
        auto mb = mm->insert_instruction(insert_loc, migraphx::op::mul{}, fb, fpb);
        auto rb = mm->insert_instruction(insert_loc, migraphx::op::round{}, mb);
kahmed10's avatar
kahmed10 committed
427
        auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
428
429
        auto qb = mm->insert_instruction(
            insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
430

431
432
        auto qdot = mm->add_instruction(migraphx::op::quant_dot{1, 0}, qa, qb);
        auto fdot = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot);
433
        std::vector<float> v_alpha(fdot->get_shape().elements(), 200.0f);
434
435
436
        auto new_alpha = mm->add_literal(migraphx::literal(fdot->get_shape(), v_alpha));
        auto alpha_ab  = mm->add_instruction(migraphx::op::mul{}, new_alpha, fdot);
        mm->add_instruction(migraphx::op::convert{migraphx::shape::double_type}, alpha_ab);
437
438
439
440
441

        return p;
    };

    auto p = create_program();
Shucai Xiao's avatar
Shucai Xiao committed
442
    const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
443
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
444
445
446
447
448
449
450
451
452
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

TEST_CASE(dot_large_alpha_beta_float)
{
    auto create_program = [] {
        migraphx::program p;
453
        auto* mm = p.get_main_module();
454
455
456
        migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
457
458
459
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
460

461
        mm->add_instruction(migraphx::op::dot{20.0f, 50.5f}, pa, pb, pc);
462
463
464
465
466
467

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
468
        auto* mm = p.get_main_module();
469
470
471
        migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
472
473
474
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
475
476
        // quantize parameter a to int8 type, multiply the scale
        std::vector<float> vfa(sa.elements(), 0.1f);
477
478
        auto fa = mm->add_literal(migraphx::literal(sa, vfa));
        auto ma = mm->add_instruction(migraphx::op::mul{}, fa, pa);
479
480
        // add the shift
        std::vector<float> vsa(sa.elements(), 1.0f);
481
482
483
        auto sfta = mm->add_literal(migraphx::literal(sa, vsa));
        auto msa  = mm->add_instruction(migraphx::op::add{}, sfta, ma);
        auto ra   = mm->add_instruction(migraphx::op::round{}, msa);
kahmed10's avatar
kahmed10 committed
484
        auto ca   = create_clip_op(p, 127.0f, -128.0f, ra);
485
        auto qa   = mm->add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
486
487
488
489

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pb);
        std::vector<float> vfb(sb.elements(), 0.1f);
490
491
492
        auto fb = mm->add_literal(migraphx::literal(sb, vfb));
        auto mb = mm->insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb);
        auto rb = mm->insert_instruction(insert_loc, migraphx::op::round{}, mb);
kahmed10's avatar
kahmed10 committed
493
        auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
494
495
        auto qb = mm->insert_instruction(
            insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
496
497

        // quantize parameter c to int32 type
498
        auto qc = mm->insert_instruction(
499
500
            std::next(pc), migraphx::op::convert{migraphx::shape::int32_type}, pc);

501
502
        auto qdot = mm->add_instruction(migraphx::op::quant_dot{2000, 51}, qa, qb, qc);
        mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot);
503
504
505
506
507
508
509

        return p;
    };

    auto p = create_program();
    const std::vector<std::pair<float, float>>& quant_params{
        {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
510
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
511
512
513
514
515
516
517
518
519
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

TEST_CASE(dot_large_alpha_beta_int32)
{
    auto create_program = [] {
        migraphx::program p;
520
        auto* mm = p.get_main_module();
521
522
523
        migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
524
525
526
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
527

528
        mm->add_instruction(migraphx::op::dot{20.0f, 50.0f}, pa, pb, pc);
529
530
531
532
533
534

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
535
        auto* mm = p.get_main_module();
536
537
538
        migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
539
540
541
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
542
543
        // quantize parameter a to int8 type, multiply the scale
        std::vector<float> vfa(sa.elements(), 0.1f);
544
545
546
        auto fa = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
        auto conv_a = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa);
        auto ma     = mm->add_instruction(migraphx::op::mul{}, fa, conv_a);
547
548
549

        // add the shift
        std::vector<float> vsa(sa.elements(), 1.0f);
550
551
552
553
554
555
        auto sfta =
            mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa));
        auto msa = mm->add_instruction(migraphx::op::add{}, sfta, ma);
        auto ra  = mm->add_instruction(migraphx::op::round{}, msa);
        auto ca  = create_clip_op(p, 127.0f, -128.0f, ra);
        auto qa  = mm->add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
556
557
558
559

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pb);
        std::vector<float> vfb(sb.elements(), 0.1f);
560
561
        auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
        auto conv_b = mm->insert_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
562
            insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb);
563
564
        auto mb = mm->insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b);
        auto rb = mm->insert_instruction(insert_loc, migraphx::op::round{}, mb);
kahmed10's avatar
kahmed10 committed
565
        auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
566
567
        auto qb = mm->insert_instruction(
            insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
568

569
        mm->add_instruction(migraphx::op::quant_dot{2000, 50}, qa, qb, pc);
570
571
572
573
574

        return p;
    };

    auto p = create_program();
Shucai Xiao's avatar
Shucai Xiao committed
575
576
    const std::vector<std::pair<float, float>>& quant_params{
        {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
577
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
578
579
580
581
582
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

583
584
585
586
TEST_CASE(dot_int32_one_arg)
{
    auto create_program = [] {
        migraphx::program p;
587
        auto* mm = p.get_main_module();
588
        migraphx::shape s{migraphx::shape::int32_type, {16, 16}};
589
        auto pa = mm->add_parameter("a", s);
590

591
        mm->add_instruction(migraphx::op::dot{20.0f, 50.0f}, pa, pa);
592
593
594
595
596
597

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
598
        auto* mm = p.get_main_module();
599
        migraphx::shape s{migraphx::shape::int32_type, {16, 16}};
600
        auto pa = mm->add_parameter("a", s);
601
602

        // add the shift
603
        auto fpa = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa);
604
        std::vector<float> vsa(s.elements(), 1.0f);
605
606
607
608
609
610
        auto sfta =
            mm->add_literal(migraphx::literal({migraphx::shape::float_type, s.lens()}, vsa));
        auto msa = mm->add_instruction(migraphx::op::add{}, sfta, fpa);
        auto ra  = mm->add_instruction(migraphx::op::round{}, msa);
        auto ca  = create_clip_op(p, 127.0f, -128.0f, ra);
        auto qa  = mm->add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
611

612
613
        auto q_dot = mm->add_instruction(migraphx::op::quant_dot{1, 0}, qa, qa);
        auto f_dot = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, q_dot);
614
        std::vector<float> v_alpha(f_dot->get_shape().elements(), 20.0f);
615
616
617
        auto new_alpha = mm->add_literal(migraphx::literal{f_dot->get_shape(), v_alpha});
        auto alpha_ab  = mm->add_instruction(migraphx::op::mul{}, new_alpha, f_dot);
        mm->add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, alpha_ab);
618
619
620
621
622
623

        return p;
    };

    auto p = create_program();
    const std::vector<std::pair<float, float>>& quant_params{{1.0f, 1.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
624
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
625
626
627
628
629
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

630
631
TEST_CASE(dot_int32)
{
632
    auto create_program = [](bool add_return = false) {
633
        migraphx::program p;
634
        auto* mm = p.get_main_module();
635
636
637
        migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
638
639
640
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
641

642
        auto res = mm->add_instruction(migraphx::op::dot{2.0f, 5.5f}, pa, pb, pc);
643
644
        if(add_return)
        {
645
            mm->add_return({res});
646
        }
647
648
649
650

        return p;
    };

651
    auto create_int8_quantized_prog = [](bool add_return = false) {
652
        migraphx::program p;
653
        auto* mm = p.get_main_module();
654
655
656
        migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
657
658
659
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
660
661
        // quantize parameter a to int8 type, multiply the scale
        std::vector<float> vfa(sa.elements(), 0.1f);
662
663
664
        auto fa = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
        auto conv_a = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa);
        auto ma     = mm->add_instruction(migraphx::op::mul{}, fa, conv_a);
665
666
667

        // add the shift
        std::vector<float> vsa(sa.elements(), 1.0f);
668
669
670
671
672
673
        auto sfta =
            mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa));
        auto msa = mm->add_instruction(migraphx::op::add{}, sfta, ma);
        auto ra  = mm->add_instruction(migraphx::op::round{}, msa);
        auto ca  = create_clip_op(p, 127.0f, -128.0f, ra);
        auto qa  = mm->add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
674
675
676
677

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pb);
        std::vector<float> vfb(sb.elements(), 0.1f);
678
679
        auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
        auto conv_b = mm->insert_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
680
            insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb);
681
682
        auto mb = mm->insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b);
        auto rb = mm->insert_instruction(insert_loc, migraphx::op::round{}, mb);
kahmed10's avatar
kahmed10 committed
683
        auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
684
685
        auto qb = mm->insert_instruction(
            insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
686

687
688
        auto qdot = mm->add_instruction(migraphx::op::quant_dot{1, 0}, qa, qb);
        auto fr   = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot);
689
        std::vector<float> v_alpha(fr->get_shape().elements(), 20.0f);
690
691
692
        auto new_alpha = mm->add_literal(migraphx::literal(fr->get_shape(), v_alpha));
        auto alpha_ab  = mm->add_instruction(migraphx::op::mul{}, new_alpha, fr);
        auto fc = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pc);
693
        std::vector<float> v_beta(fc->get_shape().elements(), 5.5f);
694
695
696
697
        auto beta   = mm->add_literal(migraphx::literal(fc->get_shape(), v_beta));
        auto beta_c = mm->add_instruction(migraphx::op::mul{}, beta, fc);
        auto f_res  = mm->add_instruction(migraphx::op::add{}, alpha_ab, beta_c);
        auto res = mm->add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, f_res);
698
699
        if(add_return)
        {
700
            mm->add_return({res});
701
        }
702
703
704
705
706

        return p;
    };

    auto p = create_program();
Shucai Xiao's avatar
Shucai Xiao committed
707
708
    const std::vector<std::pair<float, float>>& quant_params{
        {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
709
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
710
711
    auto qp = create_int8_quantized_prog();
    EXPECT(p == qp);
712
713
714
715
716

    auto p_ret = create_program(true);
    migraphx::quantize_int8_impl(p_ret, quant_params, {"dot"});
    auto qp_ret = create_int8_quantized_prog(true);
    EXPECT(p_ret == qp_ret);
717
718
}

719
720
721
722
TEST_CASE(dot_float_convert)
{
    auto create_program = [] {
        migraphx::program p;
723
        auto* mm = p.get_main_module();
724
725
        migraphx::shape sa{migraphx::shape::int8_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
726
727
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
728

729
730
        auto fpa = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa);
        mm->add_instruction(migraphx::op::dot{2.0f, 5.5f}, fpa, pb);
731
732
733
734
735
736

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
737
        auto* mm = p.get_main_module();
738
739
        migraphx::shape sa{migraphx::shape::int8_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
740
741
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
742
743
744
745

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pb);
        std::vector<float> vfb(sb.elements(), 0.1f);
746
747
748
        auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
        auto mb = mm->insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb);
        auto rb = mm->insert_instruction(insert_loc, migraphx::op::round{}, mb);
kahmed10's avatar
kahmed10 committed
749
        auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
750
751
        auto qb = mm->insert_instruction(
            insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
752

753
754
        auto qdot = mm->add_instruction(migraphx::op::quant_dot{1, 0}, pa, qb);
        auto fr   = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot);
755
        std::vector<float> v_alpha(fr->get_shape().elements(), 10.0f);
756
757
        auto new_alpha = mm->add_literal(migraphx::literal(fr->get_shape(), v_alpha));
        mm->add_instruction(migraphx::op::mul{}, new_alpha, fr);
758
759
760
761
762

        return p;
    };

    auto p = create_program();
Shucai Xiao's avatar
Shucai Xiao committed
763
    const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
764
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
765
    migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
766
767
768
769
770
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

771
772
TEST_CASE(conv_float)
{
Shucai Xiao's avatar
Shucai Xiao committed
773
    auto create_program = [] {
774
        migraphx::program p;
775
        auto* mm = p.get_main_module();
776
        auto input =
777
            mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
778
        auto weights =
779
780
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
        mm->add_instruction(migraphx::op::convolution{}, input, weights);
781
782
783
784
785
786

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
787
        auto* mm = p.get_main_module();
788
789
        migraphx::shape sx{migraphx::shape::float_type, {4, 3, 3, 3}};
        migraphx::shape sw{migraphx::shape::float_type, {4, 3, 3, 3}};
790
791
        auto px = mm->add_parameter("x", sx);
        auto pw = mm->add_parameter("w", sw);
792
793
        // quantize parameter a to int8 type, multiply the scale
        std::vector<float> vfx(sx.elements(), 0.1f);
794
795
796
        auto fx = mm->add_literal(migraphx::literal(sx, vfx));
        auto mx = mm->add_instruction(migraphx::op::mul{}, fx, px);
        auto rx = mm->add_instruction(migraphx::op::round{}, mx);
kahmed10's avatar
kahmed10 committed
797
        auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
798
        auto qx = mm->add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx);
799
800
801
802

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pw);
        std::vector<float> vfw(sw.elements(), 0.1f);
803
804
805
        auto fw = mm->add_literal(migraphx::literal(sw, vfw));
        auto mw = mm->insert_instruction(insert_loc, migraphx::op::mul{}, fw, pw);
        auto rw = mm->insert_instruction(insert_loc, migraphx::op::round{}, mw);
kahmed10's avatar
kahmed10 committed
806
        auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
807
808
        auto qw = mm->insert_instruction(
            insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw);
809

810
811
812
        auto q_conv = mm->add_instruction(migraphx::op::quant_convolution{}, qx, qw);
        auto f_conv =
            mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, q_conv);
813
        std::vector<float> v_adj(f_conv->get_shape().elements(), 100.0f);
814
815
        auto adj = mm->add_literal(migraphx::literal(f_conv->get_shape(), v_adj));
        mm->add_instruction(migraphx::op::mul{}, adj, f_conv);
816
817
818
819
820
821

        return p;
    };

    auto p = create_program();
    const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
822
    migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
823
824
825
826
827
828
829
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

TEST_CASE(conv_int32)
{
Shucai Xiao's avatar
Shucai Xiao committed
830
    auto create_program = [] {
831
        migraphx::program p;
832
        auto* mm = p.get_main_module();
833
        auto input =
834
            mm->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {4, 3, 3, 3}});
835
        auto weights =
836
837
            mm->add_parameter("w", migraphx::shape{migraphx::shape::int32_type, {4, 3, 3, 3}});
        mm->add_instruction(migraphx::op::convolution{}, input, weights);
838
839
840
841
842
843

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
844
845

        auto* mm = p.get_main_module();
846
847
        migraphx::shape sx{migraphx::shape::int32_type, {4, 3, 3, 3}};
        migraphx::shape sw{migraphx::shape::int32_type, {4, 3, 3, 3}};
848
849
        auto px = mm->add_parameter("x", sx);
        auto pw = mm->add_parameter("w", sw);
850
        // quantize parameter a to int8 type, multiply the scale
851
        auto fpx = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, px);
852
        std::vector<float> vfx(sx.elements(), 0.1f);
853
854
855
        auto fx = mm->add_literal(migraphx::literal(fpx->get_shape(), vfx));
        auto mx = mm->add_instruction(migraphx::op::mul{}, fx, fpx);
        auto rx = mm->add_instruction(migraphx::op::round{}, mx);
kahmed10's avatar
kahmed10 committed
856
        auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
857
        auto qx = mm->add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx);
858
859
860

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pw);
861
        auto fpw        = mm->insert_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
862
            insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pw);
863
        std::vector<float> vfw(sw.elements(), 0.1f);
864
865
866
        auto fw = mm->add_literal(migraphx::literal(fpw->get_shape(), vfw));
        auto mw = mm->insert_instruction(insert_loc, migraphx::op::mul{}, fw, fpw);
        auto rw = mm->insert_instruction(insert_loc, migraphx::op::round{}, mw);
kahmed10's avatar
kahmed10 committed
867
        auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
868
869
        auto qw = mm->insert_instruction(
            insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw);
870

871
        auto q_conv = mm->add_instruction(migraphx::op::quant_convolution{}, qx, qw);
872
        std::vector<float> v_adj(q_conv->get_shape().elements(), 100.0f);
873
874
        auto adj = mm->add_literal(migraphx::literal(q_conv->get_shape(), v_adj));
        mm->add_instruction(migraphx::op::mul{}, q_conv, adj);
875
876
877
878
879
880

        return p;
    };

    auto p = create_program();
    const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
881
    migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
882
883
884
885
886
887
888
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

TEST_CASE(conv_half)
{
Shucai Xiao's avatar
Shucai Xiao committed
889
    auto create_program = [] {
890
        migraphx::program p;
891
        auto* mm = p.get_main_module();
892
        auto input =
893
            mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}});
894
        auto weights =
895
896
            mm->add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}});
        mm->add_instruction(migraphx::op::convolution{}, input, weights);
897
898
899
900
901
902

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
903
        auto* mm = p.get_main_module();
904
905
        migraphx::shape sx{migraphx::shape::half_type, {4, 3, 3, 3}};
        migraphx::shape sw{migraphx::shape::half_type, {4, 3, 3, 3}};
906
907
        auto px = mm->add_parameter("x", sx);
        auto pw = mm->add_parameter("w", sw);
908
        // quantize parameter a to int8 type, multiply the scale
909
        auto fpx = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, px);
910
        std::vector<float> vfx(sx.elements(), 0.1f);
911
912
913
        auto fx = mm->add_literal(migraphx::literal(fpx->get_shape(), vfx));
        auto mx = mm->add_instruction(migraphx::op::mul{}, fx, fpx);
        auto rx = mm->add_instruction(migraphx::op::round{}, mx);
kahmed10's avatar
kahmed10 committed
914
        auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
915
        auto qx = mm->add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx);
916
917
918

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pw);
919
        auto fpw        = mm->insert_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
920
            insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pw);
921
        std::vector<float> vfw(sw.elements(), 0.1f);
922
923
924
        auto fw = mm->add_literal(migraphx::literal(fpw->get_shape(), vfw));
        auto mw = mm->insert_instruction(insert_loc, migraphx::op::mul{}, fw, fpw);
        auto rw = mm->insert_instruction(insert_loc, migraphx::op::round{}, mw);
kahmed10's avatar
kahmed10 committed
925
        auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
926
927
        auto qw = mm->insert_instruction(
            insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw);
928

929
930
931
        auto q_conv = mm->add_instruction(migraphx::op::quant_convolution{}, qx, qw);
        auto f_conv =
            mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, q_conv);
932
        std::vector<float> v_adj(f_conv->get_shape().elements(), 100.0f);
933
934
935
        auto adj   = mm->add_literal(migraphx::literal(f_conv->get_shape(), v_adj));
        auto f_res = mm->add_instruction(migraphx::op::mul{}, adj, f_conv);
        mm->add_instruction(migraphx::op::convert{migraphx::shape::half_type}, f_res);
936
937
938
939
940
941

        return p;
    };

    auto p = create_program();
    const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
942
    migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
943
944
945
946
947
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

948
949
950
951
TEST_CASE(target_copy)
{
    auto run_prog = [](migraphx::program p,
                       const migraphx::target& t,
952
                       migraphx::parameter_map& m_in,
953
954
                       std::vector<float>& res) {
        p.compile(t);
955
        migraphx::parameter_map m;
956
957
958
959
960
961
962
963
964
965
966
967
        for(auto&& x : p.get_parameter_shapes())
        {
            if(m_in.count(x.first) > 0)
            {
                m[x.first] = t.copy_to(m_in[x.first]);
            }
            else
            {
                m[x.first] = t.allocate(x.second);
            }
        }

968
        auto result = t.copy_from(p.eval(m).back());
969
970
971
972
973
        result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
    };

    auto create_program = [] {
        migraphx::program p;
974
        auto* mm = p.get_main_module();
975
        migraphx::shape s{migraphx::shape::float_type, {3, 3}};
976
977
978
        auto p1 = mm->add_parameter("x", s);
        auto p2 = mm->add_parameter("y", s);
        mm->add_instruction(migraphx::op::add{}, p1, p2);
979
980
981
982
983
984

        return p;
    };

    {
        auto p = create_program();
985
        migraphx::parameter_map m;
986
987
        migraphx::shape s{migraphx::shape::float_type, {3, 3}};
        m["x"] = migraphx::generate_argument(s);
988
989
990
        std::vector<float> ref_result;
        migraphx::target ref_t = migraphx::ref::target{};
        run_prog(p, ref_t, m, ref_result);
991
992

        std::vector<float> orig_result;
993
        run_prog(p, ref_t, m, orig_result);
994

995
        EXPECT(migraphx::verify_range(ref_result, orig_result));
996
997
998
    }
}

999
TEST_CASE(int8_quantization_dot)
1000
1001
1002
{
    auto run_prog = [](migraphx::program p,
                       const migraphx::target& t,
1003
                       migraphx::parameter_map& m_in,
1004
1005
                       std::vector<float>& res,
                       bool b_quantize = false) {
Shucai Xiao's avatar
Shucai Xiao committed
1006
        if(b_quantize)
1007
        {
1008
            std::vector<migraphx::parameter_map> cali_data;
1009
1010
1011
1012
            cali_data.push_back(m_in);
            migraphx::quantize_int8(p, t, cali_data);
        }
        p.compile(t);
1013
        migraphx::parameter_map m;
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
        for(auto&& x : p.get_parameter_shapes())
        {
            if(m_in.count(x.first) > 0)
            {
                m[x.first] = t.copy_to(m_in[x.first]);
            }
            else
            {
                m[x.first] = t.allocate(x.second);
            }
        }

1026
        auto result = t.copy_from(p.eval(m).back());
1027
1028
1029
1030
1031
        result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
    };

    auto create_program = [] {
        migraphx::program p;
1032
        auto* mm = p.get_main_module();
1033
1034
1035
        migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
1036
1037
1038
1039
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
        mm->add_instruction(migraphx::op::dot{}, pa, pb, pc);
1040
1041
1042
1043
1044
1045

        return p;
    };

    {
        auto p = create_program();
1046
        migraphx::parameter_map m;
1047
1048
1049
1050
1051
        migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
        migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
        m["a"] = migraphx::generate_argument(sa);
        m["c"] = migraphx::generate_argument(sc);
        std::vector<float> quant_result;
1052
1053
        migraphx::target ref_t = migraphx::ref::target{};
        run_prog(p, ref_t, m, quant_result, true);
1054
1055

        std::vector<float> no_quant_result;
1056
        run_prog(p, ref_t, m, no_quant_result);
1057
1058
1059
1060
1061

        EXPECT(migraphx::verify_range(quant_result, no_quant_result));
    }
}

1062
1063
1064
1065
1066
1067
1068
1069
TEST_CASE(int8_quantization_conv)
{
    auto run_prog = [](migraphx::program p,
                       const migraphx::target& t,
                       std::vector<float>& res,
                       bool b_quantize = false) {
        if(b_quantize)
        {
1070
            std::vector<migraphx::parameter_map> cali_data;
1071
1072
1073
            migraphx::quantize_int8(p, t, cali_data);
        }
        p.compile(t);
1074
        migraphx::parameter_map m;
1075

1076
        auto result = t.copy_from(p.eval(m).back());
1077
1078
1079
1080
1081
        result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
    };

    auto create_program = [] {
        migraphx::program p;
1082
        auto* mm = p.get_main_module();
1083
1084
1085
        migraphx::shape sx{migraphx::shape::float_type, {4, 2, 2, 2}};
        migraphx::shape sw{migraphx::shape::float_type, {4, 2, 2, 2}};
        std::vector<float> v(sx.elements(), 0.5f);
1086
1087
1088
        auto input   = mm->add_literal(migraphx::literal(sx, v));
        auto weights = mm->add_literal(migraphx::literal(sw, v));
        mm->add_instruction(migraphx::op::convolution{}, input, weights);
1089
1090
1091
1092
1093
1094
1095

        return p;
    };

    {
        auto p = create_program();
        std::vector<float> quant_result;
1096
1097
        migraphx::target ref_t = migraphx::ref::target{};
        run_prog(p, ref_t, quant_result, true);
1098
1099

        std::vector<float> no_quant_result;
1100
        run_prog(p, ref_t, no_quant_result);
1101
1102
1103
1104
1105

        EXPECT(migraphx::verify_range(quant_result, no_quant_result));
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
1106
int main(int argc, const char* argv[]) { test::run(argc, argv); }