quantization.cpp 51.2 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
6
#include <migraphx/generate.hpp>
7
#include <migraphx/ref/target.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
8
9
10
#include <migraphx/verify.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/dead_code_elimination.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
11
#include <migraphx/propagate_constant.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
12
13
#include <migraphx/pass_manager.hpp>
#include <migraphx/onnx.hpp>
14
15
16
17
#include <migraphx/make_op.hpp>

#include <migraphx/serialize.hpp>

Shucai Xiao's avatar
Shucai Xiao committed
18
19
20
#include "test.hpp"
#include <migraphx/half.hpp>

kahmed10's avatar
kahmed10 committed
21
22
23
migraphx::instruction_ref
create_clip_op(migraphx::program& p, float max, float min, migraphx::instruction_ref input)
{
24
    auto* mm        = p.get_main_module();
kahmed10's avatar
kahmed10 committed
25
    auto input_lens = input->get_shape().lens();
26
27
    auto max_val    = mm->add_literal(max);
    auto min_val    = mm->add_literal(min);
28
29
30
31
    max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
                                  max_val);
    min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
                                  min_val);
32
    return mm->add_instruction(migraphx::make_op("clip"), input, min_val, max_val);
kahmed10's avatar
kahmed10 committed
33
34
35
36
37
38
39
40
}

migraphx::instruction_ref create_clip_op(migraphx::instruction_ref insert_loc,
                                         migraphx::program& p,
                                         float max,
                                         float min,
                                         migraphx::instruction_ref input)
{
41
    auto* mm        = p.get_main_module();
kahmed10's avatar
kahmed10 committed
42
    auto input_lens = input->get_shape().lens();
43
44
    auto max_val    = mm->add_literal(max);
    auto min_val    = mm->add_literal(min);
45
    max_val         = mm->insert_instruction(
46
        insert_loc, migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
47
    min_val = mm->insert_instruction(
48
        insert_loc, migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
49
    return mm->insert_instruction(insert_loc, migraphx::make_op("clip"), input, min_val, max_val);
kahmed10's avatar
kahmed10 committed
50
51
}

Shucai Xiao's avatar
Shucai Xiao committed
52
53
TEST_CASE(param_add)
{
54
    auto create_program_float = [](bool add_return = false) {
Shucai Xiao's avatar
Shucai Xiao committed
55
        migraphx::program p;
56
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
57
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
58
59
        auto p1  = mm->add_parameter("x", s);
        auto p2  = mm->add_parameter("y", s);
60
        auto sum = mm->add_instruction(migraphx::make_op("add"), p1, p2);
61
62
        if(add_return)
        {
63
            mm->add_return({sum});
64
        }
Shucai Xiao's avatar
Shucai Xiao committed
65
66
67
68

        return p;
    };

69
    auto create_program_half = [](bool add_return = false) {
Shucai Xiao's avatar
Shucai Xiao committed
70
        migraphx::program p;
71
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
72
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
73
        auto p1  = mm->add_parameter("x", s);
74
        auto hp1 = mm->insert_instruction(std::next(p1), migraphx::make_op("convert"), p1);
75
        auto p2  = mm->add_parameter("y", s);
76
77
78
79
80
81
        auto hp2 = mm->insert_instruction(std::next(p2), migraphx::make_op("convert"), p2);
        auto hs  = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
        auto res = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            hs);
82
83
        if(add_return)
        {
84
            mm->add_return({res});
85
        }
Shucai Xiao's avatar
Shucai Xiao committed
86
87
88
89
90
91
92
93

        return p;
    };

    {
        auto p1 = create_program_float();
        auto p2 = create_program_half();

Shucai Xiao's avatar
Shucai Xiao committed
94
        migraphx::quantize_fp16(p1);
Shucai Xiao's avatar
Shucai Xiao committed
95
96
97
98
99
100
101
        EXPECT(p1 == p2);
    }

    {
        auto p1 = create_program_float();
        auto p2 = create_program_half();

Shucai Xiao's avatar
Shucai Xiao committed
102
        migraphx::quantize_fp16(p1, {"add"});
Shucai Xiao's avatar
Shucai Xiao committed
103
104
        EXPECT(p1 == p2);
    }
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

    {
        auto p1 = create_program_float(true);
        auto p2 = create_program_half(true);

        migraphx::quantize_fp16(p1);
        EXPECT(p1 == p2);
    }

    {
        auto p1 = create_program_float(true);
        auto p2 = create_program_half(true);

        migraphx::quantize_fp16(p1, {"add"});
        EXPECT(p1 == p2);
    }
Shucai Xiao's avatar
Shucai Xiao committed
121
122
123
124
125
126
}

TEST_CASE(param_add_sub)
{
    auto create_program_float = [] {
        migraphx::program p;
127
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
128
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
129
130
        auto p1   = mm->add_parameter("x", s);
        auto p2   = mm->add_parameter("y", s);
131
132
133
        auto sum  = mm->add_instruction(migraphx::make_op("add"), p1, p2);
        auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2);
        mm->add_instruction(migraphx::make_op("add"), diff, p1);
Shucai Xiao's avatar
Shucai Xiao committed
134
135
136
137
138
139

        return p;
    };

    auto create_program_half_add = [] {
        migraphx::program p;
140
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
141
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
142
143
        auto p1  = mm->add_parameter("x", s);
        auto hp1 = mm->insert_instruction(
144
145
146
147
            std::next(p1),
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
            p1);
148
149
        auto p2  = mm->add_parameter("y", s);
        auto hp2 = mm->insert_instruction(
150
151
152
153
154
155
156
157
158
159
            std::next(p2),
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
            p2);
        auto hsum = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
        auto sum  = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            hsum);
        auto diff  = mm->add_instruction(migraphx::make_op("sub"), sum, p2);
160
        auto hdiff = mm->add_instruction(
161
162
163
164
165
166
167
168
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
            diff);
        auto res = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1);
        mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            res);
Shucai Xiao's avatar
Shucai Xiao committed
169
170
171
172
173
174

        return p;
    };

    auto create_program_half_sub = [] {
        migraphx::program p;
175
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
176
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
177
178
179
        auto p1  = mm->add_parameter("x", s);
        auto p2  = mm->add_parameter("y", s);
        auto hp2 = mm->insert_instruction(
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
            std::next(p2),
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
            p2);
        auto sum  = mm->add_instruction(migraphx::make_op("add"), p1, p2);
        auto hsum = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
            sum);
        auto hdiff = mm->add_instruction(migraphx::make_op("sub"), hsum, hp2);
        auto diff  = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            hdiff);
        mm->add_instruction(migraphx::make_op("add"), diff, p1);
Shucai Xiao's avatar
Shucai Xiao committed
195
196
197
198

        return p;
    };

199
200
    auto create_program_half_all = [] {
        migraphx::program p;
201
        auto* mm = p.get_main_module();
202
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
203
204
        auto p1  = mm->add_parameter("x", s);
        auto hp1 = mm->insert_instruction(
205
206
207
208
            std::next(p1),
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
            p1);
209
210
        auto p2  = mm->add_parameter("y", s);
        auto hp2 = mm->insert_instruction(
211
212
213
214
215
216
217
218
219
220
221
            std::next(p2),
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
            p2);
        auto hsum  = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
        auto hdiff = mm->add_instruction(migraphx::make_op("sub"), hsum, hp2);
        auto hres  = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1);
        mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            hres);
222
223
224
225

        return p;
    };

Shucai Xiao's avatar
Shucai Xiao committed
226
227
228
229
    {
        auto p1 = create_program_float();
        auto p2 = create_program_half_add();

Shucai Xiao's avatar
Shucai Xiao committed
230
        migraphx::quantize_fp16(p1, {"add"});
Shucai Xiao's avatar
Shucai Xiao committed
231
232
233
234
235
236
237
        EXPECT(p1 == p2);
    }

    {
        auto p1 = create_program_float();
        auto p2 = create_program_half_sub();

Shucai Xiao's avatar
Shucai Xiao committed
238
        migraphx::quantize_fp16(p1, {"sub"});
Shucai Xiao's avatar
Shucai Xiao committed
239
240
        EXPECT(p1 == p2);
    }
241
242
243
244
245

    {
        auto p1 = create_program_float();
        auto p2 = create_program_half_all();

Shucai Xiao's avatar
Shucai Xiao committed
246
        migraphx::quantize_fp16(p1);
247
        migraphx::run_passes(*p1.get_main_module(), {migraphx::dead_code_elimination{}});
Shucai Xiao's avatar
Shucai Xiao committed
248

249
250
        EXPECT(p1 == p2);
    }
Shucai Xiao's avatar
Shucai Xiao committed
251
252
253
254
255
256
}

TEST_CASE(literal_add)
{
    auto create_program_float = [] {
        migraphx::program p;
257
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
258
259
260
        migraphx::shape s{migraphx::shape::float_type, {2, 3}};
        std::vector<float> data(2 * 3);
        std::iota(data.begin(), data.end(), 1.0f);
261
262
        auto l1 = mm->add_literal(migraphx::literal(s, data));
        auto l2 = mm->add_literal(migraphx::literal(s, data));
263
        mm->add_instruction(migraphx::make_op("add"), l1, l2);
Shucai Xiao's avatar
Shucai Xiao committed
264
265
266
267
268
269

        return p;
    };

    auto create_program_half = [] {
        migraphx::program p;
270
        auto* mm = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
271
272
273
        migraphx::shape s{migraphx::shape::half_type, {2, 3}};
        std::vector<migraphx::half> data(2 * 3);
        std::iota(data.begin(), data.end(), 1.0f);
274
275
        auto l1 = mm->add_literal(migraphx::literal(s, data));
        auto l2 = mm->add_literal(migraphx::literal(s, data));
276
277
278
279
280
        auto hs = mm->add_instruction(migraphx::make_op("add"), l1, l2);
        mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            hs);
Shucai Xiao's avatar
Shucai Xiao committed
281
282
283
284
285
286
287
288

        return p;
    };

    {
        auto p1 = create_program_float();
        auto p2 = create_program_half();

Shucai Xiao's avatar
Shucai Xiao committed
289
        migraphx::quantize_fp16(p1, {"all"});
290
        migraphx::run_passes(*p1.get_main_module(),
Shucai Xiao's avatar
Shucai Xiao committed
291
                             {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
292
        migraphx::run_passes(*p2.get_main_module(),
Shucai Xiao's avatar
Shucai Xiao committed
293
                             {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
Shucai Xiao's avatar
Shucai Xiao committed
294
295
296
297
298
299
300
301

        EXPECT(p1 == p2);
    }

    {
        auto p1 = create_program_float();
        auto p2 = create_program_half();

Shucai Xiao's avatar
Shucai Xiao committed
302
        migraphx::quantize_fp16(p1, {"add"});
303
        migraphx::run_passes(*p1.get_main_module(),
Shucai Xiao's avatar
Shucai Xiao committed
304
                             {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
305
        migraphx::run_passes(*p2.get_main_module(),
Shucai Xiao's avatar
Shucai Xiao committed
306
                             {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
Shucai Xiao's avatar
Shucai Xiao committed
307
308
309
310
        EXPECT(p1 == p2);
    }
}

311
312
TEST_CASE(op_capture)
{
Shucai Xiao's avatar
Shucai Xiao committed
313
    auto test_func = [&](std::size_t ins_index, const std::vector<migraphx::argument>& args) {
314
315
316
317
318
319
        (void)ins_index;
        (void)args;
    };

    auto create_program_float = [] {
        migraphx::program p;
320
        auto* mm = p.get_main_module();
321
322
323
        migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
        migraphx::shape s2{migraphx::shape::float_type, {3, 6}};

324
325
326
327
        auto p1 = mm->add_parameter("x", s1);
        auto p2 = mm->add_parameter("y", s1);
        auto pb = mm->add_parameter("b", s2);
        auto pc = mm->add_parameter("c", s2);
328
329
330
        auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
        auto ps = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
        mm->add_instruction(migraphx::make_op("dot"), pa, ps);
331
332
333
334
335
336

        return p;
    };

    auto create_program_op = [&] {
        migraphx::program p;
337
        auto* mm = p.get_main_module();
338
339
340
        migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
        migraphx::shape s2{migraphx::shape::float_type, {3, 6}};

341
342
343
344
        auto p1  = mm->add_parameter("x", s1);
        auto p2  = mm->add_parameter("y", s1);
        auto pb  = mm->add_parameter("b", s2);
        auto pc  = mm->add_parameter("c", s2);
345
        auto pa  = mm->add_instruction(migraphx::make_op("add"), p1, p2);
346
347
348
        auto opb = mm->insert_instruction(std::next(pb), migraphx::op::capture{1, test_func}, pb);
        auto opc = mm->insert_instruction(std::next(pc), migraphx::op::capture{2, test_func}, pc);
        auto opa = mm->add_instruction(migraphx::op::capture{0, test_func}, pa);
349
        auto ps  = mm->add_instruction(migraphx::make_op("dot"), opa, opb, opc);
350
        auto ops = mm->add_instruction(migraphx::op::capture{3, test_func}, ps);
351
        mm->add_instruction(migraphx::make_op("dot"), opa, ops);
352
353
354
355
356

        return p;
    };

    {
Shucai Xiao's avatar
Shucai Xiao committed
357
358
        auto p             = create_program_float();
        auto op_capture_p  = create_program_op();
359
        migraphx::target t = migraphx::ref::target{};
Shucai Xiao's avatar
Shucai Xiao committed
360
        migraphx::capture_arguments(p, t, {"dot", "convolution"});
361
362
363
364
        EXPECT(p == op_capture_p);
    }
}

365
366
367
368
TEST_CASE(dot_float)
{
    auto create_program = [] {
        migraphx::program p;
369
        auto* mm = p.get_main_module();
370
371
372
        migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
373
374
375
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
376

377
378
        mm->add_instruction(
            migraphx::make_op("dot", {{"alpha", 2.0f}, {"beta", 1.5f}}), pa, pb, pc);
379
380
381
382
383
384

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
385
        auto* mm = p.get_main_module();
386
387
388
        migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
389
390
391
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
392
393
        // quantize parameter a to int8 type, multiply the scale
        std::vector<float> vfa(sa.elements(), 0.1f);
394
        auto fa = mm->add_literal(migraphx::literal(sa, vfa));
395
396
        auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, pa);
        auto ra = mm->add_instruction(migraphx::make_op("round"), ma);
kahmed10's avatar
kahmed10 committed
397
        auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
398
399
400
401
        auto qa = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            ca);
402
403
404
405

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pb);
        std::vector<float> vfb(sb.elements(), 0.1f);
406
        auto fb = mm->add_literal(migraphx::literal(sb, vfb));
407
408
        auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, pb);
        auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
kahmed10's avatar
kahmed10 committed
409
        auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
410
        auto qb = mm->insert_instruction(
411
412
413
414
415
416
417
418
419
420
421
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            cb);

        auto qdot = mm->add_instruction(
            migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qb);
        auto fdot = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            qdot);
422
        std::vector<float> v_alpha(fdot->get_shape().elements(), 200.0f);
423
        auto new_alpha = mm->add_literal(migraphx::literal(fdot->get_shape(), v_alpha));
424
        auto alpha_ab  = mm->add_instruction(migraphx::make_op("mul"), new_alpha, fdot);
425
        std::vector<float> v_beta(pc->get_shape().elements(), 1.5f);
426
        auto beta   = mm->add_literal(migraphx::literal(pc->get_shape(), v_beta));
427
428
        auto beta_c = mm->add_instruction(migraphx::make_op("mul"), beta, pc);
        mm->add_instruction(migraphx::make_op("add"), alpha_ab, beta_c);
429
430
431
432
433
434
435

        return p;
    };

    auto p = create_program();
    const std::vector<std::pair<float, float>>& quant_params{
        {0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
436
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
437
    migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
438

439
440
441
442
443
444
445
446
447
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

TEST_CASE(dot_double_2args)
{
    auto create_program = [] {
        migraphx::program p;
448
        auto* mm = p.get_main_module();
449
450
        migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
451
452
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
453

454
        mm->add_instruction(migraphx::make_op("dot", {{"alpha", 2.0f}, {"beta", 1.5f}}), pa, pb);
455
456
457
458
459
460

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
461
        auto* mm = p.get_main_module();
462
463
464
        migraphx::shape sa{migraphx::shape::double_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::double_type, {2, 8}};
465
466
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
467
468
        // quantize parameter a to int8 type, multiply the scale
        std::vector<float> vfa(sa.elements(), 0.1f);
469
470
471
472
        auto fpa = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            pa);
473
        auto fa = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
474
475
        auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, fpa);
        auto ra = mm->add_instruction(migraphx::make_op("round"), ma);
476
        auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
477
478
479
480
        auto qa = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            ca);
481
482
483

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pb);
484
        auto fpb        = mm->insert_instruction(
485
486
487
488
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            pb);
489
        std::vector<float> vfb(sb.elements(), 0.1f);
490
        auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
491
492
        auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, fpb);
        auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
kahmed10's avatar
kahmed10 committed
493
        auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
494
        auto qb = mm->insert_instruction(
495
496
497
498
499
500
501
502
503
504
505
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            cb);

        auto qdot = mm->add_instruction(
            migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qb);
        auto fdot = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            qdot);
506
        std::vector<float> v_alpha(fdot->get_shape().elements(), 200.0f);
507
        auto new_alpha = mm->add_literal(migraphx::literal(fdot->get_shape(), v_alpha));
508
509
510
511
512
        auto alpha_ab  = mm->add_instruction(migraphx::make_op("mul"), new_alpha, fdot);
        mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::double_type)}}),
            alpha_ab);
513
514
515
516
517

        return p;
    };

    auto p = create_program();
Shucai Xiao's avatar
Shucai Xiao committed
518
    const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
519
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
520
521
522
523
524
525
526
527
528
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

TEST_CASE(dot_large_alpha_beta_float)
{
    auto create_program = [] {
        migraphx::program p;
529
        auto* mm = p.get_main_module();
530
531
532
        migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
533
534
535
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
536

537
538
        mm->add_instruction(
            migraphx::make_op("dot", {{"alpha", 20.0f}, {"beta", 50.5f}}), pa, pb, pc);
539
540
541
542
543
544

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
545
        auto* mm = p.get_main_module();
546
547
548
        migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
549
550
551
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
552
553
        // quantize parameter a to int8 type, multiply the scale
        std::vector<float> vfa(sa.elements(), 0.1f);
554
        auto fa = mm->add_literal(migraphx::literal(sa, vfa));
555
        auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, pa);
556
557
        // add the shift
        std::vector<float> vsa(sa.elements(), 1.0f);
558
        auto sfta = mm->add_literal(migraphx::literal(sa, vsa));
559
560
        auto msa  = mm->add_instruction(migraphx::make_op("add"), sfta, ma);
        auto ra   = mm->add_instruction(migraphx::make_op("round"), msa);
kahmed10's avatar
kahmed10 committed
561
        auto ca   = create_clip_op(p, 127.0f, -128.0f, ra);
562
563
564
565
        auto qa   = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            ca);
566
567
568
569

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pb);
        std::vector<float> vfb(sb.elements(), 0.1f);
570
        auto fb = mm->add_literal(migraphx::literal(sb, vfb));
571
572
        auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, pb);
        auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
kahmed10's avatar
kahmed10 committed
573
        auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
574
        auto qb = mm->insert_instruction(
575
576
577
578
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            cb);
579
580

        // quantize parameter c to int32 type
581
        auto qc = mm->insert_instruction(
582
583
584
585
586
587
588
589
590
591
592
            std::next(pc),
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
            pc);

        auto qdot = mm->add_instruction(
            migraphx::make_op("quant_dot", {{"alpha", 2000}, {"beta", 51}}), qa, qb, qc);
        mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            qdot);
593
594
595
596
597
598
599

        return p;
    };

    auto p = create_program();
    const std::vector<std::pair<float, float>>& quant_params{
        {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
600
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
601
602
603
604
605
606
607
608
609
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

TEST_CASE(dot_large_alpha_beta_int32)
{
    auto create_program = [] {
        migraphx::program p;
610
        auto* mm = p.get_main_module();
611
612
613
        migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
614
615
616
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
617

618
619
        mm->add_instruction(
            migraphx::make_op("dot", {{"alpha", 20.0f}, {"beta", 50.0f}}), pa, pb, pc);
620
621
622
623
624
625

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
626
        auto* mm = p.get_main_module();
627
628
629
        migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
630
631
632
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
633
634
        // quantize parameter a to int8 type, multiply the scale
        std::vector<float> vfa(sa.elements(), 0.1f);
635
        auto fa = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
636
637
638
639
640
        auto conv_a = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            pa);
        auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, conv_a);
641
642
643

        // add the shift
        std::vector<float> vsa(sa.elements(), 1.0f);
644
645
        auto sfta =
            mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa));
646
647
        auto msa = mm->add_instruction(migraphx::make_op("add"), sfta, ma);
        auto ra  = mm->add_instruction(migraphx::make_op("round"), msa);
648
        auto ca  = create_clip_op(p, 127.0f, -128.0f, ra);
649
650
651
652
        auto qa  = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            ca);
653
654
655
656

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pb);
        std::vector<float> vfb(sb.elements(), 0.1f);
657
658
        auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
        auto conv_b = mm->insert_instruction(
659
660
661
662
663
664
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            pb);
        auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, conv_b);
        auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
kahmed10's avatar
kahmed10 committed
665
        auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
666
        auto qb = mm->insert_instruction(
667
668
669
670
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            cb);
671

672
673
        mm->add_instruction(
            migraphx::make_op("quant_dot", {{"alpha", 2000}, {"beta", 50}}), qa, qb, pc);
674
675
676
677
678

        return p;
    };

    auto p = create_program();
Shucai Xiao's avatar
Shucai Xiao committed
679
680
    const std::vector<std::pair<float, float>>& quant_params{
        {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
681
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
682
683
684
685
686
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

687
688
689
690
TEST_CASE(dot_int32_one_arg)
{
    auto create_program = [] {
        migraphx::program p;
691
        auto* mm = p.get_main_module();
692
        migraphx::shape s{migraphx::shape::int32_type, {16, 16}};
693
        auto pa = mm->add_parameter("a", s);
694

695
        mm->add_instruction(migraphx::make_op("dot", {{"alpha", 20.0f}, {"beta", 50.0f}}), pa, pa);
696
697
698
699
700
701

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
702
        auto* mm = p.get_main_module();
703
        migraphx::shape s{migraphx::shape::int32_type, {16, 16}};
704
        auto pa = mm->add_parameter("a", s);
705
706

        // add the shift
707
708
709
710
        auto fpa = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            pa);
711
        std::vector<float> vsa(s.elements(), 1.0f);
712
713
        auto sfta =
            mm->add_literal(migraphx::literal({migraphx::shape::float_type, s.lens()}, vsa));
714
715
        auto msa = mm->add_instruction(migraphx::make_op("add"), sfta, fpa);
        auto ra  = mm->add_instruction(migraphx::make_op("round"), msa);
716
        auto ca  = create_clip_op(p, 127.0f, -128.0f, ra);
717
718
719
720
721
722
723
724
725
726
727
        auto qa  = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            ca);

        auto q_dot = mm->add_instruction(
            migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qa);
        auto f_dot = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            q_dot);
728
        std::vector<float> v_alpha(f_dot->get_shape().elements(), 20.0f);
729
        auto new_alpha = mm->add_literal(migraphx::literal{f_dot->get_shape(), v_alpha});
730
731
732
733
734
        auto alpha_ab  = mm->add_instruction(migraphx::make_op("mul"), new_alpha, f_dot);
        mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
            alpha_ab);
735
736
737
738
739
740

        return p;
    };

    auto p = create_program();
    const std::vector<std::pair<float, float>>& quant_params{{1.0f, 1.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
741
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
742
743
744
745
746
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

747
748
TEST_CASE(dot_int32)
{
749
    auto create_program = [](bool add_return = false) {
750
        migraphx::program p;
751
        auto* mm = p.get_main_module();
752
753
754
        migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
755
756
757
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
758

759
760
        auto res = mm->add_instruction(
            migraphx::make_op("dot", {{"alpha", 2.0f}, {"beta", 5.5f}}), pa, pb, pc);
761
762
        if(add_return)
        {
763
            mm->add_return({res});
764
        }
765
766
767
768

        return p;
    };

769
    auto create_int8_quantized_prog = [](bool add_return = false) {
770
        migraphx::program p;
771
        auto* mm = p.get_main_module();
772
773
774
        migraphx::shape sa{migraphx::shape::int32_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::int32_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::int32_type, {2, 8}};
775
776
777
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
778
779
        // quantize parameter a to int8 type, multiply the scale
        std::vector<float> vfa(sa.elements(), 0.1f);
780
        auto fa = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
781
782
783
784
785
        auto conv_a = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            pa);
        auto ma = mm->add_instruction(migraphx::make_op("mul"), fa, conv_a);
786
787
788

        // add the shift
        std::vector<float> vsa(sa.elements(), 1.0f);
789
790
        auto sfta =
            mm->add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa));
791
792
        auto msa = mm->add_instruction(migraphx::make_op("add"), sfta, ma);
        auto ra  = mm->add_instruction(migraphx::make_op("round"), msa);
793
        auto ca  = create_clip_op(p, 127.0f, -128.0f, ra);
794
795
796
797
        auto qa  = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            ca);
798
799
800
801

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pb);
        std::vector<float> vfb(sb.elements(), 0.1f);
802
803
        auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
        auto conv_b = mm->insert_instruction(
804
805
806
807
808
809
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            pb);
        auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, conv_b);
        auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
kahmed10's avatar
kahmed10 committed
810
        auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
811
        auto qb = mm->insert_instruction(
812
813
814
815
816
817
818
819
820
821
822
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            cb);

        auto qdot = mm->add_instruction(
            migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qb);
        auto fr = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            qdot);
823
        std::vector<float> v_alpha(fr->get_shape().elements(), 20.0f);
824
        auto new_alpha = mm->add_literal(migraphx::literal(fr->get_shape(), v_alpha));
825
826
827
828
829
        auto alpha_ab  = mm->add_instruction(migraphx::make_op("mul"), new_alpha, fr);
        auto fc        = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            pc);
830
        std::vector<float> v_beta(fc->get_shape().elements(), 5.5f);
831
        auto beta   = mm->add_literal(migraphx::literal(fc->get_shape(), v_beta));
832
833
834
835
836
837
        auto beta_c = mm->add_instruction(migraphx::make_op("mul"), beta, fc);
        auto f_res  = mm->add_instruction(migraphx::make_op("add"), alpha_ab, beta_c);
        auto res    = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
            f_res);
838
839
        if(add_return)
        {
840
            mm->add_return({res});
841
        }
842
843
844
845
846

        return p;
    };

    auto p = create_program();
Shucai Xiao's avatar
Shucai Xiao committed
847
848
    const std::vector<std::pair<float, float>>& quant_params{
        {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
849
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
850
851
    auto qp = create_int8_quantized_prog();
    EXPECT(p == qp);
852
853
854
855
856

    auto p_ret = create_program(true);
    migraphx::quantize_int8_impl(p_ret, quant_params, {"dot"});
    auto qp_ret = create_int8_quantized_prog(true);
    EXPECT(p_ret == qp_ret);
857
858
}

859
860
861
862
TEST_CASE(dot_float_convert)
{
    auto create_program = [] {
        migraphx::program p;
863
        auto* mm = p.get_main_module();
864
865
        migraphx::shape sa{migraphx::shape::int8_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
866
867
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
868

869
870
871
872
873
        auto fpa = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            pa);
        mm->add_instruction(migraphx::make_op("dot", {{"alpha", 2.0f}, {"beta", 5.5f}}), fpa, pb);
874
875
876
877
878
879

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
880
        auto* mm = p.get_main_module();
881
882
        migraphx::shape sa{migraphx::shape::int8_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
883
884
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
885
886
887
888

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pb);
        std::vector<float> vfb(sb.elements(), 0.1f);
889
        auto fb = mm->add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
890
891
        auto mb = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fb, pb);
        auto rb = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mb);
kahmed10's avatar
kahmed10 committed
892
        auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
893
        auto qb = mm->insert_instruction(
894
895
896
897
898
899
900
901
902
903
904
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            cb);

        auto qdot = mm->add_instruction(
            migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), pa, qb);
        auto fr = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            qdot);
905
        std::vector<float> v_alpha(fr->get_shape().elements(), 10.0f);
906
        auto new_alpha = mm->add_literal(migraphx::literal(fr->get_shape(), v_alpha));
907
        mm->add_instruction(migraphx::make_op("mul"), new_alpha, fr);
908
909
910
911
912

        return p;
    };

    auto p = create_program();
Shucai Xiao's avatar
Shucai Xiao committed
913
    const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
914
    migraphx::quantize_int8_impl(p, quant_params, {"dot"});
915
    migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
916
917
918
919
920
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

921
922
TEST_CASE(conv_float)
{
Shucai Xiao's avatar
Shucai Xiao committed
923
    auto create_program = [] {
924
        migraphx::program p;
925
        auto* mm = p.get_main_module();
926
        auto input =
927
            mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
928
        auto weights =
929
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
930
        mm->add_instruction(migraphx::make_op("convolution"), input, weights);
931
932
933
934
935
936

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
937
        auto* mm = p.get_main_module();
938
939
        migraphx::shape sx{migraphx::shape::float_type, {4, 3, 3, 3}};
        migraphx::shape sw{migraphx::shape::float_type, {4, 3, 3, 3}};
940
941
        auto px = mm->add_parameter("x", sx);
        auto pw = mm->add_parameter("w", sw);
942
943
        // quantize parameter a to int8 type, multiply the scale
        std::vector<float> vfx(sx.elements(), 0.1f);
944
        auto fx = mm->add_literal(migraphx::literal(sx, vfx));
945
946
        auto mx = mm->add_instruction(migraphx::make_op("mul"), fx, px);
        auto rx = mm->add_instruction(migraphx::make_op("round"), mx);
kahmed10's avatar
kahmed10 committed
947
        auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
948
949
950
951
        auto qx = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            cx);
952
953
954
955

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pw);
        std::vector<float> vfw(sw.elements(), 0.1f);
956
        auto fw = mm->add_literal(migraphx::literal(sw, vfw));
957
958
        auto mw = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fw, pw);
        auto rw = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mw);
kahmed10's avatar
kahmed10 committed
959
        auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
960
        auto qw = mm->insert_instruction(
961
962
963
964
965
966
967
968
969
970
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            cw);

        auto q_conv = mm->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
        auto f_conv = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            q_conv);
971
        std::vector<float> v_adj(f_conv->get_shape().elements(), 100.0f);
972
        auto adj = mm->add_literal(migraphx::literal(f_conv->get_shape(), v_adj));
973
        mm->add_instruction(migraphx::make_op("mul"), adj, f_conv);
974
975
976
977
978
979

        return p;
    };

    auto p = create_program();
    const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
980
    migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
981
982
983
984
985
986
987
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

TEST_CASE(conv_int32)
{
Shucai Xiao's avatar
Shucai Xiao committed
988
    auto create_program = [] {
989
        migraphx::program p;
990
        auto* mm = p.get_main_module();
991
        auto input =
992
            mm->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {4, 3, 3, 3}});
993
        auto weights =
994
            mm->add_parameter("w", migraphx::shape{migraphx::shape::int32_type, {4, 3, 3, 3}});
995
        mm->add_instruction(migraphx::make_op("convolution"), input, weights);
996
997
998
999
1000
1001

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
1002
1003

        auto* mm = p.get_main_module();
1004
1005
        migraphx::shape sx{migraphx::shape::int32_type, {4, 3, 3, 3}};
        migraphx::shape sw{migraphx::shape::int32_type, {4, 3, 3, 3}};
1006
1007
        auto px = mm->add_parameter("x", sx);
        auto pw = mm->add_parameter("w", sw);
1008
        // quantize parameter a to int8 type, multiply the scale
1009
1010
1011
1012
        auto fpx = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            px);
1013
        std::vector<float> vfx(sx.elements(), 0.1f);
1014
        auto fx = mm->add_literal(migraphx::literal(fpx->get_shape(), vfx));
1015
1016
        auto mx = mm->add_instruction(migraphx::make_op("mul"), fx, fpx);
        auto rx = mm->add_instruction(migraphx::make_op("round"), mx);
kahmed10's avatar
kahmed10 committed
1017
        auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
1018
1019
1020
1021
        auto qx = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            cx);
1022
1023
1024

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pw);
1025
        auto fpw        = mm->insert_instruction(
1026
1027
1028
1029
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            pw);
1030
        std::vector<float> vfw(sw.elements(), 0.1f);
1031
        auto fw = mm->add_literal(migraphx::literal(fpw->get_shape(), vfw));
1032
1033
        auto mw = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fw, fpw);
        auto rw = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mw);
kahmed10's avatar
kahmed10 committed
1034
        auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
1035
        auto qw = mm->insert_instruction(
1036
1037
1038
1039
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            cw);
1040

1041
        auto q_conv = mm->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
1042
        std::vector<float> v_adj(q_conv->get_shape().elements(), 100.0f);
1043
        auto adj = mm->add_literal(migraphx::literal(q_conv->get_shape(), v_adj));
1044
        mm->add_instruction(migraphx::make_op("mul"), q_conv, adj);
1045
1046
1047
1048
1049
1050

        return p;
    };

    auto p = create_program();
    const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
1051
    migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
1052
1053
1054
1055
1056
1057
1058
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

TEST_CASE(conv_half)
{
Shucai Xiao's avatar
Shucai Xiao committed
1059
    auto create_program = [] {
1060
        migraphx::program p;
1061
        auto* mm = p.get_main_module();
1062
        auto input =
1063
            mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}});
1064
        auto weights =
1065
            mm->add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}});
1066
        mm->add_instruction(migraphx::make_op("convolution"), input, weights);
1067
1068
1069
1070
1071
1072

        return p;
    };

    auto create_int8_quantized_prog = [] {
        migraphx::program p;
1073
        auto* mm = p.get_main_module();
1074
1075
        migraphx::shape sx{migraphx::shape::half_type, {4, 3, 3, 3}};
        migraphx::shape sw{migraphx::shape::half_type, {4, 3, 3, 3}};
1076
1077
        auto px = mm->add_parameter("x", sx);
        auto pw = mm->add_parameter("w", sw);
1078
        // quantize parameter a to int8 type, multiply the scale
1079
1080
1081
1082
        auto fpx = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            px);
1083
        std::vector<float> vfx(sx.elements(), 0.1f);
1084
        auto fx = mm->add_literal(migraphx::literal(fpx->get_shape(), vfx));
1085
1086
        auto mx = mm->add_instruction(migraphx::make_op("mul"), fx, fpx);
        auto rx = mm->add_instruction(migraphx::make_op("round"), mx);
kahmed10's avatar
kahmed10 committed
1087
        auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
1088
1089
1090
1091
        auto qx = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            cx);
1092
1093
1094

        // quantize parameter b to int8 type
        auto insert_loc = std::next(pw);
1095
        auto fpw        = mm->insert_instruction(
1096
1097
1098
1099
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            pw);
1100
        std::vector<float> vfw(sw.elements(), 0.1f);
1101
        auto fw = mm->add_literal(migraphx::literal(fpw->get_shape(), vfw));
1102
1103
        auto mw = mm->insert_instruction(insert_loc, migraphx::make_op("mul"), fw, fpw);
        auto rw = mm->insert_instruction(insert_loc, migraphx::make_op("round"), mw);
kahmed10's avatar
kahmed10 committed
1104
        auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
1105
        auto qw = mm->insert_instruction(
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
            insert_loc,
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
            cw);

        auto q_conv = mm->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
        auto f_conv = mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
            q_conv);
1116
        std::vector<float> v_adj(f_conv->get_shape().elements(), 100.0f);
1117
        auto adj   = mm->add_literal(migraphx::literal(f_conv->get_shape(), v_adj));
1118
1119
1120
1121
1122
        auto f_res = mm->add_instruction(migraphx::make_op("mul"), adj, f_conv);
        mm->add_instruction(
            migraphx::make_op("convert",
                              {{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
            f_res);
1123
1124
1125
1126
1127
1128

        return p;
    };

    auto p = create_program();
    const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
Shucai Xiao's avatar
Shucai Xiao committed
1129
    migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
1130
1131
1132
1133
1134
    auto qp = create_int8_quantized_prog();

    EXPECT(p == qp);
}

1135
1136
1137
1138
TEST_CASE(target_copy)
{
    auto run_prog = [](migraphx::program p,
                       const migraphx::target& t,
1139
                       migraphx::parameter_map& m_in,
1140
1141
                       std::vector<float>& res) {
        p.compile(t);
1142
        migraphx::parameter_map m;
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
        for(auto&& x : p.get_parameter_shapes())
        {
            if(m_in.count(x.first) > 0)
            {
                m[x.first] = t.copy_to(m_in[x.first]);
            }
            else
            {
                m[x.first] = t.allocate(x.second);
            }
        }

1155
        auto result = t.copy_from(p.eval(m).back());
1156
1157
1158
1159
1160
        result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
    };

    auto create_program = [] {
        migraphx::program p;
1161
        auto* mm = p.get_main_module();
1162
        migraphx::shape s{migraphx::shape::float_type, {3, 3}};
1163
1164
        auto p1 = mm->add_parameter("x", s);
        auto p2 = mm->add_parameter("y", s);
1165
        mm->add_instruction(migraphx::make_op("add"), p1, p2);
1166
1167
1168
1169
1170
1171

        return p;
    };

    {
        auto p = create_program();
1172
        migraphx::parameter_map m;
1173
1174
        migraphx::shape s{migraphx::shape::float_type, {3, 3}};
        m["x"] = migraphx::generate_argument(s);
1175
1176
1177
        std::vector<float> ref_result;
        migraphx::target ref_t = migraphx::ref::target{};
        run_prog(p, ref_t, m, ref_result);
1178
1179

        std::vector<float> orig_result;
1180
        run_prog(p, ref_t, m, orig_result);
1181

1182
        EXPECT(migraphx::verify_range(ref_result, orig_result));
1183
1184
1185
    }
}

1186
TEST_CASE(int8_quantization_dot)
1187
1188
1189
{
    auto run_prog = [](migraphx::program p,
                       const migraphx::target& t,
1190
                       migraphx::parameter_map& m_in,
1191
1192
                       std::vector<float>& res,
                       bool b_quantize = false) {
Shucai Xiao's avatar
Shucai Xiao committed
1193
        if(b_quantize)
1194
        {
1195
            std::vector<migraphx::parameter_map> cali_data;
1196
1197
1198
1199
            cali_data.push_back(m_in);
            migraphx::quantize_int8(p, t, cali_data);
        }
        p.compile(t);
1200
        migraphx::parameter_map m;
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
        for(auto&& x : p.get_parameter_shapes())
        {
            if(m_in.count(x.first) > 0)
            {
                m[x.first] = t.copy_to(m_in[x.first]);
            }
            else
            {
                m[x.first] = t.allocate(x.second);
            }
        }

1213
        auto result = t.copy_from(p.eval(m).back());
1214
1215
1216
1217
1218
        result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
    };

    auto create_program = [] {
        migraphx::program p;
1219
        auto* mm = p.get_main_module();
1220
1221
1222
        migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
        migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
        migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
1223
1224
1225
        auto pa = mm->add_parameter("a", sa);
        auto pb = mm->add_parameter("b", sb);
        auto pc = mm->add_parameter("c", sc);
1226
        mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
1227
1228
1229
1230
1231
1232

        return p;
    };

    {
        auto p = create_program();
1233
        migraphx::parameter_map m;
1234
1235
1236
1237
1238
        migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
        migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
        m["a"] = migraphx::generate_argument(sa);
        m["c"] = migraphx::generate_argument(sc);
        std::vector<float> quant_result;
1239
1240
        migraphx::target ref_t = migraphx::ref::target{};
        run_prog(p, ref_t, m, quant_result, true);
1241
1242

        std::vector<float> no_quant_result;
1243
        run_prog(p, ref_t, m, no_quant_result);
1244
1245
1246
1247
1248

        EXPECT(migraphx::verify_range(quant_result, no_quant_result));
    }
}

1249
1250
1251
1252
1253
1254
1255
1256
TEST_CASE(int8_quantization_conv)
{
    auto run_prog = [](migraphx::program p,
                       const migraphx::target& t,
                       std::vector<float>& res,
                       bool b_quantize = false) {
        if(b_quantize)
        {
1257
            std::vector<migraphx::parameter_map> cali_data;
1258
1259
1260
            migraphx::quantize_int8(p, t, cali_data);
        }
        p.compile(t);
1261
        migraphx::parameter_map m;
1262

1263
        auto result = t.copy_from(p.eval(m).back());
1264
1265
1266
1267
1268
        result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
    };

    auto create_program = [] {
        migraphx::program p;
1269
        auto* mm = p.get_main_module();
1270
1271
1272
        migraphx::shape sx{migraphx::shape::float_type, {4, 2, 2, 2}};
        migraphx::shape sw{migraphx::shape::float_type, {4, 2, 2, 2}};
        std::vector<float> v(sx.elements(), 0.5f);
1273
1274
        auto input   = mm->add_literal(migraphx::literal(sx, v));
        auto weights = mm->add_literal(migraphx::literal(sw, v));
1275
        mm->add_instruction(migraphx::make_op("convolution"), input, weights);
1276
1277
1278
1279
1280
1281
1282

        return p;
    };

    {
        auto p = create_program();
        std::vector<float> quant_result;
1283
1284
        migraphx::target ref_t = migraphx::ref::target{};
        run_prog(p, ref_t, quant_result, true);
1285
1286

        std::vector<float> no_quant_result;
1287
        run_prog(p, ref_t, no_quant_result);
1288
1289
1290
1291
1292

        EXPECT(migraphx::verify_range(quant_result, no_quant_result));
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
1293
int main(int argc, const char* argv[]) { test::run(argc, argv); }