Unverified Commit 7ff7095c authored by xiabo123's avatar xiabo123 Committed by GitHub
Browse files

[Fix] Fix the support for ROCm (#2811)

parent 3269278e
...@@ -16,6 +16,7 @@ from typing import Dict, Optional, Tuple, Union ...@@ -16,6 +16,7 @@ from typing import Dict, Optional, Tuple, Union
import torch import torch
from mmengine.utils import digit_version from mmengine.utils import digit_version
from mmengine.utils.dl_utils.parrots_wrapper import is_rocm_pytorch
enabled = True enabled = True
weight_gradients_disabled = False weight_gradients_disabled = False
...@@ -39,7 +40,7 @@ def conv2d(input: torch.Tensor, ...@@ -39,7 +40,7 @@ def conv2d(input: torch.Tensor,
dilation: Union[int, Tuple[int, ...]] = 1, dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1): groups: int = 1):
flag = True flag = True
if torch.__version__ >= '1.10.0': if digit_version(torch.__version__) >= digit_version('1.10.0'):
warnings.warn('Since ' warnings.warn('Since '
'aten:cudnn_convolution_backward_weight is ' 'aten:cudnn_convolution_backward_weight is '
f'not supported in torch=={torch.__version__},' f'not supported in torch=={torch.__version__},'
...@@ -283,15 +284,24 @@ def _conv2d_gradfix( ...@@ -283,15 +284,24 @@ def _conv2d_gradfix(
output_padding=output_padding, output_padding=output_padding,
output_mask=[0, 1, 0])[1] output_mask=[0, 1, 0])[1]
else: else:
# General case => cuDNN. if is_rocm_pytorch():
name = ('aten::cudnn_convolution_transpose_backward_weight' name = 'aten::miopen_convolution_transpose_backward_weight'
if transpose else if not transpose:
'aten::cudnn_convolution_backward_weight') name = 'aten::miopen_convolution_backward_weight'
flags = [ flags = [
torch.backends.cudnn.benchmark, torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic, torch.backends.cudnn.deterministic
torch.backends.cudnn.allow_tf32 ]
] else:
# General case => cuDNN.
name = ('aten::cudnn_convolution_transpose_backward_weight'
if transpose else
'aten::cudnn_convolution_backward_weight')
flags = [
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32
]
return torch._C._jit_get_operation(name)(weight_shape, return torch._C._jit_get_operation(name)(weight_shape,
grad_output, input, grad_output, input,
padding, stride, padding, stride,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmengine.utils import digit_version
from torch import Tensor, nn from torch import Tensor, nn
_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3} _mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}
...@@ -70,7 +71,8 @@ class CornerPool(nn.Module): ...@@ -70,7 +71,8 @@ class CornerPool(nn.Module):
self.mode = mode self.mode = mode
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0': if (torch.__version__ != 'parrots' and
digit_version(torch.__version__) >= digit_version('1.5.0')):
dim, flip = self.cummax_dim_flip[self.mode] dim, flip = self.cummax_dim_flip[self.mode]
if flip: if flip:
x = x.flip(dim) x = x.flip(dim)
......
...@@ -289,7 +289,13 @@ torch::Tensor bias_act_op(const torch::Tensor &x, const torch::Tensor &b, ...@@ -289,7 +289,13 @@ torch::Tensor bias_act_op(const torch::Tensor &x, const torch::Tensor &b,
int blockSize = 4 * 32; int blockSize = 4 * 32;
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
void *args[] = {&p}; void *args[] = {&p};
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK(hipLaunchKernel(kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream())); at::cuda::getCurrentCUDAStream()));
#endif
return y; return y;
} }
...@@ -672,8 +672,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { ...@@ -672,8 +672,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs. // Combine signs.
uint32_t s = sx + sy + sw + sz; uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1; s <<= (signX & 3) << 1;
#ifdef MMCV_WITH_HIP
s |= __shfl_xor(s, 1);
s |= __shfl_xor(s, 2);
#else
s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2); s |= __shfl_xor_sync(groupMask, s, 2);
#endif
// Write signs. // Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { if ((uint32_t)(signY + 0) < sShapeMaxY) {
...@@ -720,8 +725,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { ...@@ -720,8 +725,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs. // Combine signs.
uint32_t s = sx + sy + sw + sz; uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1; s <<= (signX & 3) << 1;
#ifdef MMCV_WITH_HIP
s |= __shfl_xor(s, 1);
s |= __shfl_xor(s, 2);
#else
s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2); s |= __shfl_xor_sync(groupMask, s, 2);
#endif
// Write signs. // Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { if ((uint32_t)(signY + 0) < sShapeMaxY) {
...@@ -852,8 +862,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { ...@@ -852,8 +862,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs. // Combine signs.
int s = sx + sy; int s = sx + sy;
s <<= signXo; s <<= signXo;
#ifdef MMCV_WITH_HIP
s |= __shfl_xor(s, 1);
s |= __shfl_xor(s, 2);
#else
s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2); s |= __shfl_xor_sync(groupMask, s, 2);
#endif
// Write signs. // Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { if ((uint32_t)(signY + 0) < sShapeMaxY) {
...@@ -882,8 +897,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { ...@@ -882,8 +897,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs. // Combine signs.
int s = sx + sy; int s = sx + sy;
s <<= signXo; s <<= signXo;
#ifdef MMCV_WITH_HIP
s |= __shfl_xor(s, 1);
s |= __shfl_xor(s, 2);
#else
s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2); s |= __shfl_xor_sync(groupMask, s, 2);
#endif
// Write signs. // Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { if ((uint32_t)(signY + 0) < sShapeMaxY) {
...@@ -1171,9 +1191,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { ...@@ -1171,9 +1191,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
} }
if ((uint32_t)signXb < p.swLimit && if ((uint32_t)signXb < p.swLimit &&
(uint32_t)signY < p.sShape.y && signY >= minY) { (uint32_t)signY < p.sShape.y && signY >= minY) {
#ifdef MMCV_WITH_HIP
s += __shfl_xor(s, 1); // Coalesce.
s += __shfl_xor(s, 2); // Coalesce.
#else
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
p.s[si] = s; // Write. #endif
p.s[si] = s; // Write.
} }
} else { } else {
// Determine and write sign. // Determine and write sign.
...@@ -1189,9 +1214,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { ...@@ -1189,9 +1214,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
s = signXbit * 2; s = signXbit * 2;
v = InternalType<T>::clamp(v, p.clamp); v = InternalType<T>::clamp(v, p.clamp);
} }
#ifdef MMCV_WITH_HIP
s += __shfl_xor(s, 1); // Coalesce.
s += __shfl_xor(s, 2); // Coalesce.
#else
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
p.s[si] = s; // Write. #endif
p.s[si] = s; // Write.
} else { } else {
// Just compute the value. // Just compute the value.
if (v < 0.f) v *= p.slope; if (v < 0.f) v *= p.slope;
...@@ -1411,10 +1441,17 @@ static __global__ void filtered_lrelu_act_kernel( ...@@ -1411,10 +1441,17 @@ static __global__ void filtered_lrelu_act_kernel(
// Coalesce into threads 0 and 16 of warp. // Coalesce into threads 0 and 16 of warp.
uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu; uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
s <<= ((threadIdx.x & 15) << 1); // Shift into place. s <<= ((threadIdx.x & 15) << 1); // Shift into place.
s |= __shfl_xor_sync(m, s, 1); // Distribute. #ifdef MMCV_WITH_HIP
s |= __shfl_xor(s, 1); // Distribute.
s |= __shfl_xor(s, 2);
s |= __shfl_xor(s, 4);
s |= __shfl_xor(s, 8);
#else
s |= __shfl_xor_sync(m, s, 1); // Distribute.
s |= __shfl_xor_sync(m, s, 2); s |= __shfl_xor_sync(m, s, 2);
s |= __shfl_xor_sync(m, s, 4); s |= __shfl_xor_sync(m, s, 4);
s |= __shfl_xor_sync(m, s, 8); s |= __shfl_xor_sync(m, s, 8);
#endif
// Write signs if leader and in p.s. // Write signs if leader and in p.s.
if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in. if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
...@@ -1586,6 +1623,7 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel( ...@@ -1586,6 +1623,7 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
#define BUILD_FILTERED_LRELU_OP 1 #define BUILD_FILTERED_LRELU_OP 1
#ifndef MMCV_WITH_HIP
#ifdef __GNUC__ #ifdef __GNUC__
#if __GNUC__ < 6 #if __GNUC__ < 6
#undef BUILD_FILTERED_LRELU_OP #undef BUILD_FILTERED_LRELU_OP
...@@ -1597,6 +1635,7 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel( ...@@ -1597,6 +1635,7 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
#undef BUILD_FILTERED_LRELU_OP #undef BUILD_FILTERED_LRELU_OP
#define BUILD_FILTERED_LRELU_OP 0 #define BUILD_FILTERED_LRELU_OP 0
#endif #endif
#endif
#if BUILD_FILTERED_LRELU_OP == 1 #if BUILD_FILTERED_LRELU_OP == 1
std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op( std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
...@@ -1637,9 +1676,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op( ...@@ -1637,9 +1676,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
// Figure out how much shared memory is available on the device. // Figure out how much shared memory is available on the device.
int maxSharedBytes = 0; int maxSharedBytes = 0;
#ifdef MMCV_WITH_HIP
cudaDeviceGetAttribute(&maxSharedBytes,
hipDeviceAttributeSharedMemPerBlockOptin,
x.device().index());
#else
AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes,
cudaDevAttrMaxSharedMemoryPerBlockOptin, cudaDevAttrMaxSharedMemoryPerBlockOptin,
x.device().index())); x.device().index()));
#endif
int sharedKB = maxSharedBytes >> 10; int sharedKB = maxSharedBytes >> 10;
// Populate enough launch parameters to check if a CUDA kernel exists. // Populate enough launch parameters to check if a CUDA kernel exists.
...@@ -1837,10 +1882,14 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op( ...@@ -1837,10 +1882,14 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
p.tilesXrep = 0; p.tilesXrep = 0;
p.tilesXdim = 0; p.tilesXdim = 0;
} }
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK(hipLaunchKernel(spec.setup, 1, 1024, args, 0,
at::cuda::getCurrentCUDAStream()));
#else
// Launch filter setup kernel. // Launch filter setup kernel.
AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0,
at::cuda::getCurrentCUDAStream())); at::cuda::getCurrentCUDAStream()));
#endif
// Copy kernels to constant memory. // Copy kernels to constant memory.
if (writeSigns && !readSigns) if (writeSigns && !readSigns)
...@@ -1853,9 +1902,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op( ...@@ -1853,9 +1902,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
// Set cache and shared memory configurations for main kernel. // Set cache and shared memory configurations for main kernel.
AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK(hipFuncSetAttribute(
spec.exec, hipFuncAttributeMaxDynamicSharedMemorySize,
spec.dynamicSharedKB << 10));
#else
AT_CUDA_CHECK(cudaFuncSetAttribute( AT_CUDA_CHECK(cudaFuncSetAttribute(
spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize,
spec.dynamicSharedKB << 10)); spec.dynamicSharedKB << 10));
#endif
AT_CUDA_CHECK( AT_CUDA_CHECK(
cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
...@@ -1866,9 +1921,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op( ...@@ -1866,9 +1921,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
{ {
p.blockZofs = zofs; p.blockZofs = zofs;
int subGz = std::min(maxSubGz, gz - zofs); int subGz = std::min(maxSubGz, gz - zofs);
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK(hipLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args,
spec.dynamicSharedKB << 10,
at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args,
spec.dynamicSharedKB << 10, spec.dynamicSharedKB << 10,
at::cuda::getCurrentCUDAStream())); at::cuda::getCurrentCUDAStream()));
#endif
} }
// Done. // Done.
...@@ -1983,7 +2044,13 @@ torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx, ...@@ -1983,7 +2044,13 @@ torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
gz = std::min(gz, gmax); gz = std::min(gz, gmax);
// Launch. // Launch.
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK(hipLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0,
at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0,
at::cuda::getCurrentCUDAStream())); at::cuda::getCurrentCUDAStream()));
#endif
return so; return so;
} }
...@@ -734,7 +734,13 @@ torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy, ...@@ -734,7 +734,13 @@ torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy,
// Launch CUDA kernel. // Launch CUDA kernel.
void *args[] = {&p}; void *args[] = {&p};
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK(hipLaunchKernel(spec.kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream()));
#else
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0,
at::cuda::getCurrentCUDAStream())); at::cuda::getCurrentCUDAStream()));
#endif
return y; return y;
} }
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import pytest import pytest
import torch import torch
from mmengine.utils import digit_version from mmengine.utils import digit_version
from mmengine.utils.dl_utils.parrots_wrapper import is_rocm_pytorch
from mmcv.ops import filtered_lrelu from mmcv.ops import filtered_lrelu
...@@ -115,7 +116,7 @@ class TestFilteredLrelu: ...@@ -115,7 +116,7 @@ class TestFilteredLrelu:
assert out.shape == (1, 3, 16, 16) assert out.shape == (1, 3, 16, 16)
@pytest.mark.skipif( @pytest.mark.skipif(
not torch.cuda.is_available() not torch.cuda.is_available() or is_rocm_pytorch()
or digit_version(torch.version.cuda) < digit_version('10.2'), or digit_version(torch.version.cuda) < digit_version('10.2'),
reason='requires cuda>=10.2') reason='requires cuda>=10.2')
def test_filtered_lrelu_cuda(self): def test_filtered_lrelu_cuda(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment