Commit 05031d74 authored by wsttiger's avatar wsttiger
Browse files

transpose and contiguous seem to be working on GPU with test

parent 04ca2e74
...@@ -6,20 +6,17 @@ if(NOT TARGET MIOpen) ...@@ -6,20 +6,17 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
add_library(migraph_device
kernels.cpp
)
rocm_clang_tidy_check(migraph_device)
target_link_libraries(migraph_device migraph hip::device)
target_include_directories(migraph_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
add_library(migraph_miopen add_library(migraph_miopen
hip.cpp hip.cpp
miopen_target.cpp miopen_target.cpp
) )
rocm_clang_tidy_check(migraph_miopen) rocm_clang_tidy_check(migraph_miopen)
target_link_libraries(migraph_miopen migraph MIOpen) target_link_libraries(migraph_miopen migraph MIOpen migraph_device)
target_include_directories(migraph_miopen PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraph_miopen PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
add_library(migraph_device
kernels.cu
)
rocm_clang_tidy_check(migraph_device)
target_link_libraries(migraph_device migraph hip::device)
target_include_directories(migraph_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace migraph { namespace migraph {
namespace miopen { namespace miopen {
migraph::argument hip_contiguous(migraph::argument arg, migraph::shape output_shape); void hip_contiguous(migraph::shape output_shape, migraph::argument arg, migraph::argument result);
} // namespace miopen } // namespace miopen
......
...@@ -42,22 +42,24 @@ __global__ void contiguous_gpu(const T* A, ...@@ -42,22 +42,24 @@ __global__ void contiguous_gpu(const T* A,
} }
} }
migraph::argument hip_contiguous(migraph::argument arg, migraph::shape output_shape) void hip_contiguous(migraph::shape output_shape, migraph::argument arg, migraph::argument result)
{ {
migraph::argument result{output_shape};
size_t ndim = output_shape.lens().size(); size_t ndim = output_shape.lens().size();
visit_all(result, arg)([&](auto output, auto input) { visit_all(result, arg)([&](auto output, auto input) {
if(ndim == 4) if(ndim == 4)
{ {
HIPTensorDescriptor<4> td_a, td_at; HIPTensorDescriptor<4> td_a, td_at;
auto s = arg.get_shape(); auto s = arg.get_shape();
for(int i = 0; i < output_shape.lens().size(); i++) for(int i = 0; i < ndim; i++)
{ {
td_a.strides[i] = s.strides().at(i); td_a.strides[i] = s.strides().at(i);
td_at.strides[i] = output_shape.strides().at(i); td_at.strides[i] = output_shape.strides().at(i);
} }
dim3 nblocks(512); dim3 nblocks(512);
dim3 nthreads(512); dim3 nthreads(512);
// std::cout << "nelements: " << s.elements() << std::endl;
// std::cout << "A ptr: " << input.data() << std::endl;
// std::cout << "At ptr: " << output.data() << std::endl;
hipLaunchKernelGGL((contiguous_gpu<int, 4>), hipLaunchKernelGGL((contiguous_gpu<int, 4>),
nblocks, nblocks,
nthreads, nthreads,
...@@ -74,7 +76,6 @@ migraph::argument hip_contiguous(migraph::argument arg, migraph::shape output_sh ...@@ -74,7 +76,6 @@ migraph::argument hip_contiguous(migraph::argument arg, migraph::shape output_sh
MIGRAPH_THROW("contiguous is only valid for 4D tensors"); MIGRAPH_THROW("contiguous is only valid for 4D tensors");
} }
}); });
return result;
} }
} // namespace miopen } // namespace miopen
} // namespace migraph } // namespace migraph
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraph/miopen/miopen.hpp> #include <migraph/miopen/miopen.hpp>
#include <migraph/miopen/hip.hpp> #include <migraph/miopen/hip.hpp>
#include <migraph/dfor.hpp> #include <migraph/dfor.hpp>
#include <migraph/miopen/kernels.hpp>
namespace migraph { namespace migraph {
namespace miopen { namespace miopen {
...@@ -200,13 +201,15 @@ struct miopen_contiguous ...@@ -200,13 +201,15 @@ struct miopen_contiguous
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; // argument result{output_shape};
visit_all(result, from_gpu(args[0]))([&](auto output, auto input) { // visit_all(result, from_gpu(args[0]))([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) { // shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()); // output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
}); // });
}); // });
return to_gpu(result); // return to_gpu(result);
hip_contiguous(output_shape, args.at(0), args.at(1));
return args.at(1);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment