"vscode:/vscode.git/clone" did not exist on "05b87aec8ebee091df05d5e6b5bb69340759823a"
hip_contiguous.cpp 3.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

#include <hip/hip_runtime.h>
#include <migraph/operators.hpp>

namespace migraph {
namespace miopen {

template <class F>
void visit_tensor_size(std::size_t n, F f)
{
    switch(n)
    {
    case 0:
    {
        f(std::integral_constant<std::size_t, 0>{});
        break;
    }
    case 1:
    {
        f(std::integral_constant<std::size_t, 1>{});
        break;
    }
    case 2:
    {
        f(std::integral_constant<std::size_t, 2>{});
        break;
    }
    case 3:
    {
        f(std::integral_constant<std::size_t, 3>{});
        break;
    }
    case 4:
    {
        f(std::integral_constant<std::size_t, 4>{});
        break;
    }
    case 5:
    {
        f(std::integral_constant<std::size_t, 5>{});
        break;
    }
    default: throw std::runtime_error("Unknown tensor size");
    }
}

wsttiger's avatar
wsttiger committed
47
48
49
50
51
52
53
54
template <size_t NDim>
struct hip_index
{
    size_t d[NDim];
    size_t& operator[](size_t i) { return d[i]; }
    size_t operator[](size_t i) const { return d[i]; }
};

55
56
57
58
59
template <size_t NDim>
struct hip_tensor_descriptor
{
    hip_tensor_descriptor() = default;
    template <typename T, typename V>
wsttiger's avatar
wsttiger committed
60
    hip_tensor_descriptor(const T& lens_ext, const V& strides_ext)
61
62
    {
        for(size_t i = 0; i < NDim; i++)
wsttiger's avatar
wsttiger committed
63
            lens[i] = lens_ext[i];
64
        for(size_t i = 0; i < NDim; i++)
wsttiger's avatar
wsttiger committed
65
            strides[i] = strides_ext[i];
66
    }
wsttiger's avatar
wsttiger committed
67
    hip_index<NDim> multi(size_t idx)
68
    {
wsttiger's avatar
wsttiger committed
69
70
71
72
73
74
75
76
        hip_index<NDim> result{};
        size_t tidx = idx;
        for(size_t is = 0; is < NDim; is++)
        {
            result[is] = tidx / strides[is];
            tidx       = tidx % strides[is];
        }
        return result;
77
    }
wsttiger's avatar
wsttiger committed
78
79
80
81
82
83
84
85
86
87
    size_t linear(hip_index<NDim> s)
    {
        size_t idx = 0;
        for(size_t i = 0; i < NDim; i++)
            idx += s[i] * strides[i];
        return idx;
    }
    size_t lens[NDim]    = {};
    size_t strides[NDim] = {};
};
88
89
90
91
92
93
94
95
96
97
98

template <typename T, size_t NDim>
__global__ void contiguous_gpu(const T* a,
                               hip_tensor_descriptor<NDim> a_desc,
                               T* at,
                               hip_tensor_descriptor<NDim> at_desc,
                               size_t nelements)
{
    for(size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < nelements;
        i += blockDim.x * gridDim.x)
    {
wsttiger's avatar
wsttiger committed
99
100
101
        hip_index<NDim> s = at_desc.multi(i);
        size_t lidx       = a_desc.linear(s);
        at[i]             = a[lidx];
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    }
}

void hip_contiguous(migraph::shape output_shape, migraph::argument arg, migraph::argument result)
{
    size_t ndim = output_shape.lens().size();
    visit_all(result, arg)([&](auto output, auto input) {
        if(ndim == 4)
        {
            const auto& s = arg.get_shape();
            hip_tensor_descriptor<4> a_desc(s.lens(), s.strides());
            hip_tensor_descriptor<4> at_desc(output_shape.lens(), output_shape.strides());
            dim3 nblocks(512);
            dim3 nthreads(512);
            hipLaunchKernelGGL((contiguous_gpu<int, 4>),
                               nblocks,
                               nthreads,
                               0,
                               nullptr,
                               input.data(),
                               a_desc,
                               output.data(),
                               at_desc,
                               s.elements());
        }
        else
        {
            MIGRAPH_THROW("contiguous is only valid for 4D tensors");
        }
    });
}
} // namespace miopen
} // namespace migraph