kernels.cpp 2.29 KB
Newer Older
1
2
3
4
5
6
7
8

#include <hip/hip_runtime.h>
#include <migraph/operators.hpp>

namespace migraph {
namespace miopen {

template <int NDIM>
9
struct hip_tensor_descriptor
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
{
    size_t lens[NDIM];
    size_t strides[NDIM];
};

template <int NDIM>
__host__ __device__ void multiindex(size_t (&strides)[NDIM], size_t idx, size_t* result)
{
    size_t tidx = idx;
    for(size_t is = 0; is < NDIM; is++)
    {
        result[is] = tidx / strides[is];
        tidx       = tidx % strides[is];
    }
}

template <typename T, int NDIM>
27
28
29
30
__global__ void contiguous_gpu(const T* a,
                               hip_tensor_descriptor<NDIM> a_desc,
                               T* at,
                               hip_tensor_descriptor<NDIM> at_desc,
31
32
33
34
35
36
                               size_t nelements)
{
    for(size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < nelements;
        i += blockDim.x * gridDim.x)
    {
        size_t s[NDIM];
37
        multiindex<NDIM>(at_desc.strides, i, s);
38
39
        size_t lidx = 0;
        for(size_t j = 0; j < NDIM; j++)
40
41
            lidx += s[j] * a_desc.strides[j];
        at[i] = a[lidx];
42
43
44
    }
}

45
void hip_contiguous(migraph::shape output_shape, migraph::argument arg, migraph::argument result)
46
47
48
49
50
{
    size_t ndim = output_shape.lens().size();
    visit_all(result, arg)([&](auto output, auto input) {
        if(ndim == 4)
        {
51
52
53
            hip_tensor_descriptor<4> a_desc{};
            hip_tensor_descriptor<4> at_desc{};
            const auto& s = arg.get_shape();
54
            for(int i = 0; i < ndim; i++)
55
            {
56
57
                a_desc.strides[i]  = s.strides().at(i);
                at_desc.strides[i] = output_shape.strides().at(i);
58
59
60
61
62
63
64
            }
            dim3 nblocks(512);
            dim3 nthreads(512);
            hipLaunchKernelGGL((contiguous_gpu<int, 4>),
                               nblocks,
                               nthreads,
                               0,
65
                               nullptr,
66
                               input.data(),
67
                               a_desc,
68
                               output.data(),
69
                               at_desc,
70
71
72
73
74
75
76
77
78
79
                               s.elements());
        }
        else
        {
            MIGRAPH_THROW("contiguous is only valid for 4D tensors");
        }
    });
}
} // namespace miopen
} // namespace migraph