Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
71ee2a61
"docs/vscode:/vscode.git/clone" did not exist on "a2b11de487269f81c6cdbe17ac9fa4c8c585da1b"
Unverified
Commit
71ee2a61
authored
Jan 13, 2023
by
mengpenghui
Committed by
GitHub
Jan 13, 2023
Browse files
[Enhance] Add AMP support for MLU_DCNv2 (#2548)
parent
c310d28c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
8 deletions
+16
-8
mmcv/ops/modulated_deform_conv.py
mmcv/ops/modulated_deform_conv.py
+4
-1
tests/test_ops/test_modulated_deform_conv.py
tests/test_ops/test_modulated_deform_conv.py
+12
-7
No files found.
mmcv/ops/modulated_deform_conv.py
View file @
71ee2a61
...
...
@@ -406,10 +406,13 @@ if IS_MLU_AVAILABLE:
o1
,
o2
,
mask
=
torch
.
chunk
(
out
,
3
,
dim
=
1
)
offset
=
torch
.
cat
((
o1
,
o2
),
dim
=
1
)
mask
=
torch
.
sigmoid
(
mask
)
x
=
x
.
type_as
(
offset
)
weight
=
self
.
weight
.
type_as
(
x
)
mask
=
mask
.
type_as
(
x
)
return
tv_deform_conv2d
(
x
,
offset
,
self
.
weight
,
weight
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
...
...
tests/test_ops/test_modulated_deform_conv.py
View file @
71ee2a61
...
...
@@ -74,7 +74,7 @@ class TestMdconv:
assert
numpy
.
allclose
(
dcn
.
conv_offset
.
bias
.
grad
.
cpu
().
detach
().
numpy
(),
dcn_offset_b_grad
,
1e-2
)
def
_test_amp_mdconv
(
self
,
input_dtype
=
torch
.
float
):
def
_test_amp_mdconv
(
self
,
input_dtype
=
torch
.
float
,
device
=
'cuda'
):
"""The function to test amp released on pytorch 1.6.0.
The type of input data might be torch.float or torch.half,
...
...
@@ -84,10 +84,15 @@ class TestMdconv:
Args:
input_dtype: torch.float or torch.half.
"""
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
()
and
device
==
'cuda'
:
return
from
mmcv.ops
import
ModulatedDeformConv2dPack
input
=
torch
.
tensor
(
input_t
).
cuda
().
type
(
input_dtype
)
if
device
==
'mlu'
:
from
mmcv.ops
import
\
ModulatedDeformConv2dPack_MLU
as
ModulatedDeformConv2dPack
else
:
from
mmcv.ops
import
ModulatedDeformConv2dPack
input
=
torch
.
tensor
(
input_t
).
to
(
device
).
type
(
input_dtype
)
input
.
requires_grad
=
True
dcn
=
ModulatedDeformConv2dPack
(
...
...
@@ -97,7 +102,7 @@ class TestMdconv:
stride
=
1
,
padding
=
1
,
deform_groups
=
1
,
bias
=
False
).
cuda
(
)
bias
=
False
).
to
(
device
)
dcn
.
weight
.
data
.
fill_
(
1.
)
output
=
dcn
(
input
)
output
.
sum
().
backward
()
...
...
@@ -126,5 +131,5 @@ class TestMdconv:
if
(
TORCH_VERSION
!=
'parrots'
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.6.0'
)):
with
autocast
(
enabled
=
True
):
self
.
_test_amp_mdconv
(
torch
.
float
)
self
.
_test_amp_mdconv
(
torch
.
half
)
self
.
_test_amp_mdconv
(
torch
.
float
,
device
=
device
)
self
.
_test_amp_mdconv
(
torch
.
half
,
device
=
device
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment