Commit bd635011 authored by xiabo's avatar xiabo
Browse files

rocm环境适配

parent 9ba29737
......@@ -289,7 +289,12 @@ torch::Tensor bias_act_op(const torch::Tensor &x, const torch::Tensor &b,
int blockSize = 4 * 32;
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
void *args[] = {&p};
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(hipLaunchKernel(kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream()));
#endif
return y;
}
......@@ -672,8 +672,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1;
#ifndef MMCV_WITH_HIP
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
#else
s |= __shfl_xor(s, 1);
s |= __shfl_xor(s, 2);
#endif
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) {
......@@ -720,9 +725,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1;
#ifndef MMCV_WITH_HIP
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
#else
s |= __shfl_xor(s, 1);
s |= __shfl_xor(s, 2);
#endif
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) {
p.s[si0] = (unsigned char)(s >> 0);
......@@ -852,9 +861,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
int s = sx + sy;
s <<= signXo;
#ifndef MMCV_WITH_HIP
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
#else
s |= __shfl_xor(s, 1);
s |= __shfl_xor(s, 2);
#endif
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) {
p.s[si0] = (unsigned char)(s >> 0);
......@@ -882,9 +895,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
int s = sx + sy;
s <<= signXo;
#ifndef MMCV_WITH_HIP
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
#else
s |= __shfl_xor(s, 1);
s |= __shfl_xor(s, 2);
#endif
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) {
p.s[si0] = (unsigned char)(s >> 0);
......@@ -1171,8 +1188,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
}
if ((uint32_t)signXb < p.swLimit &&
(uint32_t)signY < p.sShape.y && signY >= minY) {
#ifndef MMCV_WITH_HIP
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
#else
s += __shfl_xor(s, 1); // Coalesce.
s += __shfl_xor(s, 2); // Coalesce.
#endif
p.s[si] = s; // Write.
}
} else {
......@@ -1189,8 +1211,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
s = signXbit * 2;
v = InternalType<T>::clamp(v, p.clamp);
}
#ifndef MMCV_WITH_HIP
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
#else
s += __shfl_xor(s, 1); // Coalesce.
s += __shfl_xor(s, 2); // Coalesce.
#endif
p.s[si] = s; // Write.
} else {
// Just compute the value.
......@@ -1411,11 +1438,17 @@ static __global__ void filtered_lrelu_act_kernel(
// Coalesce into threads 0 and 16 of warp.
uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
s <<= ((threadIdx.x & 15) << 1); // Shift into place.
#ifndef MMCV_WITH_HIP
s |= __shfl_xor_sync(m, s, 1); // Distribute.
s |= __shfl_xor_sync(m, s, 2);
s |= __shfl_xor_sync(m, s, 4);
s |= __shfl_xor_sync(m, s, 8);
#else
s |= __shfl_xor(s, 1); // Distribute.
s |= __shfl_xor(s, 2);
s |= __shfl_xor(s, 4);
s |= __shfl_xor(s, 8);
#endif
// Write signs if leader and in p.s.
if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
{
......@@ -1839,9 +1872,13 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
}
// Launch filter setup kernel.
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0,
at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(hipLaunchKernel(spec.setup, 1, 1024, args, 0,
at::cuda::getCurrentCUDAStream()));
#endif
// Copy kernels to constant memory.
if (writeSigns && !readSigns)
AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
......@@ -1866,9 +1903,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
{
p.blockZofs = zofs;
int subGz = std::min(maxSubGz, gz - zofs);
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args,
spec.dynamicSharedKB << 10,
at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(hipLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args,
spec.dynamicSharedKB << 10,
at::cuda::getCurrentCUDAStream()));
#endif
}
// Done.
......@@ -1983,7 +2026,12 @@ torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
gz = std::min(gz, gmax);
// Launch.
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0,
at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(hipLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0,
at::cuda::getCurrentCUDAStream()));
#endif
return so;
}
......@@ -734,7 +734,12 @@ torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy,
// Launch CUDA kernel.
void *args[] = {&p};
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(hipLaunchKernel(spec.kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream()));
#endif
return y;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment