#include #include namespace migraph { namespace miopen { template struct HIPTensorDescriptor { size_t lens[NDIM]; size_t strides[NDIM]; }; template __host__ __device__ void multiindex(size_t (&strides)[NDIM], size_t idx, size_t* result) { size_t tidx = idx; for(size_t is = 0; is < NDIM; is++) { result[is] = tidx / strides[is]; tidx = tidx % strides[is]; } } template __global__ void contiguous_gpu(const T* A, HIPTensorDescriptor td_a, T* At, HIPTensorDescriptor td_at, size_t nelements) { for(size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < nelements; i += blockDim.x * gridDim.x) { size_t s[NDIM]; multiindex(td_at.strides, i, s); size_t lidx = 0; for(size_t j = 0; j < NDIM; j++) lidx += s[j] * td_a.strides[j]; At[i] = A[lidx]; } } void hip_contiguous(migraph::shape output_shape, migraph::argument arg, migraph::argument result) { size_t ndim = output_shape.lens().size(); visit_all(result, arg)([&](auto output, auto input) { if(ndim == 4) { HIPTensorDescriptor<4> td_a, td_at; auto s = arg.get_shape(); for(int i = 0; i < ndim; i++) { td_a.strides[i] = s.strides().at(i); td_at.strides[i] = output_shape.strides().at(i); } dim3 nblocks(512); dim3 nthreads(512); hipLaunchKernelGGL((contiguous_gpu), nblocks, nthreads, 0, 0, input.data(), td_a, output.data(), td_at, s.elements()); } else { MIGRAPH_THROW("contiguous is only valid for 4D tensors"); } }); } } // namespace miopen } // namespace migraph