"include/vscode:/vscode.git/clone" did not exist on "fd7eee0d552e8e33d0d10a6425cfa2105bf03654"
Commit dc0c4810 authored by wsttiger's avatar wsttiger
Browse files

Renamed kernels.cpp -> hip_contiguous.cpp and fixed up hip_tensor_descriptor

parent 0aeeb4bb
...@@ -7,7 +7,7 @@ if(NOT TARGET MIOpen) ...@@ -7,7 +7,7 @@ if(NOT TARGET MIOpen)
endif() endif()
add_library(migraph_device add_library(migraph_device
kernels.cpp hip_contiguous.cpp
) )
rocm_clang_tidy_check(migraph_device) rocm_clang_tidy_check(migraph_device)
target_link_libraries(migraph_device migraph hip::device) target_link_libraries(migraph_device migraph hip::device)
......
...@@ -5,38 +5,86 @@ ...@@ -5,38 +5,86 @@
namespace migraph { namespace migraph {
namespace miopen { namespace miopen {
template <int NDIM> template <class F>
void visit_tensor_size(std::size_t n, F f)
{
switch(n)
{
case 0:
{
f(std::integral_constant<std::size_t, 0>{});
break;
}
case 1:
{
f(std::integral_constant<std::size_t, 1>{});
break;
}
case 2:
{
f(std::integral_constant<std::size_t, 2>{});
break;
}
case 3:
{
f(std::integral_constant<std::size_t, 3>{});
break;
}
case 4:
{
f(std::integral_constant<std::size_t, 4>{});
break;
}
case 5:
{
f(std::integral_constant<std::size_t, 5>{});
break;
}
default: throw std::runtime_error("Unknown tensor size");
}
}
template <size_t NDim>
struct hip_tensor_descriptor struct hip_tensor_descriptor
{ {
size_t lens[NDIM]; hip_tensor_descriptor() = default;
size_t strides[NDIM]; template <typename T, typename V>
hip_tensor_descriptor(const T& lens_, const V& strides_)
{
for(size_t i = 0; i < NDim; i++)
lens[i] = lens_[i];
for(size_t i = 0; i < NDim; i++)
strides[i] = strides_[i];
}
size_t lens[NDim];
size_t strides[NDim];
}; };
template <int NDIM> template <size_t NDim>
__host__ __device__ void multiindex(size_t (&strides)[NDIM], size_t idx, size_t* result) __host__ __device__ void multiindex(size_t (&strides)[NDim], size_t idx, size_t* result)
{ {
size_t tidx = idx; size_t tidx = idx;
for(size_t is = 0; is < NDIM; is++) for(size_t is = 0; is < NDim; is++)
{ {
result[is] = tidx / strides[is]; result[is] = tidx / strides[is];
tidx = tidx % strides[is]; tidx = tidx % strides[is];
} }
} }
template <typename T, int NDIM> template <typename T, size_t NDim>
__global__ void contiguous_gpu(const T* a, __global__ void contiguous_gpu(const T* a,
hip_tensor_descriptor<NDIM> a_desc, hip_tensor_descriptor<NDim> a_desc,
T* at, T* at,
hip_tensor_descriptor<NDIM> at_desc, hip_tensor_descriptor<NDim> at_desc,
size_t nelements) size_t nelements)
{ {
for(size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < nelements; for(size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < nelements;
i += blockDim.x * gridDim.x) i += blockDim.x * gridDim.x)
{ {
size_t s[NDIM]; size_t s[NDim];
multiindex<NDIM>(at_desc.strides, i, s); multiindex<NDim>(at_desc.strides, i, s);
size_t lidx = 0; size_t lidx = 0;
for(size_t j = 0; j < NDIM; j++) for(size_t j = 0; j < NDim; j++)
lidx += s[j] * a_desc.strides[j]; lidx += s[j] * a_desc.strides[j];
at[i] = a[lidx]; at[i] = a[lidx];
} }
...@@ -48,14 +96,9 @@ void hip_contiguous(migraph::shape output_shape, migraph::argument arg, migraph: ...@@ -48,14 +96,9 @@ void hip_contiguous(migraph::shape output_shape, migraph::argument arg, migraph:
visit_all(result, arg)([&](auto output, auto input) { visit_all(result, arg)([&](auto output, auto input) {
if(ndim == 4) if(ndim == 4)
{ {
hip_tensor_descriptor<4> a_desc{};
hip_tensor_descriptor<4> at_desc{};
const auto& s = arg.get_shape(); const auto& s = arg.get_shape();
for(int i = 0; i < ndim; i++) hip_tensor_descriptor<4> a_desc(s.lens(), s.strides());
{ hip_tensor_descriptor<4> at_desc(output_shape.lens(), output_shape.strides());
a_desc.strides[i] = s.strides().at(i);
at_desc.strides[i] = output_shape.strides().at(i);
}
dim3 nblocks(512); dim3 nblocks(512);
dim3 nthreads(512); dim3 nthreads(512);
hipLaunchKernelGGL((contiguous_gpu<int, 4>), hipLaunchKernelGGL((contiguous_gpu<int, 4>),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment