"docs/en_US/QuickStart.md" did not exist on "a656bba5161b32080d2dc71c2ff331f34e183485"
Commit 702412b1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine contiguous gpu implementation

parent 562724bf
#include <migraphx/gpu/device/contiguous.hpp> #include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/nary.hpp> #include <migraphx/gpu/device/nary.hpp>
#include <hip/hip_fp16.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -21,11 +22,31 @@ void contiguous_nonstandard(hipStream_t stream, const argument& result, const ar ...@@ -21,11 +22,31 @@ void contiguous_nonstandard(hipStream_t stream, const argument& result, const ar
void contiguous_packed(hipStream_t stream, const argument& result, const argument& arg) void contiguous_packed(hipStream_t stream, const argument& result, const argument& arg)
{ {
index_int nelements = result.get_shape().elements(); index_int nelements = result.get_shape().elements();
visit_all(result, arg)([&](auto output_v, auto input_v) { auto type = result.get_shape().type();
const auto* input = device_cast(input_v.data()); if (type == shape::half_type)
auto* output = device_cast(output_v.data()); {
gs_launch(stream, nelements)([=](auto i) __device__ { output[i] = input[i]; }); visit_all(result, arg)([&](auto output_v, auto input_v) {
}); const auto* input = device_cast(input_v.data());
auto* output = device_cast(output_v.data());
const __half2* input2 = reinterpret_cast<__half2*>(input_v.data());
__half2* output2 = reinterpret_cast<__half2*>(output_v.data());
gs_launch(stream, nelements / 2)([=](auto i) __device__ {
output2[i] = input2[i];
if (i == 0 and (nelements % 2) == 1)
{
output[nelements - 1] = input[nelements - 1];
}
});
});
}
else
{
visit_all(result, arg)([&](auto output_v, auto input_v) {
const auto* input = device_cast(input_v.data());
auto* output = device_cast(output_v.data());
gs_launch(stream, nelements)([=](auto i) __device__ { output[i] = input[i]; });
});
}
} }
void contiguous(hipStream_t stream, const argument& result, const argument& arg) void contiguous(hipStream_t stream, const argument& result, const argument& arg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment