Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
06efa79d
Commit
06efa79d
authored
May 04, 2023
by
xiabo
Browse files
rocm环境适配
parent
bd635011
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
15 deletions
+40
-15
mmcv/ops/conv2d_gradfix.py
mmcv/ops/conv2d_gradfix.py
+25
-7
mmcv/ops/corner_pool.py
mmcv/ops/corner_pool.py
+2
-2
mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu
mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu
+11
-5
tests/test_ops/test_filtered_lrelu.py
tests/test_ops/test_filtered_lrelu.py
+2
-1
No files found.
mmcv/ops/conv2d_gradfix.py
View file @
06efa79d
...
...
@@ -39,7 +39,7 @@ def conv2d(input: torch.Tensor,
dilation
:
Union
[
int
,
Tuple
[
int
,
...]]
=
1
,
groups
:
int
=
1
):
flag
=
True
if
torch
.
__version__
>=
'1.10.0'
:
if
digit_version
(
torch
.
__version__
)
>=
digit_version
(
'1.10.0'
)
:
warnings
.
warn
(
'Since '
'aten:cudnn_convolution_backward_weight is '
f
'not supported in torch==
{
torch
.
__version__
}
,'
...
...
@@ -282,6 +282,24 @@ def _conv2d_gradfix(
groups
=
groups
,
output_padding
=
output_padding
,
output_mask
=
[
0
,
1
,
0
])[
1
]
else
:
is_rocm_pytorch
=
False
try
:
from
torch.utils.cpp_extension
import
ROCM_HOME
is_rocm_pytorch
=
True
if
((
torch
.
version
.
hip
is
not
None
)
and
(
ROCM_HOME
is
not
None
))
else
False
except
ImportError
:
pass
name
=
''
flags
=
[]
if
is_rocm_pytorch
:
name
=
(
'aten::miopen_convolution_transpose_backward_weight'
if
transpose
else
'aten::miopen_convolution_backward_weight'
)
flags
=
[
torch
.
backends
.
cudnn
.
benchmark
,
torch
.
backends
.
cudnn
.
deterministic
]
else
:
# General case => cuDNN.
name
=
(
'aten::cudnn_convolution_transpose_backward_weight'
...
...
mmcv/ops/corner_pool.py
View file @
06efa79d
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
torch
import
Tensor
,
nn
from
mmengine.utils
import
digit_version
_mode_dict
=
{
'top'
:
0
,
'bottom'
:
1
,
'left'
:
2
,
'right'
:
3
}
...
...
@@ -70,7 +70,7 @@ class CornerPool(nn.Module):
self
.
mode
=
mode
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
if
torch
.
__version__
!=
'parrots'
and
torch
.
__version__
>=
'1.5.0'
:
if
torch
.
__version__
!=
'parrots'
and
digit_version
(
torch
.
__version__
)
>=
digit_version
(
'1.5.0'
)
:
dim
,
flip
=
self
.
cummax_dim_flip
[
self
.
mode
]
if
flip
:
x
=
x
.
flip
(
dim
)
...
...
mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu
View file @
06efa79d
...
...
@@ -1619,6 +1619,7 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
#define BUILD_FILTERED_LRELU_OP 1
#ifndef MMCV_WITH_HIP
#ifdef __GNUC__
#if __GNUC__ < 6
#undef BUILD_FILTERED_LRELU_OP
...
...
@@ -1626,10 +1627,12 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
#endif
#endif
#if CUDA_VERSION < 10020
#undef BUILD_FILTERED_LRELU_OP
#define BUILD_FILTERED_LRELU_OP 0
#endif
#endif
#if BUILD_FILTERED_LRELU_OP == 1
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
int
>
filtered_lrelu_op
(
...
...
@@ -1670,9 +1673,10 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
// Figure out how much shared memory is available on the device.
int
maxSharedBytes
=
0
;
AT_CUDA_CHECK
(
cudaDeviceGetAttribute
(
&
maxSharedBytes
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
x
.
device
().
index
()));
int
result
=
cudaDeviceGetAttribute
(
&
maxSharedBytes
,
// cudaDevAttrMaxSharedMemoryPerBlockOptin,
hipDeviceAttributeSharedMemPerBlockOptin
,
x
.
device
().
index
());
int
sharedKB
=
maxSharedBytes
>>
10
;
// Populate enough launch parameters to check if a CUDA kernel exists.
...
...
@@ -1890,8 +1894,10 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
// Set cache and shared memory configurations for main kernel.
AT_CUDA_CHECK
(
cudaFuncSetCacheConfig
(
spec
.
exec
,
cudaFuncCachePreferShared
));
if
(
spec
.
dynamicSharedKB
)
// Need dynamically allocated shared memory?
AT_CUDA_CHECK
(
cudaFuncSetAttribute
(
spec
.
exec
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
// AT_CUDA_CHECK(cudaFuncSetAttribute(
AT_CUDA_CHECK
(
hipFuncSetAttribute
(
// spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize,
spec
.
exec
,
hipFuncAttributeMaxDynamicSharedMemorySize
,
spec
.
dynamicSharedKB
<<
10
));
AT_CUDA_CHECK
(
cudaFuncSetSharedMemConfig
(
spec
.
exec
,
cudaSharedMemBankSizeFourByte
));
...
...
tests/test_ops/test_filtered_lrelu.py
View file @
06efa79d
...
...
@@ -116,7 +116,8 @@ class TestFilteredLrelu:
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
digit_version
(
torch
.
version
.
cuda
)
<
digit_version
(
'10.2'
),
# or digit_version(torch.version.cuda) < digit_version('10.2'),
or
False
,
reason
=
'requires cuda>=10.2'
)
def
test_filtered_lrelu_cuda
(
self
):
out
=
filtered_lrelu
(
self
.
input_tensor
.
cuda
(),
bias
=
self
.
bias
.
cuda
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment