quant_convolution.cpp 6.4 KB
Newer Older
1
2
3
4
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
5
6
7

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
8
namespace gpu {
9

10
shape miopen_quant_convolution::compute_shape(const std::vector<shape>& inputs) const
11
{
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
    check_shapes{inputs, *this}.has(5).standard();
    return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument miopen_quant_convolution::compute(context& ctx,
                                           const shape& output_shape,
                                           const std::vector<argument>& args) const
{
    auto x_desc      = make_tensor(args[0].get_shape());
    auto x_desc_vec4 = make_tensor(args[0].get_shape(), true);
    auto w_desc      = make_tensor(args[1].get_shape());
    auto w_desc_vec4 = make_tensor(args[1].get_shape(), true);
    shape tmp_output_shape{shape::float_type, output_shape.lens()};
    auto y_desc = make_tensor(tmp_output_shape);

    float alpha = 1;
    float beta  = 0;
28

29
30
31
32
33
34
35
36
37
    // pack input to vec4 format
    auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
                                        &alpha,
                                        x_desc.get(),
                                        args[0].implicit(),
                                        &beta,
                                        x_desc_vec4.get(),
                                        arg_vec4_x.implicit());
    if(status != miopenStatusSuccess)
38
    {
39
        MIGRAPHX_THROW("QUANT_CONVOLUTION: transform input tensor failed");
40
41
    }

42
43
44
45
46
47
48
49
50
    // pack input to vec4 format
    status = miopenTransformTensor(ctx.get_stream().get_miopen(),
                                   &alpha,
                                   w_desc.get(),
                                   args[1].implicit(),
                                   &beta,
                                   w_desc_vec4.get(),
                                   arg_vec4_w.implicit());
    if(status != miopenStatusSuccess)
51
    {
52
        MIGRAPHX_THROW("QUANT_CONVOLUTION: transform weight tensor failed");
53
54
    }

55
56
57
58
59
60
61
62
63
64
65
66
67
68
    status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
                                      &alpha,
                                      x_desc_vec4.get(),
                                      arg_vec4_x.implicit(),
                                      w_desc_vec4.get(),
                                      arg_vec4_w.implicit(),
                                      cd.get(),
                                      algo,
                                      &beta,
                                      y_desc.get(),
                                      args[3].implicit(),
                                      args[2].implicit(),
                                      args[2].get_shape().bytes());
    if(status != miopenStatusSuccess)
69
    {
70
        MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
71
    }
72

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    // Add a conversion from float to int32_t
    device::convert(ctx.get_stream().get(), args[4], args[3], 1.0f, 0.0f, shape::int32_type);

    return args[4];
}

shape miopen_quant_convolution::compile(context& ctx,
                                        const shape& output_shape,
                                        std::vector<shape> inputs)
{
    shape workspace_shape{};
    auto x_desc = make_tensor(inputs[0], true);
    auto w_desc = make_tensor(inputs[1], true);
    shape tmp_output_shape{shape::float_type, output_shape.lens()};
    auto y_desc = make_tensor(tmp_output_shape);

    std::size_t workspace_size = 0;
    miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
                                             w_desc.get(),
                                             x_desc.get(),
                                             cd.get(),
                                             y_desc.get(),
                                             &workspace_size);
    workspace_shape = shape{shape::int8_type, {workspace_size}};

    arg_vec4_x     = to_gpu(generate_argument(pack_int8_shape(inputs[0])));
    arg_vec4_w     = to_gpu(generate_argument(pack_int8_shape(inputs[1])));
    auto y         = allocate_gpu(tmp_output_shape);
    auto workspace = allocate_gpu(workspace_shape);

    int algo_count = 1;
    miopenConvAlgoPerf_t perf;
    auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
                                                        x_desc.get(),
                                                        arg_vec4_x.implicit(),
                                                        w_desc.get(),
                                                        arg_vec4_w.implicit(),
                                                        cd.get(),
                                                        y_desc.get(),
                                                        y.implicit(),
                                                        1,
                                                        &algo_count,
                                                        &perf,
                                                        workspace.implicit(),
                                                        workspace_size,
                                                        false);
    if(status != miopenStatusSuccess)
    {
        MIGRAPHX_THROW("QUANT_CONVOLUTION: find convolution failed");
    }
    handle = ctx.get_stream().get_miopen();
    algo   = perf.fwd_algo;
    return shape{shape::int8_type, {perf.memory}};
}
127

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
void miopen_quant_convolution::finalize(context& ctx,
                                        const shape& output_shape,
                                        std::vector<shape> inputs)
{
    if(handle == ctx.get_stream().get_miopen())
        return;
    // Check that workspace hasn't changed
    auto size = inputs.at(2).bytes();
    auto ws   = compile(ctx, output_shape, std::move(inputs));
    if(ws.bytes() > size)
        MIGRAPHX_THROW("Workspace has changed during finalization.");
}

shape miopen_quant_convolution::pack_int8_shape(shape& s)
{
    if(s.type() != shape::int8_type)
    {
        MIGRAPHX_THROW("PACK_INT8_SHAPE: only process int8_type");
    }

    auto lens    = s.lens();
    auto strides = s.strides();
    lens[1]      = (lens[1] + 3) / 4 * 4;
    strides[0]   = strides[1] * lens[1];

    return {s.type(), lens, strides};
}

} // namespace gpu
157
158
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx