Unverified Commit ba07b221 authored by Aaron Enye Shi's avatar Aaron Enye Shi Committed by GitHub
Browse files

Fix HIP-Clang GPU build issues (#384)



* Fix HIP-Clang GPU build issues

Add missing device attributes for GPU functions. GPU functions must be annotated with __device__ in HIP.

* Use HIP device function max and min

* Fix clang-format-5.0 issues

* Undo change that breaks on HIP-HCC
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent a023ec19
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void acos(hipStream_t stream, const argument& result, const argument& arg) void acos(hipStream_t stream, const argument& result, const argument& arg)
{ {
nary(stream, result, arg)([](auto x) { return ::acos(to_hip_type(x)); }); nary(stream, result, arg)([](auto x) __device__ { return ::acos(to_hip_type(x)); });
} }
} // namespace device } // namespace device
......
...@@ -8,7 +8,7 @@ namespace device { ...@@ -8,7 +8,7 @@ namespace device {
void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) { return x + y; }); nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x + y; });
} }
void add(hipStream_t stream, void add(hipStream_t stream,
...@@ -17,7 +17,8 @@ void add(hipStream_t stream, ...@@ -17,7 +17,8 @@ void add(hipStream_t stream,
const argument& arg2, const argument& arg2,
const argument& arg3) const argument& arg3)
{ {
nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) { return x + y + z; }); nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z)
__device__ { return x + y + z; });
} }
} // namespace device } // namespace device
......
...@@ -13,8 +13,8 @@ void add_clip(hipStream_t stream, ...@@ -13,8 +13,8 @@ void add_clip(hipStream_t stream,
const float max, const float max,
const float min) const float min)
{ {
nary(stream, result, arg1, arg2)([max, min](auto x, auto y) { nary(stream, result, arg1, arg2)([max, min](auto x, auto y) __device__ {
return std::min<decltype(x + y)>(std::max<decltype(x)>(min, x + y), max); return ::min<decltype(x + y)>(::max<decltype(x)>(min, x + y), max);
}); });
} }
...@@ -26,8 +26,8 @@ void add_clip(hipStream_t stream, ...@@ -26,8 +26,8 @@ void add_clip(hipStream_t stream,
const float max, const float max,
const float min) const float min)
{ {
nary(stream, result, arg1, arg2, arg3)([max, min](auto x, auto y, auto z) { nary(stream, result, arg1, arg2, arg3)([max, min](auto x, auto y, auto z) __device__ {
return std::min<decltype(x + y + z)>(std::max<decltype(x)>(min, x + y + z), max); return ::min<decltype(x + y + z)>(::max<decltype(x)>(min, x + y + z), max);
}); });
} }
......
...@@ -11,8 +11,8 @@ void add_relu(hipStream_t stream, ...@@ -11,8 +11,8 @@ void add_relu(hipStream_t stream,
const argument& arg1, const argument& arg1,
const argument& arg2) const argument& arg2)
{ {
nary(stream, result, arg1, arg2)( nary(stream, result, arg1, arg2)([](auto x, auto y)
[](auto x, auto y) { return std::max<decltype(x + y)>(0, x + y); }); __device__ { return ::max<decltype(x + y)>(0, x + y); });
} }
void add_relu(hipStream_t stream, void add_relu(hipStream_t stream,
...@@ -22,7 +22,7 @@ void add_relu(hipStream_t stream, ...@@ -22,7 +22,7 @@ void add_relu(hipStream_t stream,
const argument& arg3) const argument& arg3)
{ {
nary(stream, result, arg1, arg2, arg3)( nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto y, auto z) { return std::max<decltype(x + y + z)>(0, x + y + z); }); [](auto x, auto y, auto z) __device__ { return ::max<decltype(x + y + z)>(0, x + y + z); });
} }
} // namespace device } // namespace device
......
...@@ -12,7 +12,7 @@ void add_sigmoid(hipStream_t stream, ...@@ -12,7 +12,7 @@ void add_sigmoid(hipStream_t stream,
const argument& arg2) const argument& arg2)
{ {
nary(stream, result, arg1, arg2)( nary(stream, result, arg1, arg2)(
[](auto x, auto y) { return 1.f / (1.f + ::exp(to_hip_type(-(x + y)))); }); [](auto x, auto y) __device__ { return 1.f / (1.f + ::exp(to_hip_type(-(x + y)))); });
} }
void add_sigmoid(hipStream_t stream, void add_sigmoid(hipStream_t stream,
...@@ -21,8 +21,9 @@ void add_sigmoid(hipStream_t stream, ...@@ -21,8 +21,9 @@ void add_sigmoid(hipStream_t stream,
const argument& arg2, const argument& arg2,
const argument& arg3) const argument& arg3)
{ {
nary(stream, result, arg1, arg2, arg3)( nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) __device__ {
[](auto x, auto y, auto z) { return 1.f / (1.f + ::exp(to_hip_type(-(x + y + z)))); }); return 1.f / (1.f + ::exp(to_hip_type(-(x + y + z))));
});
} }
} // namespace device } // namespace device
......
...@@ -11,7 +11,8 @@ void add_tanh(hipStream_t stream, ...@@ -11,7 +11,8 @@ void add_tanh(hipStream_t stream,
const argument& arg1, const argument& arg1,
const argument& arg2) const argument& arg2)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) { return ::tanh(to_hip_type(x + y)); }); nary(stream, result, arg1, arg2)([](auto x, auto y)
__device__ { return ::tanh(to_hip_type(x + y)); });
} }
void add_tanh(hipStream_t stream, void add_tanh(hipStream_t stream,
...@@ -21,7 +22,7 @@ void add_tanh(hipStream_t stream, ...@@ -21,7 +22,7 @@ void add_tanh(hipStream_t stream,
const argument& arg3) const argument& arg3)
{ {
nary(stream, result, arg1, arg2, arg3)( nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto y, auto z) { return ::tanh(to_hip_type(x + y + z)); }); [](auto x, auto y, auto z) __device__ { return ::tanh(to_hip_type(x + y + z)); });
} }
} // namespace device } // namespace device
......
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void asin(hipStream_t stream, const argument& result, const argument& arg) void asin(hipStream_t stream, const argument& result, const argument& arg)
{ {
nary(stream, result, arg)([](auto x) { return ::asin(to_hip_type(x)); }); nary(stream, result, arg)([](auto x) __device__ { return ::asin(to_hip_type(x)); });
} }
} // namespace device } // namespace device
......
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void atan(hipStream_t stream, const argument& result, const argument& arg) void atan(hipStream_t stream, const argument& result, const argument& arg)
{ {
nary(stream, result, arg)([](auto x) { return ::atan(to_hip_type(x)); }); nary(stream, result, arg)([](auto x) __device__ { return ::atan(to_hip_type(x)); });
} }
} // namespace device } // namespace device
......
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void ceil(hipStream_t stream, const argument& result, const argument& arg) void ceil(hipStream_t stream, const argument& result, const argument& arg)
{ {
nary(stream, result, arg)([](auto x) { return ::ceil(to_hip_type(x)); }); nary(stream, result, arg)([](auto x) __device__ { return ::ceil(to_hip_type(x)); });
} }
} // namespace device } // namespace device
......
...@@ -12,8 +12,9 @@ void clip(hipStream_t stream, ...@@ -12,8 +12,9 @@ void clip(hipStream_t stream,
const float max, const float max,
const float min) const float min)
{ {
nary(stream, result, arg1)( nary(stream, result, arg1)([max, min](auto x) __device__ {
[max, min](auto x) { return std::min<decltype(x)>(std::max<decltype(x)>(min, x), max); }); return ::min<decltype(x)>(::max<decltype(x)>(min, x), max);
});
} }
} // namespace device } // namespace device
......
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void contiguous(hipStream_t stream, argument result, argument arg) void contiguous(hipStream_t stream, argument result, argument arg)
{ {
nary(stream, std::move(result), std::move(arg))([](auto x) { return x; }); nary(stream, std::move(result), std::move(arg))([](auto x) __device__ { return x; });
} }
} // namespace device } // namespace device
......
...@@ -12,8 +12,8 @@ void convert(hipStream_t stream, const argument& result, const argument& arg) ...@@ -12,8 +12,8 @@ void convert(hipStream_t stream, const argument& result, const argument& arg)
arg.visit([&](auto input) { arg.visit([&](auto input) {
const auto* input_ptr = device_cast(input.data()); const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
gs_launch(stream, gs_launch(stream, result.get_shape().elements())(
result.get_shape().elements())([=](auto i) { output_ptr[i] = input_ptr[i]; }); [=](auto i) __device__ { output_ptr[i] = input_ptr[i]; });
}); });
}); });
} }
......
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void cos(hipStream_t stream, const argument& result, const argument& arg) void cos(hipStream_t stream, const argument& result, const argument& arg)
{ {
nary(stream, result, arg)([](auto x) { return ::cos(to_hip_type(x)); }); nary(stream, result, arg)([](auto x) __device__ { return ::cos(to_hip_type(x)); });
} }
} // namespace device } // namespace device
......
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void cosh(hipStream_t stream, const argument& result, const argument& arg) void cosh(hipStream_t stream, const argument& result, const argument& arg)
{ {
nary(stream, result, arg)([](auto x) { return ::cosh(to_hip_type(x)); }); nary(stream, result, arg)([](auto x) __device__ { return ::cosh(to_hip_type(x)); });
} }
} // namespace device } // namespace device
......
...@@ -8,7 +8,7 @@ namespace device { ...@@ -8,7 +8,7 @@ namespace device {
void div(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) void div(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) { return x / y; }); nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x / y; });
} }
} // namespace device } // namespace device
......
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void erf(hipStream_t stream, const argument& result, const argument& arg) void erf(hipStream_t stream, const argument& result, const argument& arg)
{ {
nary(stream, result, arg)([](auto x) { return ::erf(to_hip_type(x)); }); nary(stream, result, arg)([](auto x) __device__ { return ::erf(to_hip_type(x)); });
} }
} // namespace device } // namespace device
......
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void exp(hipStream_t stream, const argument& result, const argument& arg) void exp(hipStream_t stream, const argument& result, const argument& arg)
{ {
nary(stream, result, arg)([](auto x) { return ::exp(to_hip_type(x)); }); nary(stream, result, arg)([](auto x) __device__ { return ::exp(to_hip_type(x)); });
} }
} // namespace device } // namespace device
......
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void floor(hipStream_t stream, const argument& result, const argument& arg) void floor(hipStream_t stream, const argument& result, const argument& arg)
{ {
nary(stream, result, arg)([](auto x) { return ::floor(to_hip_type(x)); }); nary(stream, result, arg)([](auto x) __device__ { return ::floor(to_hip_type(x)); });
} }
} // namespace device } // namespace device
......
...@@ -25,7 +25,7 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg ...@@ -25,7 +25,7 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg
arg2.visit([&](auto indices) { arg2.visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data()); const auto* indices_ptr = device_cast(indices.data());
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
gs_launch(stream, nelements, 256)([=](auto i) { gs_launch(stream, nelements, 256)([=](auto i) __device__ {
auto idx = out_comp.multi(i); auto idx = out_comp.multi(i);
auto in_index = indices_ptr[idx[axis_index]]; auto in_index = indices_ptr[idx[axis_index]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
......
...@@ -78,8 +78,9 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024) ...@@ -78,8 +78,9 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024)
index_int nglobal = std::min<index_int>(256, groups) * local; index_int nglobal = std::min<index_int>(256, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, local)( launch(stream, nglobal, local)([=](auto idx) __device__ {
[=](auto idx) { idx.global_stride(n, [&](auto i) { gs_invoke(f, i, idx); }); }); idx.global_stride(n, [&](auto i) { gs_invoke(f, i, idx); });
});
}; };
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment