kernels.cpp 2.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

#include <hip/hip_runtime.h>
#include <migraph/operators.hpp>

namespace migraph {
namespace miopen {

template <int NDIM>
struct HIPTensorDescriptor
{
    size_t lens[NDIM];
    size_t strides[NDIM];
};

template <int NDIM>
__host__ __device__ void multiindex(size_t (&strides)[NDIM], size_t idx, size_t* result)
{
    size_t tidx = idx;
    for(size_t is = 0; is < NDIM; is++)
    {
        result[is] = tidx / strides[is];
        tidx       = tidx % strides[is];
    }
}

template <typename T, int NDIM>
__global__ void contiguous_gpu(const T* A,
                               HIPTensorDescriptor<NDIM> td_a,
                               T* At,
                               HIPTensorDescriptor<NDIM> td_at,
                               size_t nelements)
{
    for(size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < nelements;
        i += blockDim.x * gridDim.x)
    {
        size_t s[NDIM];
        multiindex<NDIM>(td_at.strides, i, s);
        size_t lidx = 0;
        for(size_t j = 0; j < NDIM; j++)
            lidx += s[j] * td_a.strides[j];
        At[i] = A[lidx];
    }
}

45
void hip_contiguous(migraph::shape output_shape, migraph::argument arg, migraph::argument result)
46
47
48
49
50
51
52
{
    size_t ndim = output_shape.lens().size();
    visit_all(result, arg)([&](auto output, auto input) {
        if(ndim == 4)
        {
            HIPTensorDescriptor<4> td_a, td_at;
            auto s = arg.get_shape();
53
            for(int i = 0; i < ndim; i++)
54
55
56
57
58
59
            {
                td_a.strides[i]  = s.strides().at(i);
                td_at.strides[i] = output_shape.strides().at(i);
            }
            dim3 nblocks(512);
            dim3 nthreads(512);
60
61
62
            // std::cout << "nelements: " << s.elements() << std::endl;
            // std::cout << "A ptr: " << input.data() << std::endl;
            // std::cout << "At ptr: " << output.data() << std::endl;
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
            hipLaunchKernelGGL((contiguous_gpu<int, 4>),
                               nblocks,
                               nthreads,
                               0,
                               0,
                               input.data(),
                               td_a,
                               output.data(),
                               td_at,
                               s.elements());
        }
        else
        {
            MIGRAPH_THROW("contiguous is only valid for 4D tensors");
        }
    });
}
} // namespace miopen
} // namespace migraph