int8_gemm_pack.cpp 2.95 KB
Newer Older
1
2
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
3
#include <migraphx/gpu/device/int8_gemm_pack.hpp>
4
5
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
6
#include <migraphx/gpu/device/tensor.hpp>
7
8
9
10
11
12

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

13
void int8_gemm_pack_a(hipStream_t stream, const argument& result, const argument& arg)
14
{
Shucai Xiao's avatar
Shucai Xiao committed
15
    auto comp_shape    = arg.get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
16
    auto out_lens      = comp_shape.lens();
Shucai Xiao's avatar
Shucai Xiao committed
17
18
    auto dim_0         = out_lens.size() - 2;
    auto dim_1         = out_lens.size() - 1;
Shucai Xiao's avatar
Shucai Xiao committed
19
    std::size_t lda    = comp_shape.strides()[dim_0];
20
    std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
21
    visit_all(result, arg)([&](auto output, auto input) {
Shucai Xiao's avatar
Shucai Xiao committed
22
        std::size_t nelements = comp_shape.elements();
Shucai Xiao's avatar
Shucai Xiao committed
23
24
        auto* out_ptr         = device_cast(output.data());
        auto* in_ptr          = device_cast(input.data());
25
        visit_tensor_size(out_lens.size(), [&](auto out_dim) {
Shucai Xiao's avatar
Shucai Xiao committed
26
            hip_tensor_descriptor<out_dim> desc(comp_shape);
Shucai Xiao's avatar
Shucai Xiao committed
27
            gs_launch(stream, nelements, 256)([=](auto ii) {
Shucai Xiao's avatar
Shucai Xiao committed
28
29
30
31
                const size_t nb    = 4;
                auto idx           = desc.multi(ii);
                std::size_t i_m    = idx[dim_1];
                std::size_t i_k    = idx[dim_0];
32
                std::size_t offset = ii / m_size * m_size;
Shucai Xiao's avatar
Shucai Xiao committed
33
34
                out_ptr[i_k % nb + (i_m + (i_k / nb) * lda) * nb + offset] =
                    in_ptr[i_m + i_k * lda + offset];
35
36
37
38
39
            });
        });
    });
}

40
void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument& arg)
41
{
Shucai Xiao's avatar
Shucai Xiao committed
42
    auto trans_shape = arg.get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
43
44
45
46
    auto out_lens    = trans_shape.lens();
    auto dim_0       = trans_shape.lens().size() - 2;
    auto dim_1       = trans_shape.lens().size() - 1;
    std::size_t ldb  = trans_shape.strides()[dim_1];
47
48
49

    auto wrap_lens = out_lens;
    std::swap(wrap_lens[dim_0], wrap_lens[dim_1]);
Shucai Xiao's avatar
Shucai Xiao committed
50
    shape comp_shape{trans_shape.type(), wrap_lens};
51
    std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
52
    visit_all(result, arg)([&](auto output, auto input) {
Shucai Xiao's avatar
Shucai Xiao committed
53
        std::size_t nelements = comp_shape.elements();
Shucai Xiao's avatar
Shucai Xiao committed
54
55
        auto* out_ptr         = device_cast(output.data());
        auto* in_ptr          = device_cast(input.data());
56
        visit_tensor_size(out_lens.size(), [&](auto out_dim) {
Shucai Xiao's avatar
Shucai Xiao committed
57
            hip_tensor_descriptor<out_dim> desc(comp_shape);
Shucai Xiao's avatar
Shucai Xiao committed
58
            gs_launch(stream, nelements, 256)([=](auto ii) {
Shucai Xiao's avatar
Shucai Xiao committed
59
60
                const size_t nb    = 4;
                auto idx           = desc.multi(ii);
61
62
                std::size_t i_n    = idx[dim_1];
                std::size_t i_k    = idx[dim_0];
63
                std::size_t offset = ii / m_size * m_size;
Shucai Xiao's avatar
Shucai Xiao committed
64
65
                out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] =
                    in_ptr[i_n + i_k * ldb + offset];
66
67
68
69
70
71
72
73
74
            });
        });
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx