Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dcnv3
Commits
861253ca
Unverified
Commit
861253ca
authored
Apr 18, 2023
by
Zeqiang Lai
Committed by
GitHub
Apr 18, 2023
Browse files
Fix DCNv3 version compatibility (#108)
parent
8f2d1583
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
14 deletions
+19
-14
classification/ops_dcnv3/functions/dcnv3_func.py
classification/ops_dcnv3/functions/dcnv3_func.py
+19
-14
No files found.
classification/ops_dcnv3/functions/dcnv3_func.py
View file @
861253ca
...
...
@@ -15,6 +15,9 @@ from torch.autograd.function import once_differentiable
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
import
DCNv3
import
pkg_resources
dcn_version
=
float
(
pkg_resources
.
get_distribution
(
'DCNv3'
).
version
)
class
DCNv3Function
(
Function
):
@
staticmethod
...
...
@@ -38,15 +41,16 @@ class DCNv3Function(Function):
ctx
.
im2col_step
=
im2col_step
ctx
.
remove_center
=
remove_center
kwargs
=
{}
if
remove_center
:
kwargs
[
'remove_center'
]
=
remove_center
output
=
DCNv3
.
dcnv3_forward
(
args
=
[
input
,
offset
,
mask
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
group_channels
,
offset_scale
,
ctx
.
im2col_step
,
**
kwargs
)
group_channels
,
offset_scale
,
ctx
.
im2col_step
]
if
remove_center
or
dcn_version
>
1.0
:
args
.
append
(
remove_center
)
output
=
DCNv3
.
dcnv3_forward
(
*
args
)
ctx
.
save_for_backward
(
input
,
offset
,
mask
)
return
output
...
...
@@ -57,16 +61,17 @@ class DCNv3Function(Function):
def
backward
(
ctx
,
grad_output
):
input
,
offset
,
mask
=
ctx
.
saved_tensors
kwargs
=
{}
if
ctx
.
remove_center
:
kwargs
[
'remove_center'
]
=
ctx
.
remove_center
grad_input
,
grad_offset
,
grad_mask
=
\
DCNv3
.
dcnv3_backward
(
args
=
[
input
,
offset
,
mask
,
ctx
.
kernel_h
,
ctx
.
kernel_w
,
ctx
.
stride_h
,
ctx
.
stride_w
,
ctx
.
pad_h
,
ctx
.
pad_w
,
ctx
.
dilation_h
,
ctx
.
dilation_w
,
ctx
.
group
,
ctx
.
group_channels
,
ctx
.
offset_scale
,
grad_output
.
contiguous
(),
ctx
.
im2col_step
,
**
kwargs
)
ctx
.
group_channels
,
ctx
.
offset_scale
,
grad_output
.
contiguous
(),
ctx
.
im2col_step
]
if
ctx
.
remove_center
or
dcn_version
>
1.0
:
args
.
append
(
ctx
.
remove_center
)
grad_input
,
grad_offset
,
grad_mask
=
\
DCNv3
.
dcnv3_backward
(
*
args
)
return
grad_input
,
grad_offset
,
grad_mask
,
\
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment