Commit bd635011 authored by xiabo's avatar xiabo
Browse files

rocm环境适配

parent 9ba29737
...@@ -289,7 +289,12 @@ torch::Tensor bias_act_op(const torch::Tensor &x, const torch::Tensor &b, ...@@ -289,7 +289,12 @@ torch::Tensor bias_act_op(const torch::Tensor &x, const torch::Tensor &b,
int blockSize = 4 * 32; int blockSize = 4 * 32;
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
void *args[] = {&p}; void *args[] = {&p};
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream())); at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(hipLaunchKernel(kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream()));
#endif
return y; return y;
} }
...@@ -672,8 +672,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { ...@@ -672,8 +672,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs. // Combine signs.
uint32_t s = sx + sy + sw + sz; uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1; s <<= (signX & 3) << 1;
#ifndef MMCV_WITH_HIP
s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2); s |= __shfl_xor_sync(groupMask, s, 2);
#else
s |= __shfl_xor(s, 1);
s |= __shfl_xor(s, 2);
#endif
// Write signs. // Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { if ((uint32_t)(signY + 0) < sShapeMaxY) {
...@@ -720,9 +725,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { ...@@ -720,9 +725,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs. // Combine signs.
uint32_t s = sx + sy + sw + sz; uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1; s <<= (signX & 3) << 1;
#ifndef MMCV_WITH_HIP
s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2); s |= __shfl_xor_sync(groupMask, s, 2);
#else
s |= __shfl_xor(s, 1);
s |= __shfl_xor(s, 2);
#endif
// Write signs. // Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { if ((uint32_t)(signY + 0) < sShapeMaxY) {
p.s[si0] = (unsigned char)(s >> 0); p.s[si0] = (unsigned char)(s >> 0);
...@@ -852,9 +861,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { ...@@ -852,9 +861,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs. // Combine signs.
int s = sx + sy; int s = sx + sy;
s <<= signXo; s <<= signXo;
#ifndef MMCV_WITH_HIP
s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2); s |= __shfl_xor_sync(groupMask, s, 2);
#else
s |= __shfl_xor(s, 1);
s |= __shfl_xor(s, 2);
#endif
// Write signs. // Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { if ((uint32_t)(signY + 0) < sShapeMaxY) {
p.s[si0] = (unsigned char)(s >> 0); p.s[si0] = (unsigned char)(s >> 0);
...@@ -882,9 +895,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { ...@@ -882,9 +895,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs. // Combine signs.
int s = sx + sy; int s = sx + sy;
s <<= signXo; s <<= signXo;
#ifndef MMCV_WITH_HIP
s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2); s |= __shfl_xor_sync(groupMask, s, 2);
#else
s |= __shfl_xor(s, 1);
s |= __shfl_xor(s, 2);
#endif
// Write signs. // Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { if ((uint32_t)(signY + 0) < sShapeMaxY) {
p.s[si0] = (unsigned char)(s >> 0); p.s[si0] = (unsigned char)(s >> 0);
...@@ -1171,8 +1188,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { ...@@ -1171,8 +1188,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
} }
if ((uint32_t)signXb < p.swLimit && if ((uint32_t)signXb < p.swLimit &&
(uint32_t)signY < p.sShape.y && signY >= minY) { (uint32_t)signY < p.sShape.y && signY >= minY) {
#ifndef MMCV_WITH_HIP
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
#else
s += __shfl_xor(s, 1); // Coalesce.
s += __shfl_xor(s, 2); // Coalesce.
#endif
p.s[si] = s; // Write. p.s[si] = s; // Write.
} }
} else { } else {
...@@ -1189,8 +1211,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { ...@@ -1189,8 +1211,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
s = signXbit * 2; s = signXbit * 2;
v = InternalType<T>::clamp(v, p.clamp); v = InternalType<T>::clamp(v, p.clamp);
} }
#ifndef MMCV_WITH_HIP
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
#else
s += __shfl_xor(s, 1); // Coalesce.
s += __shfl_xor(s, 2); // Coalesce.
#endif
p.s[si] = s; // Write. p.s[si] = s; // Write.
} else { } else {
// Just compute the value. // Just compute the value.
...@@ -1411,11 +1438,17 @@ static __global__ void filtered_lrelu_act_kernel( ...@@ -1411,11 +1438,17 @@ static __global__ void filtered_lrelu_act_kernel(
// Coalesce into threads 0 and 16 of warp. // Coalesce into threads 0 and 16 of warp.
uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu; uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
s <<= ((threadIdx.x & 15) << 1); // Shift into place. s <<= ((threadIdx.x & 15) << 1); // Shift into place.
#ifndef MMCV_WITH_HIP
s |= __shfl_xor_sync(m, s, 1); // Distribute. s |= __shfl_xor_sync(m, s, 1); // Distribute.
s |= __shfl_xor_sync(m, s, 2); s |= __shfl_xor_sync(m, s, 2);
s |= __shfl_xor_sync(m, s, 4); s |= __shfl_xor_sync(m, s, 4);
s |= __shfl_xor_sync(m, s, 8); s |= __shfl_xor_sync(m, s, 8);
#else
s |= __shfl_xor(s, 1); // Distribute.
s |= __shfl_xor(s, 2);
s |= __shfl_xor(s, 4);
s |= __shfl_xor(s, 8);
#endif
// Write signs if leader and in p.s. // Write signs if leader and in p.s.
if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in. if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
{ {
...@@ -1839,9 +1872,13 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op( ...@@ -1839,9 +1872,13 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
} }
// Launch filter setup kernel. // Launch filter setup kernel.
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0,
at::cuda::getCurrentCUDAStream())); at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(hipLaunchKernel(spec.setup, 1, 1024, args, 0,
at::cuda::getCurrentCUDAStream()));
#endif
// Copy kernels to constant memory. // Copy kernels to constant memory.
if (writeSigns && !readSigns) if (writeSigns && !readSigns)
AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
...@@ -1866,9 +1903,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op( ...@@ -1866,9 +1903,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
{ {
p.blockZofs = zofs; p.blockZofs = zofs;
int subGz = std::min(maxSubGz, gz - zofs); int subGz = std::min(maxSubGz, gz - zofs);
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args,
spec.dynamicSharedKB << 10, spec.dynamicSharedKB << 10,
at::cuda::getCurrentCUDAStream())); at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(hipLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args,
spec.dynamicSharedKB << 10,
at::cuda::getCurrentCUDAStream()));
#endif
} }
// Done. // Done.
...@@ -1983,7 +2026,12 @@ torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx, ...@@ -1983,7 +2026,12 @@ torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
gz = std::min(gz, gmax); gz = std::min(gz, gmax);
// Launch. // Launch.
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0,
at::cuda::getCurrentCUDAStream())); at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(hipLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0,
at::cuda::getCurrentCUDAStream()));
#endif
return so; return so;
} }
...@@ -734,7 +734,12 @@ torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy, ...@@ -734,7 +734,12 @@ torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy,
// Launch CUDA kernel. // Launch CUDA kernel.
void *args[] = {&p}; void *args[] = {&p};
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream())); at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(hipLaunchKernel(spec.kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream()));
#endif
return y; return y;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment