contiguous.cpp 3.43 KB
Newer Older
1
2

#include <hip/hip_runtime.h>
3
#include <migraph/gpu/device/contiguous.hpp>
4
5

namespace migraph {
Paul's avatar
Paul committed
6
namespace gpu {
7
namespace device {
8

Paul's avatar
Paul committed
9
10
11
12
13
14
15
struct index
{
    std::size_t global;
    std::size_t local;
    std::size_t group;
};

Paul's avatar
Paul committed
16
template <class F>
Paul's avatar
Paul committed
17
18
19
20
21
22
23
24
__global__ void launcher(F f)
{
    index idx{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x};
    f(idx);
}

auto launch(std::size_t global, std::size_t local)
{
Paul's avatar
Paul committed
25
    return [=](auto f) {
Paul's avatar
Paul committed
26
27
28
29
30
        assert(local > 0);
        assert(global > 0);
        using f_type = decltype(f);
        dim3 nblocks(global / local);
        dim3 nthreads(local);
Paul's avatar
Paul committed
31
        hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, nullptr, f);
Paul's avatar
Paul committed
32
33
34
    };
}

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
template <class F>
void visit_tensor_size(std::size_t n, F f)
{
    switch(n)
    {
    case 1:
    {
        f(std::integral_constant<std::size_t, 1>{});
        break;
    }
    case 2:
    {
        f(std::integral_constant<std::size_t, 2>{});
        break;
    }
    case 3:
    {
        f(std::integral_constant<std::size_t, 3>{});
        break;
    }
    case 4:
    {
        f(std::integral_constant<std::size_t, 4>{});
        break;
    }
    case 5:
    {
        f(std::integral_constant<std::size_t, 5>{});
        break;
    }
    default: throw std::runtime_error("Unknown tensor size");
    }
}

wsttiger's avatar
wsttiger committed
69
70
71
72
template <size_t NDim>
struct hip_index
{
    size_t d[NDim];
Paul's avatar
Paul committed
73
74
    __device__ __host__ size_t& operator[](size_t i) { return d[i]; }
    __device__ __host__ size_t operator[](size_t i) const { return d[i]; }
wsttiger's avatar
wsttiger committed
75
76
};

77
78
79
template <size_t NDim>
struct hip_tensor_descriptor
{
Paul's avatar
Paul committed
80
    __device__ __host__ hip_tensor_descriptor() = default;
81
    template <typename T, typename V>
Paul's avatar
Paul committed
82
    __device__ __host__ hip_tensor_descriptor(const T& lens_ext, const V& strides_ext)
83
84
    {
        for(size_t i = 0; i < NDim; i++)
wsttiger's avatar
wsttiger committed
85
            lens[i] = lens_ext[i];
86
        for(size_t i = 0; i < NDim; i++)
wsttiger's avatar
wsttiger committed
87
            strides[i] = strides_ext[i];
88
    }
Paul's avatar
Paul committed
89
    __device__ __host__ hip_index<NDim> multi(size_t idx)
90
    {
wsttiger's avatar
wsttiger committed
91
92
93
94
95
96
97
98
        hip_index<NDim> result{};
        size_t tidx = idx;
        for(size_t is = 0; is < NDim; is++)
        {
            result[is] = tidx / strides[is];
            tidx       = tidx % strides[is];
        }
        return result;
99
    }
Paul's avatar
Paul committed
100
    __device__ __host__ size_t linear(hip_index<NDim> s)
wsttiger's avatar
wsttiger committed
101
102
103
104
105
106
107
108
109
    {
        size_t idx = 0;
        for(size_t i = 0; i < NDim; i++)
            idx += s[i] * strides[i];
        return idx;
    }
    size_t lens[NDim]    = {};
    size_t strides[NDim] = {};
};
110

111
void contiguous(shape output_shape, argument arg, argument result)
112
113
{
    visit_all(result, arg)([&](auto output, auto input) {
Paul's avatar
Paul committed
114
        visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
115
            const auto& s = arg.get_shape();
Paul's avatar
Paul committed
116
117
            hip_tensor_descriptor<ndim> a_desc(s.lens(), s.strides());
            hip_tensor_descriptor<ndim> at_desc(output_shape.lens(), output_shape.strides());
Paul's avatar
Paul committed
118
119
120
121
122
            auto* a             = input.data();
            auto* at            = output.data();
            auto nelements      = s.elements();
            std::size_t nlocal  = 512;
            std::size_t nglobal = 512 * nlocal;
Paul's avatar
Paul committed
123
124
125
126

            launch(nglobal, nlocal)([=](auto idx) mutable {
                for(size_t i = idx.global; i < nelements; i += nglobal)
                {
Paul's avatar
Paul committed
127
128
                    size_t lidx = a_desc.linear(at_desc.multi(i));
                    at[i]       = a[lidx];
Paul's avatar
Paul committed
129
130
131
                }
            });
        });
132
133
    });
}
134
} // namespace device
Paul's avatar
Paul committed
135
} // namespace gpu
136
} // namespace migraph