pack_int8_args.cpp 1.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/gpu/hip.hpp>
#include <algorithm>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

void pack_int8_args::apply(program& p) const
{
    for(auto ins : iterator_for(p))
    {
        if(ins->name() != "gpu::quant_gemm")
            continue;

Shucai Xiao's avatar
Shucai Xiao committed
19
        auto inputs  = ins->inputs();
20
        auto shape_a = inputs.at(0)->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
21
        if(shape_a.type() != shape::int8_type)
22
23
            continue;

Shucai Xiao's avatar
Shucai Xiao committed
24
        if(shape_a.transposed())
25
26
27
28
29
30
31
        {
            auto pack_a = p.insert_instruction(ins, hip_allocate{shape_a});
            inputs.push_back(pack_a);
            swap(inputs.at(0), inputs.back());
        }

        auto shape_b = inputs.at(1)->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
32
        if(!shape_b.transposed())
33
34
35
36
37
38
39
40
41
42
43
44
        {
            auto pack_b = p.insert_instruction(ins, hip_allocate{shape_b});
            inputs.push_back(pack_b);
            swap(inputs.at(1), inputs.back());
        }
        instruction::replace(ins, ins->get_operator(), ins->get_shape(), inputs);
    }
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx