Unverified Commit dbdc7267 authored by Ashish Farmer's avatar Ashish Farmer Committed by GitHub
Browse files

fix GET_THREADS() for ROCm (#2997)

parent 8e878f0f
...@@ -81,6 +81,9 @@ ...@@ -81,6 +81,9 @@
const int kMaxParallelImgs = 32; const int kMaxParallelImgs = 32;
inline unsigned int GET_THREADS() { inline unsigned int GET_THREADS() {
#ifdef __HIP_PLATFORM_HCC__
return 256;
#endif
if (at::cuda::getCurrentDeviceProperties()->major >= 6) { if (at::cuda::getCurrentDeviceProperties()->major >= 6) {
return 1024; return 1024;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment