Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
a6d7dde7
Unverified
Commit
a6d7dde7
authored
Mar 24, 2023
by
q.yao
Committed by
GitHub
Mar 24, 2023
Browse files
[Fix] Fix torch2.0 dcn/mdcn symbolic (#2695)
* fix * fix lint
parent
5a45fac9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
6 deletions
+11
-6
mmcv/ops/deform_conv.py
mmcv/ops/deform_conv.py
+4
-2
mmcv/ops/modulated_deform_conv.py
mmcv/ops/modulated_deform_conv.py
+4
-2
tests/test_ops/test_onnx.py
tests/test_ops/test_onnx.py
+3
-2
No files found.
mmcv/ops/deform_conv.py
View file @
a6d7dde7
...
...
@@ -108,8 +108,10 @@ class DeformConv2dFunction(Function):
return
output
ctx
.
save_for_backward
(
input
,
offset
,
weight
)
output
=
input
.
new_empty
(
DeformConv2dFunction
.
_output_size
(
ctx
,
input
,
weight
))
output
=
input
.
new_empty
([
int
(
i
)
for
i
in
DeformConv2dFunction
.
_output_size
(
ctx
,
input
,
weight
)
])
ctx
.
bufs_
=
[
input
.
new_empty
(
0
),
input
.
new_empty
(
0
)]
# columns, ones
...
...
mmcv/ops/modulated_deform_conv.py
View file @
a6d7dde7
...
...
@@ -136,8 +136,10 @@ class ModulatedDeformConv2dFunction(Function):
ctx
,
input
,
offset
,
mask
,
weight
,
bias
)
return
output
ctx
.
save_for_backward
(
input
,
offset
,
mask
,
weight
,
bias
)
output
=
input
.
new_empty
(
ModulatedDeformConv2dFunction
.
_output_size
(
ctx
,
input
,
weight
))
output
=
input
.
new_empty
([
int
(
i
)
for
i
in
ModulatedDeformConv2dFunction
.
_output_size
(
ctx
,
input
,
weight
)
])
ctx
.
_bufs
=
[
input
.
new_empty
(
0
),
input
.
new_empty
(
0
)]
ext_module
.
modulated_deform_conv_forward
(
input
,
...
...
tests/test_ops/test_onnx.py
View file @
a6d7dde7
...
...
@@ -3,7 +3,6 @@ import os
import
numpy
as
np
import
onnx
import
onnxruntime
as
rt
import
pytest
import
torch
import
torch.nn
as
nn
...
...
@@ -37,6 +36,7 @@ class WrapFunction(nn.Module):
def
test_roialign
():
rt
=
pytest
.
importorskip
(
'onnxruntime'
)
try
:
from
mmcv.ops
import
roi_align
except
(
ImportError
,
ModuleNotFoundError
):
...
...
@@ -106,6 +106,7 @@ def test_roialign():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'test requires GPU'
)
def
test_roipool
():
rt
=
pytest
.
importorskip
(
'onnxruntime'
)
from
mmcv.ops
import
roi_pool
# roi pool config
...
...
@@ -204,7 +205,7 @@ def test_deform_conv():
from
mmcv.ops
import
DeformConv2dPack
x
=
torch
.
randn
(
1
,
2
,
4
,
4
,
device
=
'cuda'
)
_test_symbolic
(
DeformConv2dPack
(
2
,
4
,
3
,
1
,
1
).
cuda
(),
x
,
'MMCVDeformConv2d'
)
DeformConv2dPack
(
2
,
4
,
3
,
1
,
1
).
cuda
(),
(
x
,
)
,
'MMCVDeformConv2d'
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'test requires GPU'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment