Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
7ff7095c
Unverified
Commit
7ff7095c
authored
May 31, 2023
by
xiabo123
Committed by
GitHub
May 31, 2023
Browse files
[Fix] Fix the support for ROCm (#2811)
parent
3269278e
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
108 additions
and
16 deletions
+108
-16
mmcv/ops/conv2d_gradfix.py
mmcv/ops/conv2d_gradfix.py
+20
-10
mmcv/ops/corner_pool.py
mmcv/ops/corner_pool.py
+3
-1
mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu
mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu
+6
-0
mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu
mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu
+71
-4
mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu
mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu
+6
-0
tests/test_ops/test_filtered_lrelu.py
tests/test_ops/test_filtered_lrelu.py
+2
-1
No files found.
mmcv/ops/conv2d_gradfix.py
View file @
7ff7095c
...
@@ -16,6 +16,7 @@ from typing import Dict, Optional, Tuple, Union
...
@@ -16,6 +16,7 @@ from typing import Dict, Optional, Tuple, Union
import
torch
import
torch
from
mmengine.utils
import
digit_version
from
mmengine.utils
import
digit_version
from
mmengine.utils.dl_utils.parrots_wrapper
import
is_rocm_pytorch
enabled
=
True
enabled
=
True
weight_gradients_disabled
=
False
weight_gradients_disabled
=
False
...
@@ -39,7 +40,7 @@ def conv2d(input: torch.Tensor,
...
@@ -39,7 +40,7 @@ def conv2d(input: torch.Tensor,
dilation
:
Union
[
int
,
Tuple
[
int
,
...]]
=
1
,
dilation
:
Union
[
int
,
Tuple
[
int
,
...]]
=
1
,
groups
:
int
=
1
):
groups
:
int
=
1
):
flag
=
True
flag
=
True
if
torch
.
__version__
>=
'1.10.0'
:
if
digit_version
(
torch
.
__version__
)
>=
digit_version
(
'1.10.0'
)
:
warnings
.
warn
(
'Since '
warnings
.
warn
(
'Since '
'aten:cudnn_convolution_backward_weight is '
'aten:cudnn_convolution_backward_weight is '
f
'not supported in torch==
{
torch
.
__version__
}
,'
f
'not supported in torch==
{
torch
.
__version__
}
,'
...
@@ -282,6 +283,15 @@ def _conv2d_gradfix(
...
@@ -282,6 +283,15 @@ def _conv2d_gradfix(
groups
=
groups
,
groups
=
groups
,
output_padding
=
output_padding
,
output_padding
=
output_padding
,
output_mask
=
[
0
,
1
,
0
])[
1
]
output_mask
=
[
0
,
1
,
0
])[
1
]
else
:
if
is_rocm_pytorch
():
name
=
'aten::miopen_convolution_transpose_backward_weight'
if
not
transpose
:
name
=
'aten::miopen_convolution_backward_weight'
flags
=
[
torch
.
backends
.
cudnn
.
benchmark
,
torch
.
backends
.
cudnn
.
deterministic
]
else
:
else
:
# General case => cuDNN.
# General case => cuDNN.
name
=
(
'aten::cudnn_convolution_transpose_backward_weight'
name
=
(
'aten::cudnn_convolution_transpose_backward_weight'
...
...
mmcv/ops/corner_pool.py
View file @
7ff7095c
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch
from
mmengine.utils
import
digit_version
from
torch
import
Tensor
,
nn
from
torch
import
Tensor
,
nn
_mode_dict
=
{
'top'
:
0
,
'bottom'
:
1
,
'left'
:
2
,
'right'
:
3
}
_mode_dict
=
{
'top'
:
0
,
'bottom'
:
1
,
'left'
:
2
,
'right'
:
3
}
...
@@ -70,7 +71,8 @@ class CornerPool(nn.Module):
...
@@ -70,7 +71,8 @@ class CornerPool(nn.Module):
self
.
mode
=
mode
self
.
mode
=
mode
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
if
torch
.
__version__
!=
'parrots'
and
torch
.
__version__
>=
'1.5.0'
:
if
(
torch
.
__version__
!=
'parrots'
and
digit_version
(
torch
.
__version__
)
>=
digit_version
(
'1.5.0'
)):
dim
,
flip
=
self
.
cummax_dim_flip
[
self
.
mode
]
dim
,
flip
=
self
.
cummax_dim_flip
[
self
.
mode
]
if
flip
:
if
flip
:
x
=
x
.
flip
(
dim
)
x
=
x
.
flip
(
dim
)
...
...
mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu
View file @
7ff7095c
...
@@ -289,7 +289,13 @@ torch::Tensor bias_act_op(const torch::Tensor &x, const torch::Tensor &b,
...
@@ -289,7 +289,13 @@ torch::Tensor bias_act_op(const torch::Tensor &x, const torch::Tensor &b,
int
blockSize
=
4
*
32
;
int
blockSize
=
4
*
32
;
int
gridSize
=
(
p
.
sizeX
-
1
)
/
(
p
.
loopX
*
blockSize
)
+
1
;
int
gridSize
=
(
p
.
sizeX
-
1
)
/
(
p
.
loopX
*
blockSize
)
+
1
;
void
*
args
[]
=
{
&
p
};
void
*
args
[]
=
{
&
p
};
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK
(
hipLaunchKernel
(
kernel
,
gridSize
,
blockSize
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
cudaLaunchKernel
(
kernel
,
gridSize
,
blockSize
,
args
,
0
,
AT_CUDA_CHECK
(
cudaLaunchKernel
(
kernel
,
gridSize
,
blockSize
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
return
y
;
return
y
;
}
}
mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu
View file @
7ff7095c
...
@@ -672,8 +672,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
...
@@ -672,8 +672,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
// Combine signs.
uint32_t
s
=
sx
+
sy
+
sw
+
sz
;
uint32_t
s
=
sx
+
sy
+
sw
+
sz
;
s
<<=
(
signX
&
3
)
<<
1
;
s
<<=
(
signX
&
3
)
<<
1
;
#ifdef MMCV_WITH_HIP
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
2
);
#else
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#endif
// Write signs.
// Write signs.
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
...
@@ -720,8 +725,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
...
@@ -720,8 +725,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
// Combine signs.
uint32_t
s
=
sx
+
sy
+
sw
+
sz
;
uint32_t
s
=
sx
+
sy
+
sw
+
sz
;
s
<<=
(
signX
&
3
)
<<
1
;
s
<<=
(
signX
&
3
)
<<
1
;
#ifdef MMCV_WITH_HIP
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
2
);
#else
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#endif
// Write signs.
// Write signs.
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
...
@@ -852,8 +862,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
...
@@ -852,8 +862,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
// Combine signs.
int
s
=
sx
+
sy
;
int
s
=
sx
+
sy
;
s
<<=
signXo
;
s
<<=
signXo
;
#ifdef MMCV_WITH_HIP
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
2
);
#else
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#endif
// Write signs.
// Write signs.
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
...
@@ -882,8 +897,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
...
@@ -882,8 +897,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
// Combine signs.
int
s
=
sx
+
sy
;
int
s
=
sx
+
sy
;
s
<<=
signXo
;
s
<<=
signXo
;
#ifdef MMCV_WITH_HIP
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
2
);
#else
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#endif
// Write signs.
// Write signs.
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
...
@@ -1171,8 +1191,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
...
@@ -1171,8 +1191,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
}
}
if
((
uint32_t
)
signXb
<
p
.
swLimit
&&
if
((
uint32_t
)
signXb
<
p
.
swLimit
&&
(
uint32_t
)
signY
<
p
.
sShape
.
y
&&
signY
>=
minY
)
{
(
uint32_t
)
signY
<
p
.
sShape
.
y
&&
signY
>=
minY
)
{
#ifdef MMCV_WITH_HIP
s
+=
__shfl_xor
(
s
,
1
);
// Coalesce.
s
+=
__shfl_xor
(
s
,
2
);
// Coalesce.
#else
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
// Coalesce.
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
// Coalesce.
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
// Coalesce.
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
// Coalesce.
#endif
p
.
s
[
si
]
=
s
;
// Write.
p
.
s
[
si
]
=
s
;
// Write.
}
}
}
else
{
}
else
{
...
@@ -1189,8 +1214,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
...
@@ -1189,8 +1214,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
s
=
signXbit
*
2
;
s
=
signXbit
*
2
;
v
=
InternalType
<
T
>::
clamp
(
v
,
p
.
clamp
);
v
=
InternalType
<
T
>::
clamp
(
v
,
p
.
clamp
);
}
}
#ifdef MMCV_WITH_HIP
s
+=
__shfl_xor
(
s
,
1
);
// Coalesce.
s
+=
__shfl_xor
(
s
,
2
);
// Coalesce.
#else
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
// Coalesce.
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
// Coalesce.
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
// Coalesce.
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
// Coalesce.
#endif
p
.
s
[
si
]
=
s
;
// Write.
p
.
s
[
si
]
=
s
;
// Write.
}
else
{
}
else
{
// Just compute the value.
// Just compute the value.
...
@@ -1411,10 +1441,17 @@ static __global__ void filtered_lrelu_act_kernel(
...
@@ -1411,10 +1441,17 @@ static __global__ void filtered_lrelu_act_kernel(
// Coalesce into threads 0 and 16 of warp.
// Coalesce into threads 0 and 16 of warp.
uint32_t
m
=
(
threadIdx
.
x
&
16
)
?
0xffff0000u
:
0x0000ffffu
;
uint32_t
m
=
(
threadIdx
.
x
&
16
)
?
0xffff0000u
:
0x0000ffffu
;
s
<<=
((
threadIdx
.
x
&
15
)
<<
1
);
// Shift into place.
s
<<=
((
threadIdx
.
x
&
15
)
<<
1
);
// Shift into place.
#ifdef MMCV_WITH_HIP
s
|=
__shfl_xor
(
s
,
1
);
// Distribute.
s
|=
__shfl_xor
(
s
,
2
);
s
|=
__shfl_xor
(
s
,
4
);
s
|=
__shfl_xor
(
s
,
8
);
#else
s
|=
__shfl_xor_sync
(
m
,
s
,
1
);
// Distribute.
s
|=
__shfl_xor_sync
(
m
,
s
,
1
);
// Distribute.
s
|=
__shfl_xor_sync
(
m
,
s
,
2
);
s
|=
__shfl_xor_sync
(
m
,
s
,
2
);
s
|=
__shfl_xor_sync
(
m
,
s
,
4
);
s
|=
__shfl_xor_sync
(
m
,
s
,
4
);
s
|=
__shfl_xor_sync
(
m
,
s
,
8
);
s
|=
__shfl_xor_sync
(
m
,
s
,
8
);
#endif
// Write signs if leader and in p.s.
// Write signs if leader and in p.s.
if
(
!
(
threadIdx
.
x
&
15
)
&&
x
<
p
.
sShape
.
x
)
// y is always in.
if
(
!
(
threadIdx
.
x
&
15
)
&&
x
<
p
.
sShape
.
x
)
// y is always in.
...
@@ -1586,6 +1623,7 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
...
@@ -1586,6 +1623,7 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
#define BUILD_FILTERED_LRELU_OP 1
#define BUILD_FILTERED_LRELU_OP 1
#ifndef MMCV_WITH_HIP
#ifdef __GNUC__
#ifdef __GNUC__
#if __GNUC__ < 6
#if __GNUC__ < 6
#undef BUILD_FILTERED_LRELU_OP
#undef BUILD_FILTERED_LRELU_OP
...
@@ -1597,6 +1635,7 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
...
@@ -1597,6 +1635,7 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
#undef BUILD_FILTERED_LRELU_OP
#undef BUILD_FILTERED_LRELU_OP
#define BUILD_FILTERED_LRELU_OP 0
#define BUILD_FILTERED_LRELU_OP 0
#endif
#endif
#endif
#if BUILD_FILTERED_LRELU_OP == 1
#if BUILD_FILTERED_LRELU_OP == 1
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
int
>
filtered_lrelu_op
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
int
>
filtered_lrelu_op
(
...
@@ -1637,9 +1676,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
...
@@ -1637,9 +1676,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
// Figure out how much shared memory is available on the device.
// Figure out how much shared memory is available on the device.
int
maxSharedBytes
=
0
;
int
maxSharedBytes
=
0
;
#ifdef MMCV_WITH_HIP
cudaDeviceGetAttribute
(
&
maxSharedBytes
,
hipDeviceAttributeSharedMemPerBlockOptin
,
x
.
device
().
index
());
#else
AT_CUDA_CHECK
(
cudaDeviceGetAttribute
(
&
maxSharedBytes
,
AT_CUDA_CHECK
(
cudaDeviceGetAttribute
(
&
maxSharedBytes
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
x
.
device
().
index
()));
x
.
device
().
index
()));
#endif
int
sharedKB
=
maxSharedBytes
>>
10
;
int
sharedKB
=
maxSharedBytes
>>
10
;
// Populate enough launch parameters to check if a CUDA kernel exists.
// Populate enough launch parameters to check if a CUDA kernel exists.
...
@@ -1837,10 +1882,14 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
...
@@ -1837,10 +1882,14 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
p
.
tilesXrep
=
0
;
p
.
tilesXrep
=
0
;
p
.
tilesXdim
=
0
;
p
.
tilesXdim
=
0
;
}
}
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK
(
hipLaunchKernel
(
spec
.
setup
,
1
,
1024
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
// Launch filter setup kernel.
// Launch filter setup kernel.
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
setup
,
1
,
1024
,
args
,
0
,
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
setup
,
1
,
1024
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
// Copy kernels to constant memory.
// Copy kernels to constant memory.
if
(
writeSigns
&&
!
readSigns
)
if
(
writeSigns
&&
!
readSigns
)
...
@@ -1853,9 +1902,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
...
@@ -1853,9 +1902,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
// Set cache and shared memory configurations for main kernel.
// Set cache and shared memory configurations for main kernel.
AT_CUDA_CHECK
(
cudaFuncSetCacheConfig
(
spec
.
exec
,
cudaFuncCachePreferShared
));
AT_CUDA_CHECK
(
cudaFuncSetCacheConfig
(
spec
.
exec
,
cudaFuncCachePreferShared
));
if
(
spec
.
dynamicSharedKB
)
// Need dynamically allocated shared memory?
if
(
spec
.
dynamicSharedKB
)
// Need dynamically allocated shared memory?
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK
(
hipFuncSetAttribute
(
spec
.
exec
,
hipFuncAttributeMaxDynamicSharedMemorySize
,
spec
.
dynamicSharedKB
<<
10
));
#else
AT_CUDA_CHECK
(
cudaFuncSetAttribute
(
AT_CUDA_CHECK
(
cudaFuncSetAttribute
(
spec
.
exec
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
spec
.
exec
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
spec
.
dynamicSharedKB
<<
10
));
spec
.
dynamicSharedKB
<<
10
));
#endif
AT_CUDA_CHECK
(
AT_CUDA_CHECK
(
cudaFuncSetSharedMemConfig
(
spec
.
exec
,
cudaSharedMemBankSizeFourByte
));
cudaFuncSetSharedMemConfig
(
spec
.
exec
,
cudaSharedMemBankSizeFourByte
));
...
@@ -1866,9 +1921,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
...
@@ -1866,9 +1921,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
{
{
p
.
blockZofs
=
zofs
;
p
.
blockZofs
=
zofs
;
int
subGz
=
std
::
min
(
maxSubGz
,
gz
-
zofs
);
int
subGz
=
std
::
min
(
maxSubGz
,
gz
-
zofs
);
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK
(
hipLaunchKernel
(
spec
.
exec
,
dim3
(
gx
,
gy
,
subGz
),
bx
,
args
,
spec
.
dynamicSharedKB
<<
10
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
exec
,
dim3
(
gx
,
gy
,
subGz
),
bx
,
args
,
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
exec
,
dim3
(
gx
,
gy
,
subGz
),
bx
,
args
,
spec
.
dynamicSharedKB
<<
10
,
spec
.
dynamicSharedKB
<<
10
,
at
::
cuda
::
getCurrentCUDAStream
()));
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
}
}
// Done.
// Done.
...
@@ -1983,7 +2044,13 @@ torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
...
@@ -1983,7 +2044,13 @@ torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
gz
=
std
::
min
(
gz
,
gmax
);
gz
=
std
::
min
(
gz
,
gmax
);
// Launch.
// Launch.
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK
(
hipLaunchKernel
(
func
,
dim3
(
gx
,
gy
,
gz
),
bx
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
cudaLaunchKernel
(
func
,
dim3
(
gx
,
gy
,
gz
),
bx
,
args
,
0
,
AT_CUDA_CHECK
(
cudaLaunchKernel
(
func
,
dim3
(
gx
,
gy
,
gz
),
bx
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
return
so
;
return
so
;
}
}
mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu
View file @
7ff7095c
...
@@ -734,7 +734,13 @@ torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy,
...
@@ -734,7 +734,13 @@ torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy,
// Launch CUDA kernel.
// Launch CUDA kernel.
void
*
args
[]
=
{
&
p
};
void
*
args
[]
=
{
&
p
};
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK
(
hipLaunchKernel
(
spec
.
kernel
,
gridSize
,
blockSize
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
kernel
,
gridSize
,
blockSize
,
args
,
0
,
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
kernel
,
gridSize
,
blockSize
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
return
y
;
return
y
;
}
}
tests/test_ops/test_filtered_lrelu.py
View file @
7ff7095c
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
pytest
import
pytest
import
torch
import
torch
from
mmengine.utils
import
digit_version
from
mmengine.utils
import
digit_version
from
mmengine.utils.dl_utils.parrots_wrapper
import
is_rocm_pytorch
from
mmcv.ops
import
filtered_lrelu
from
mmcv.ops
import
filtered_lrelu
...
@@ -115,7 +116,7 @@ class TestFilteredLrelu:
...
@@ -115,7 +116,7 @@ class TestFilteredLrelu:
assert
out
.
shape
==
(
1
,
3
,
16
,
16
)
assert
out
.
shape
==
(
1
,
3
,
16
,
16
)
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
not
torch
.
cuda
.
is_available
()
or
is_rocm_pytorch
()
or
digit_version
(
torch
.
version
.
cuda
)
<
digit_version
(
'10.2'
),
or
digit_version
(
torch
.
version
.
cuda
)
<
digit_version
(
'10.2'
),
reason
=
'requires cuda>=10.2'
)
reason
=
'requires cuda>=10.2'
)
def
test_filtered_lrelu_cuda
(
self
):
def
test_filtered_lrelu_cuda
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment