Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
wangkx1
DCNv4-main
Commits
5b17e272
Commit
5b17e272
authored
May 27, 2026
by
wangkx1
Browse files
init
parents
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4387 additions
and
0 deletions
+4387
-0
DCNv4_op/scripts/functions/dcnv4_func.py
DCNv4_op/scripts/functions/dcnv4_func.py
+129
-0
DCNv4_op/scripts/functions/flash_deform_attn_func.py
DCNv4_op/scripts/functions/flash_deform_attn_func.py
+116
-0
DCNv4_op/scripts/functions/table.py
DCNv4_op/scripts/functions/table.py
+1356
-0
DCNv4_op/scripts/search_bwd.sh
DCNv4_op/scripts/search_bwd.sh
+3
-0
DCNv4_op/scripts/search_dcnv4.py
DCNv4_op/scripts/search_dcnv4.py
+131
-0
DCNv4_op/scripts/search_dcnv4_bwd.py
DCNv4_op/scripts/search_dcnv4_bwd.py
+200
-0
DCNv4_op/scripts/search_dcnv4_bwd_engine.py
DCNv4_op/scripts/search_dcnv4_bwd_engine.py
+25
-0
DCNv4_op/scripts/search_dcnv4_engine.py
DCNv4_op/scripts/search_dcnv4_engine.py
+25
-0
DCNv4_op/scripts/search_fwd.sh
DCNv4_op/scripts/search_fwd.sh
+3
-0
DCNv4_op/scripts/test_dcnv4.py
DCNv4_op/scripts/test_dcnv4.py
+136
-0
DCNv4_op/scripts/test_dcnv4_bwd.py
DCNv4_op/scripts/test_dcnv4_bwd.py
+222
-0
DCNv4_op/scripts/test_flash_deform_attn.py
DCNv4_op/scripts/test_flash_deform_attn.py
+174
-0
DCNv4_op/scripts/test_flash_deform_attn_backward.py
DCNv4_op/scripts/test_flash_deform_attn_backward.py
+195
-0
DCNv4_op/setup.py
DCNv4_op/setup.py
+104
-0
DCNv4_op/src/cuda/common.h
DCNv4_op/src/cuda/common.h
+216
-0
DCNv4_op/src/cuda/dcnv4_col2im_cuda.cuh
DCNv4_op/src/cuda/dcnv4_col2im_cuda.cuh
+562
-0
DCNv4_op/src/cuda/dcnv4_cuda.cu
DCNv4_op/src/cuda/dcnv4_cuda.cu
+176
-0
DCNv4_op/src/cuda/dcnv4_cuda.h
DCNv4_op/src/cuda/dcnv4_cuda.h
+34
-0
DCNv4_op/src/cuda/dcnv4_im2col_cuda.cuh
DCNv4_op/src/cuda/dcnv4_im2col_cuda.cuh
+417
-0
DCNv4_op/src/cuda/flash_deform_attn_cuda.cu
DCNv4_op/src/cuda/flash_deform_attn_cuda.cu
+163
-0
No files found.
DCNv4_op/scripts/functions/dcnv4_func.py
0 → 100644
View file @
5b17e272
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
import
torch
import
math
import
torch.nn.functional
as
F
from
torch.autograd
import
Function
from
torch.autograd.function
import
once_differentiable
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
.table
import
TABLE
,
BWDTABLE
from
DCNv4
import
ext
def
factors
(
N
):
res
=
[]
for
i
in
range
(
1
,
N
+
1
):
if
N
%
i
==
0
:
res
.
append
(
i
)
return
res
def
findspec
(
B
,
H
,
W
,
G
,
C
):
key
=
f
"
{
B
}
x
{
H
}
x
{
W
}
x
{
G
}
x
{
C
}
"
if
key
in
TABLE
:
return
TABLE
[
key
][
0
],
TABLE
[
key
][
1
]
d_stride
=
8
ms
=
factors
(
B
*
H
*
W
)
multiplier
=
1
for
m
in
ms
:
if
m
<=
64
and
(
m
*
G
*
C
//
d_stride
)
<=
512
:
multiplier
=
m
n_thread
=
multiplier
*
G
*
C
//
d_stride
key
=
f
"
{
B
}
x
{
H
}
x
{
W
}
x
{
G
}
x
{
C
}
"
TABLE
[
key
]
=
(
d_stride
,
n_thread
)
return
d_stride
,
n_thread
def
find_spec_bwd
(
B
,
H
,
W
,
G
,
C
):
key
=
f
"
{
B
}
x
{
H
}
x
{
W
}
x
{
G
}
x
{
C
}
"
if
key
in
BWDTABLE
:
return
BWDTABLE
[
key
][
0
],
BWDTABLE
[
key
][
1
]
if
C
>=
64
:
d_stride
=
2
else
:
d_stride
=
1
ms
=
factors
(
B
*
H
*
W
)
multiplier
=
1
for
m
in
ms
:
if
m
<=
64
and
(
m
*
G
*
C
//
d_stride
)
<=
256
:
multiplier
=
m
n_thread
=
multiplier
*
G
*
C
//
d_stride
return
d_stride
,
n_thread
class
DCNv4Function
(
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
input
,
offset_mask
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
group_channels
,
offset_scale
,
im2col_step
,
remove_center
):
forward_d_stride
,
forward_block_thread
=
findspec
(
input
.
shape
[
0
],
input
.
shape
[
1
],
input
.
shape
[
2
],
group
,
group_channels
)
backward_d_stride
,
backward_block_thread
=
find_spec_bwd
(
input
.
shape
[
0
],
input
.
shape
[
1
],
input
.
shape
[
2
],
group
,
group_channels
)
ctx
.
kernel_h
=
kernel_h
ctx
.
kernel_w
=
kernel_w
ctx
.
stride_h
=
stride_h
ctx
.
stride_w
=
stride_w
ctx
.
pad_h
=
pad_h
ctx
.
pad_w
=
pad_w
ctx
.
dilation_h
=
dilation_h
ctx
.
dilation_w
=
dilation_w
ctx
.
group
=
group
ctx
.
group_channels
=
group_channels
ctx
.
offset_scale
=
offset_scale
ctx
.
im2col_step
=
im2col_step
ctx
.
remove_center
=
remove_center
ctx
.
backward_d_stride
=
backward_d_stride
ctx
.
backward_block_thread
=
backward_block_thread
args
=
[
input
,
offset_mask
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
group_channels
,
offset_scale
,
ctx
.
im2col_step
,
remove_center
,
forward_d_stride
,
forward_block_thread
,
False
,
]
output
=
ext
.
dcnv4_forward
(
*
args
)
ctx
.
save_for_backward
(
input
,
offset_mask
)
return
output
@
staticmethod
@
once_differentiable
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
input
,
offset_mask
=
ctx
.
saved_tensors
args
=
[
input
,
offset_mask
,
ctx
.
kernel_h
,
ctx
.
kernel_w
,
ctx
.
stride_h
,
ctx
.
stride_w
,
ctx
.
pad_h
,
ctx
.
pad_w
,
ctx
.
dilation_h
,
ctx
.
dilation_w
,
ctx
.
group
,
ctx
.
group_channels
,
ctx
.
offset_scale
,
ctx
.
im2col_step
,
grad_output
.
contiguous
(),
ctx
.
remove_center
,
ctx
.
backward_d_stride
,
ctx
.
backward_block_thread
,
False
]
grad_input
,
grad_offset_mask
=
\
ext
.
dcnv4_backward
(
*
args
)
return
grad_input
,
grad_offset_mask
,
\
None
,
None
,
None
,
None
,
None
,
None
,
None
,
\
None
,
None
,
None
,
None
,
None
,
None
DCNv4_op/scripts/functions/flash_deform_attn_func.py
0 → 100644
View file @
5b17e272
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
import
torch
import
torch.nn.functional
as
F
from
torch.autograd
import
Function
from
torch.autograd.function
import
once_differentiable
import
numpy
as
np
from
DCNv4
import
ext
shm_size_dict
=
{
"8.0"
:
163000
,
"8.6"
:
99000
,
"8.7"
:
163000
,
"8.9"
:
99000
,
"9.0"
:
227000
,
"7.5"
:
64000
,
"7.0"
:
96000
,
}
cuda_capability
=
f
"
{
torch
.
cuda
.
get_device_properties
(
0
).
major
}
.
{
torch
.
cuda
.
get_device_properties
(
0
).
minor
}
"
cuda_capability
=
"8.7"
if
cuda_capability
not
in
shm_size_dict
:
raise
NotImplementedError
shm_size_cap
=
shm_size_dict
[
cuda_capability
]
def
factors
(
N
):
res
=
[]
for
i
in
range
(
1
,
N
+
1
):
if
N
%
i
==
0
:
res
.
append
(
i
)
return
res
def
findspec
(
B
,
Q
,
G
,
C
):
d_stride
=
8
ms
=
factors
(
B
*
Q
)
multiplier
=
1
for
m
in
ms
:
if
m
<=
64
and
(
m
*
G
*
C
//
d_stride
)
<=
512
:
multiplier
=
m
n_thread
=
multiplier
*
G
*
C
//
d_stride
return
d_stride
,
n_thread
def
findspec_bwd
(
B
,
Q
,
G
,
C
):
if
C
>=
64
:
d_stride
=
2
else
:
d_stride
=
1
ms
=
factors
(
B
*
Q
)
multiplier
=
1
for
m
in
ms
:
if
m
<=
64
and
(
m
*
G
*
C
//
d_stride
)
<=
256
:
multiplier
=
m
n_thread
=
multiplier
*
G
*
C
//
d_stride
return
d_stride
,
n_thread
class
FlashDeformAttnFunction
(
Function
):
@
staticmethod
@
torch
.
autocast
(
"cuda"
,
enabled
=
True
,
dtype
=
torch
.
float16
)
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_loc_attn
,
im2col_step
,
K
=
8
):
ctx
.
im2col_step
=
im2col_step
ctx
.
K
=
K
d_stride
,
blockthread
=
findspec
(
value
.
shape
[
0
],
sampling_loc_attn
.
shape
[
1
],
value
.
shape
[
2
],
value
.
shape
[
3
])
d_stride_backward
,
blockthread_backward
=
findspec_bwd
(
value
.
shape
[
0
],
sampling_loc_attn
.
shape
[
1
],
value
.
shape
[
2
],
value
.
shape
[
3
])
ctx
.
d_stride_backward
=
d_stride_backward
ctx
.
blockthread_backward
=
blockthread_backward
output
=
ext
.
flash_deform_attn_forward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_loc_attn
,
ctx
.
im2col_step
,
K
,
d_stride
,
blockthread
,
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_loc_attn
)
return
output
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_loc_attn
=
ctx
.
saved_tensors
grad_value
,
grad_sampling_loc_attn
=
ext
.
flash_deform_attn_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_loc_attn
,
grad_output
.
contiguous
(),
ctx
.
im2col_step
,
ctx
.
K
,
ctx
.
d_stride_backward
,
ctx
.
blockthread_backward
,
)
return
grad_value
,
None
,
None
,
grad_sampling_loc_attn
,
None
,
None
DCNv4_op/scripts/functions/table.py
0 → 100644
View file @
5b17e272
TABLE
=
{
"64x56x56x4x16"
:
[
8
,
448
,
56
],
"64x28x28x4x16"
:
[
8
,
448
,
56
],
"64x14x14x4x16"
:
[
8
,
32
,
4
],
"64x7x7x4x16"
:
[
8
,
56
,
7
],
"1x200x320x4x16"
:
[
8
,
32
,
4
],
"1x100x160x4x16"
:
[
8
,
32
,
4
],
"1x50x80x4x16"
:
[
4
,
512
,
32
],
"1x25x40x4x16"
:
[
4
,
320
,
20
],
"1x64x64x4x16"
:
[
8
,
512
,
64
],
"64x56x56x5x16"
:
[
8
,
490
,
49
],
"64x28x28x5x16"
:
[
8
,
490
,
49
],
"64x14x14x5x16"
:
[
8
,
280
,
28
],
"64x7x7x5x16"
:
[
4
,
140
,
7
],
"1x200x320x5x16"
:
[
4
,
400
,
20
],
"1x100x160x5x16"
:
[
4
,
400
,
20
],
"1x50x80x5x16"
:
[
8
,
500
,
50
],
"1x25x40x5x16"
:
[
8
,
20
,
2
],
"1x64x64x5x16"
:
[
8
,
320
,
32
],
"64x56x56x6x16"
:
[
8
,
768
,
64
],
"64x28x28x6x16"
:
[
8
,
672
,
56
],
"64x14x14x6x16"
:
[
8
,
336
,
28
],
"64x7x7x6x16"
:
[
8
,
84
,
7
],
"1x200x320x6x16"
:
[
4
,
600
,
25
],
"1x100x160x6x16"
:
[
4
,
600
,
25
],
"1x50x80x6x16"
:
[
8
,
600
,
50
],
"1x25x40x6x16"
:
[
2
,
240
,
5
],
"1x64x64x6x16"
:
[
8
,
384
,
32
],
"64x56x56x7x16"
:
[
8
,
896
,
64
],
"64x28x28x7x16"
:
[
8
,
686
,
49
],
"64x14x14x7x16"
:
[
8
,
392
,
28
],
"64x7x7x7x16"
:
[
8
,
686
,
49
],
"1x200x320x7x16"
:
[
8
,
700
,
50
],
"1x100x160x7x16"
:
[
8
,
700
,
50
],
"1x50x80x7x16"
:
[
8
,
700
,
50
],
"1x25x40x7x16"
:
[
8
,
70
,
5
],
"1x64x64x7x16"
:
[
8
,
448
,
32
],
"64x56x56x8x16"
:
[
8
,
448
,
28
],
"64x28x28x8x16"
:
[
8
,
448
,
28
],
"64x14x14x8x16"
:
[
8
,
448
,
28
],
"64x7x7x8x16"
:
[
8
,
784
,
49
],
"1x200x320x8x16"
:
[
8
,
800
,
50
],
"1x100x160x8x16"
:
[
4
,
640
,
20
],
"1x50x80x8x16"
:
[
8
,
800
,
50
],
"1x25x40x8x16"
:
[
4
,
64
,
2
],
"1x64x64x8x16"
:
[
8
,
256
,
16
],
"64x56x56x4x32"
:
[
8
,
448
,
28
],
"64x28x28x4x32"
:
[
8
,
448
,
28
],
"64x14x14x4x32"
:
[
8
,
448
,
28
],
"64x7x7x4x32"
:
[
8
,
112
,
7
],
"1x200x320x4x32"
:
[
8
,
512
,
32
],
"1x100x160x4x32"
:
[
8
,
800
,
50
],
"1x50x80x4x32"
:
[
8
,
800
,
50
],
"1x25x40x4x32"
:
[
4
,
128
,
4
],
"1x64x64x4x32"
:
[
8
,
128
,
8
],
"64x56x56x5x32"
:
[
8
,
560
,
28
],
"64x28x28x5x32"
:
[
8
,
560
,
28
],
"64x14x14x5x32"
:
[
8
,
560
,
28
],
"64x7x7x5x32"
:
[
8
,
980
,
49
],
"1x200x320x5x32"
:
[
8
,
500
,
25
],
"1x100x160x5x32"
:
[
8
,
800
,
40
],
"1x50x80x5x32"
:
[
8
,
1000
,
50
],
"1x25x40x5x32"
:
[
4
,
200
,
5
],
"1x64x64x5x32"
:
[
8
,
640
,
32
],
"64x56x56x6x32"
:
[
8
,
336
,
14
],
"64x28x28x6x32"
:
[
8
,
336
,
14
],
"64x14x14x6x32"
:
[
8
,
336
,
14
],
"64x7x7x6x32"
:
[
16
,
588
,
49
],
"1x200x320x6x32"
:
[
8
,
480
,
20
],
"1x100x160x6x32"
:
[
8
,
480
,
20
],
"1x50x80x6x32"
:
[
16
,
600
,
50
],
"1x25x40x6x32"
:
[
8
,
96
,
4
],
"1x64x64x6x32"
:
[
8
,
768
,
32
],
"64x56x56x7x32"
:
[
8
,
448
,
16
],
"64x28x28x7x32"
:
[
8
,
448
,
16
],
"64x14x14x7x32"
:
[
8
,
196
,
7
],
"64x7x7x7x32"
:
[
8
,
28
,
1
],
"1x200x320x7x32"
:
[
8
,
448
,
16
],
"1x100x160x7x32"
:
[
8
,
448
,
16
],
"1x50x80x7x32"
:
[
8
,
700
,
25
],
"1x25x40x7x32"
:
[
8
,
56
,
2
],
"1x64x64x7x32"
:
[
8
,
896
,
32
],
"64x56x56x8x32"
:
[
8
,
448
,
14
],
"64x28x28x8x32"
:
[
8
,
448
,
14
],
"64x14x14x8x32"
:
[
8
,
448
,
14
],
"64x7x7x8x32"
:
[
8
,
32
,
1
],
"1x200x320x8x32"
:
[
8
,
512
,
16
],
"1x100x160x8x32"
:
[
8
,
800
,
25
],
"1x50x80x8x32"
:
[
8
,
800
,
25
],
"1x25x40x8x32"
:
[
4
,
512
,
8
],
"1x64x64x8x32"
:
[
8
,
32
,
1
],
"64x56x56x4x64"
:
[
8
,
448
,
14
],
"64x28x28x4x64"
:
[
8
,
448
,
14
],
"64x14x14x4x64"
:
[
8
,
448
,
14
],
"64x7x7x4x64"
:
[
8
,
32
,
1
],
"1x200x320x4x64"
:
[
8
,
512
,
16
],
"1x100x160x4x64"
:
[
8
,
512
,
16
],
"1x50x80x4x64"
:
[
8
,
800
,
25
],
"1x25x40x4x64"
:
[
8
,
640
,
20
],
"1x64x64x4x64"
:
[
8
,
512
,
16
],
"64x56x56x5x64"
:
[
8
,
560
,
14
],
"64x28x28x5x64"
:
[
8
,
560
,
14
],
"64x14x14x5x64"
:
[
8
,
560
,
14
],
"64x7x7x5x64"
:
[
8
,
280
,
7
],
"1x200x320x5x64"
:
[
8
,
800
,
20
],
"1x100x160x5x64"
:
[
8
,
800
,
20
],
"1x50x80x5x64"
:
[
8
,
1000
,
25
],
"1x25x40x5x64"
:
[
8
,
80
,
2
],
"1x64x64x5x64"
:
[
8
,
320
,
8
],
"64x56x56x6x64"
:
[
8
,
768
,
16
],
"64x28x28x6x64"
:
[
8
,
768
,
16
],
"64x14x14x6x64"
:
[
8
,
336
,
7
],
"64x7x7x6x64"
:
[
8
,
336
,
7
],
"1x200x320x6x64"
:
[
8
,
768
,
16
],
"1x100x160x6x64"
:
[
8
,
480
,
10
],
"1x50x80x6x64"
:
[
16
,
240
,
10
],
"1x25x40x6x64"
:
[
8
,
240
,
5
],
"1x64x64x6x64"
:
[
8
,
768
,
16
],
"64x56x56x7x64"
:
[
8
,
896
,
16
],
"64x28x28x7x64"
:
[
8
,
448
,
8
],
"64x14x14x7x64"
:
[
8
,
392
,
7
],
"64x7x7x7x64"
:
[
8
,
56
,
1
],
"1x200x320x7x64"
:
[
8
,
896
,
16
],
"1x100x160x7x64"
:
[
8
,
448
,
8
],
"1x50x80x7x64"
:
[
8
,
448
,
8
],
"1x25x40x7x64"
:
[
8
,
448
,
8
],
"1x64x64x7x64"
:
[
8
,
448
,
8
],
"64x56x56x8x64"
:
[
8
,
896
,
14
],
"64x28x28x8x64"
:
[
8
,
896
,
14
],
"64x14x14x8x64"
:
[
8
,
448
,
7
],
"64x7x7x8x64"
:
[
8
,
64
,
1
],
"1x200x320x8x64"
:
[
8
,
512
,
8
],
"1x100x160x8x64"
:
[
8
,
512
,
8
],
"1x50x80x8x64"
:
[
8
,
512
,
8
],
"1x25x40x8x64"
:
[
8
,
512
,
8
],
"1x64x64x8x64"
:
[
8
,
512
,
8
]
}
BWDTABLE
=
{
"64x56x56x4x16"
:
[
1
,
256
,
4
],
"64x56x56x5x16"
:
[
1
,
320
,
4
],
"64x56x56x6x16"
:
[
1
,
192
,
2
],
"64x56x56x7x16"
:
[
1
,
224
,
2
],
"64x56x56x8x16"
:
[
1
,
256
,
2
],
"64x56x56x4x32"
:
[
1
,
256
,
2
],
"64x56x56x5x32"
:
[
1
,
160
,
1
],
"64x56x56x6x32"
:
[
1
,
192
,
1
],
"64x56x56x7x32"
:
[
1
,
224
,
1
],
"64x56x56x8x32"
:
[
1
,
256
,
1
],
"64x56x56x4x64"
:
[
2
,
512
,
4
],
"64x56x56x5x64"
:
[
2
,
640
,
4
],
"64x56x56x6x64"
:
[
2
,
384
,
2
],
"64x56x56x7x64"
:
[
2
,
224
,
1
],
"64x56x56x8x64"
:
[
2
,
1024
,
4
],
"64x28x28x4x16"
:
[
1
,
128
,
2
],
"64x28x28x5x16"
:
[
1
,
320
,
4
],
"64x28x28x6x16"
:
[
1
,
96
,
1
],
"64x28x28x7x16"
:
[
1
,
224
,
2
],
"64x28x28x8x16"
:
[
1
,
128
,
1
],
"64x28x28x4x32"
:
[
1
,
128
,
1
],
"64x28x28x5x32"
:
[
1
,
320
,
2
],
"64x28x28x6x32"
:
[
1
,
192
,
1
],
"64x28x28x7x32"
:
[
1
,
224
,
1
],
"64x28x28x8x32"
:
[
1
,
256
,
1
],
"64x28x28x4x64"
:
[
2
,
512
,
4
],
"64x28x28x5x64"
:
[
2
,
640
,
4
],
"64x28x28x6x64"
:
[
2
,
384
,
2
],
"64x28x28x7x64"
:
[
2
,
224
,
1
],
"64x28x28x8x64"
:
[
2
,
512
,
2
],
"64x14x14x4x16"
:
[
1
,
128
,
2
],
"64x14x14x5x16"
:
[
1
,
320
,
4
],
"64x14x14x6x16"
:
[
1
,
192
,
2
],
"64x14x14x7x16"
:
[
1
,
224
,
2
],
"64x14x14x8x16"
:
[
1
,
128
,
1
],
"64x14x14x4x32"
:
[
1
,
256
,
2
],
"64x14x14x5x32"
:
[
1
,
160
,
1
],
"64x14x14x6x32"
:
[
1
,
192
,
1
],
"64x14x14x7x32"
:
[
1
,
224
,
1
],
"64x14x14x8x32"
:
[
1
,
256
,
1
],
"64x14x14x4x64"
:
[
2
,
128
,
1
],
"64x14x14x5x64"
:
[
2
,
160
,
1
],
"64x14x14x6x64"
:
[
2
,
384
,
2
],
"64x14x14x7x64"
:
[
2
,
224
,
1
],
"64x14x14x8x64"
:
[
2
,
256
,
1
],
"64x7x7x4x16"
:
[
4
,
784
,
49
],
"64x7x7x5x16"
:
[
2
,
280
,
7
],
"64x7x7x6x16"
:
[
2
,
48
,
1
],
"64x7x7x7x16"
:
[
2
,
392
,
7
],
"64x7x7x8x16"
:
[
1
,
128
,
1
],
"64x7x7x4x32"
:
[
1
,
128
,
1
],
"64x7x7x5x32"
:
[
1
,
160
,
1
],
"64x7x7x6x32"
:
[
2
,
96
,
1
],
"64x7x7x7x32"
:
[
2
,
112
,
1
],
"64x7x7x8x32"
:
[
2
,
128
,
1
],
"64x7x7x4x64"
:
[
2
,
896
,
7
],
"64x7x7x5x64"
:
[
2
,
160
,
1
],
"64x7x7x6x64"
:
[
2
,
192
,
1
],
"64x7x7x7x64"
:
[
2
,
224
,
1
],
"64x7x7x8x64"
:
[
2
,
256
,
1
],
"1x200x320x4x16"
:
[
1
,
320
,
5
],
"1x200x320x5x16"
:
[
1
,
320
,
4
],
"1x200x320x6x16"
:
[
1
,
96
,
1
],
"1x200x320x7x16"
:
[
1
,
224
,
2
],
"1x200x320x8x16"
:
[
1
,
640
,
5
],
"1x200x320x4x32"
:
[
1
,
128
,
1
],
"1x200x320x5x32"
:
[
1
,
320
,
2
],
"1x200x320x6x32"
:
[
1
,
384
,
2
],
"1x200x320x7x32"
:
[
1
,
224
,
1
],
"1x200x320x8x32"
:
[
1
,
256
,
1
],
"1x200x320x4x64"
:
[
2
,
640
,
5
],
"1x200x320x5x64"
:
[
2
,
800
,
5
],
"1x200x320x6x64"
:
[
2
,
768
,
4
],
"1x200x320x7x64"
:
[
2
,
448
,
2
],
"1x200x320x8x64"
:
[
2
,
1024
,
4
],
"1x100x160x4x16"
:
[
1
,
320
,
5
],
"1x100x160x5x16"
:
[
1
,
640
,
8
],
"1x100x160x6x16"
:
[
1
,
96
,
1
],
"1x100x160x7x16"
:
[
1
,
224
,
2
],
"1x100x160x8x16"
:
[
1
,
640
,
5
],
"1x100x160x4x32"
:
[
1
,
256
,
2
],
"1x100x160x5x32"
:
[
1
,
160
,
1
],
"1x100x160x6x32"
:
[
1
,
384
,
2
],
"1x100x160x7x32"
:
[
1
,
224
,
1
],
"1x100x160x8x32"
:
[
1
,
512
,
2
],
"1x100x160x4x64"
:
[
2
,
128
,
1
],
"1x100x160x5x64"
:
[
2
,
160
,
1
],
"1x100x160x6x64"
:
[
2
,
384
,
2
],
"1x100x160x7x64"
:
[
2
,
448
,
2
],
"1x100x160x8x64"
:
[
2
,
512
,
2
],
"1x50x80x4x16"
:
[
1
,
320
,
5
],
"1x50x80x5x16"
:
[
1
,
320
,
4
],
"1x50x80x6x16"
:
[
1
,
96
,
1
],
"1x50x80x7x16"
:
[
1
,
112
,
1
],
"1x50x80x8x16"
:
[
1
,
512
,
4
],
"1x50x80x4x32"
:
[
1
,
128
,
1
],
"1x50x80x5x32"
:
[
1
,
320
,
2
],
"1x50x80x6x32"
:
[
1
,
384
,
2
],
"1x50x80x7x32"
:
[
1
,
224
,
1
],
"1x50x80x8x32"
:
[
1
,
256
,
1
],
"1x50x80x4x64"
:
[
2
,
256
,
2
],
"1x50x80x5x64"
:
[
2
,
640
,
4
],
"1x50x80x6x64"
:
[
2
,
768
,
4
],
"1x50x80x7x64"
:
[
2
,
448
,
2
],
"1x50x80x8x64"
:
[
2
,
1024
,
4
],
"1x25x40x4x16"
:
[
1
,
320
,
5
],
"1x25x40x5x16"
:
[
2
,
400
,
10
],
"1x25x40x6x16"
:
[
1
,
192
,
2
],
"1x25x40x7x16"
:
[
4
,
224
,
8
],
"1x25x40x8x16"
:
[
4
,
160
,
5
],
"1x25x40x4x32"
:
[
2
,
128
,
2
],
"1x25x40x5x32"
:
[
1
,
320
,
2
],
"1x25x40x6x32"
:
[
2
,
96
,
1
],
"1x25x40x7x32"
:
[
2
,
112
,
1
],
"1x25x40x8x32"
:
[
2
,
640
,
5
],
"1x25x40x4x64"
:
[
2
,
128
,
1
],
"1x25x40x5x64"
:
[
2
,
160
,
1
],
"1x25x40x6x64"
:
[
2
,
192
,
1
],
"1x25x40x7x64"
:
[
2
,
896
,
4
],
"1x25x40x8x64"
:
[
2
,
512
,
2
],
"1x64x64x4x16"
:
[
1
,
256
,
4
],
"1x64x64x5x16"
:
[
2
,
40
,
1
],
"1x64x64x6x16"
:
[
1
,
192
,
2
],
"1x64x64x7x16"
:
[
1
,
224
,
2
],
"1x64x64x8x16"
:
[
1
,
512
,
4
],
"1x64x64x4x32"
:
[
2
,
64
,
1
],
"1x64x64x5x32"
:
[
1
,
320
,
2
],
"1x64x64x6x32"
:
[
1
,
192
,
1
],
"1x64x64x7x32"
:
[
1
,
224
,
1
],
"1x64x64x8x32"
:
[
1
,
256
,
1
],
"1x64x64x4x64"
:
[
2
,
512
,
4
],
"1x64x64x5x64"
:
[
2
,
640
,
4
],
"1x64x64x6x64"
:
[
2
,
192
,
1
],
"1x64x64x7x64"
:
[
2
,
224
,
1
],
"1x64x64x8x64"
:
[
2
,
256
,
1
]
}
\ No newline at end of file
DCNv4_op/scripts/search_bwd.sh
0 → 100644
View file @
5b17e272
python search_dcnv4_bwd_engine.py
>
res_bwd.txt
python find_best.py
--input
res_bwd.txt
--output
table_bwd.py
\ No newline at end of file
DCNv4_op/scripts/search_dcnv4.py
0 → 100644
View file @
5b17e272
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
import
time
import
math
import
torch
import
torch.nn
as
nn
import
math
from
torch.autograd
import
gradcheck
import
pandas
as
pd
from
easydict
import
EasyDict
as
edict
import
argparse
from
torch.cuda
import
Event
from
functions.dcnv3_func
import
DCNv3Function
,
dcnv3_core_pytorch
from
functions.dcnv4_func
import
DCNv4Function
torch
.
set_printoptions
(
threshold
=
10000
)
torch
.
manual_seed
(
3
)
#@torch.no_grad()
def
speed_test
(
func
,
args
,
inputs
,
name
=
'Unknown'
):
tic
=
Event
(
enable_timing
=
True
)
toc
=
Event
(
enable_timing
=
True
)
# warmup
for
i
in
range
(
args
.
warmup_num
):
func
(
*
inputs
)
total_time
=
0
tic
.
record
()
for
i
in
range
(
args
.
test_num
):
o
=
func
(
*
inputs
)
torch
.
cuda
.
synchronize
()
toc
.
record
()
avg_time
=
tic
.
elapsed_time
(
toc
)
/
args
.
test_num
# print(
# f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return
avg_time
@
torch
.
no_grad
()
def
test
(
N
,
H_in
,
W_in
,
M
,
D
,
spec
=
None
):
Kh
,
Kw
=
3
,
3
remove_center
=
False
P
=
Kh
*
Kw
-
remove_center
offset_scale
=
2.0
pad
=
1
dilation
=
1
stride
=
1
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
# print(input.shape)
offset
=
(
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
2
-
1
)
*
2
# offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*0
mask_origin
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask_origin
=
mask_origin
.
half
()
mask
=
mask_origin
# mask = torch.nn.functional.softmax(mask_origin, dim=-1)
offset_mask
=
torch
.
cat
([
offset
.
unflatten
(
-
1
,
(
M
,
P
*
2
)),
mask_origin
.
detach
()],
dim
=-
1
).
flatten
(
-
2
)
im2col_step
=
128
input
=
input
.
half
()
offset
=
offset
.
half
()
mask
=
mask
.
half
()
offset_mask
=
offset_mask
.
half
()
dcnv3_args
=
[
input
,
offset
,
mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
,
remove_center
,
]
output_pytorch
=
DCNv3Function
.
apply
(
*
dcnv3_args
)
input1
=
input
.
detach
()
def
pad
(
om
):
padded_zero
=
int
(
math
.
ceil
(
om
.
shape
[
3
]
/
8
)
*
8
)
-
om
.
shape
[
3
]
padded
=
torch
.
zeros
(
om
.
shape
[
0
],
om
.
shape
[
1
],
om
.
shape
[
2
],
padded_zero
).
to
(
om
)
return
torch
.
cat
([
om
,
padded
],
dim
=-
1
)
dcnv4_args
=
[
input1
,
pad
(
offset_mask
),
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
,
remove_center
,
spec
[
0
],
spec
[
1
],
2
,
None
# 8, 512, 2, 256
]
output_flash_cuda
=
DCNv4Function
.
apply
(
*
dcnv4_args
)
fwdok
=
torch
.
allclose
(
output_flash_cuda
,
output_pytorch
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
output_flash_cuda
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_flash_cuda
-
output_pytorch
).
abs
()
/
(
output_pytorch
.
abs
()
+
1e-3
)).
max
()
# print('>>> forward half')
# print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
if
not
fwdok
:
print
(
f
"Wrong:
{
N
}
x
{
H_in
}
x
{
W_in
}
x
{
M
}
x
{
D
}
\t
{
spec
[
0
]
}
/
{
spec
[
1
]
}
(
{
spec
[
2
]
}
)"
)
return
# assert(fwdok)
test_args
=
edict
({
'warmup_num'
:
10000
,
'test_num'
:
10000
})
exp_time_dcnv4
=
speed_test
(
DCNv4Function
.
apply
,
test_args
,
dcnv4_args
,
name
=
'exp'
)
torch
.
cuda
.
synchronize
()
print
(
f
"
{
N
}
x
{
H_in
}
x
{
W_in
}
x
{
M
}
x
{
D
}
\t
{
spec
[
0
]
}
/
{
spec
[
1
]
}
(
{
spec
[
2
]
}
):
{
exp_time_dcnv4
}
"
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--n"
,
type
=
int
)
parser
.
add_argument
(
"--h"
,
type
=
int
)
parser
.
add_argument
(
"--w"
,
type
=
int
)
parser
.
add_argument
(
"--g"
,
type
=
int
)
parser
.
add_argument
(
"--c"
,
type
=
int
)
parser
.
add_argument
(
"--dstride"
,
type
=
int
)
parser
.
add_argument
(
"--blockthread"
,
type
=
int
)
parser
.
add_argument
(
"--multiplier"
,
type
=
int
)
args
=
parser
.
parse_args
()
test
(
args
.
n
,
args
.
h
,
args
.
w
,
args
.
g
,
args
.
c
,
(
args
.
dstride
,
args
.
blockthread
,
args
.
multiplier
))
DCNv4_op/scripts/search_dcnv4_bwd.py
0 → 100644
View file @
5b17e272
# --------------------------------------------------------
# DCNv4
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
import
time
import
torch
import
torch.nn
as
nn
import
math
from
torch.autograd
import
gradcheck
import
pandas
as
pd
from
easydict
import
EasyDict
as
edict
import
argparse
from
torch.cuda
import
Event
from
functions
import
DCNv4Function
,
DCNv3Function
torch
.
set_printoptions
(
threshold
=
10000
)
torch
.
manual_seed
(
3
)
def
speed_test_backward
(
func
,
args
,
inputs
,
name
=
'Unknown'
):
# warmup
# for i in range(args.warmup_num):
# o = func(*inputs)
# o.sum().backward()
total_time
=
0
len_input
=
len
(
inputs
)
for
i
in
range
(
args
.
warmup_num
+
args
.
test_num
):
tic
=
Event
(
enable_timing
=
True
)
toc
=
Event
(
enable_timing
=
True
)
inputs
[
0
]
=
inputs
[
0
].
detach
()
inputs
[
0
].
requires_grad
=
True
if
len_input
>
1
and
isinstance
(
inputs
[
1
],
torch
.
Tensor
):
inputs
[
1
]
=
inputs
[
1
].
detach
()
inputs
[
1
].
requires_grad
=
True
if
len_input
>
2
and
isinstance
(
inputs
[
2
],
torch
.
Tensor
):
inputs
[
2
]
=
inputs
[
2
].
detach
()
inputs
[
2
].
requires_grad
=
True
o
=
func
(
*
inputs
)
torch
.
cuda
.
synchronize
()
tic
.
record
()
o
.
sum
().
backward
()
toc
.
record
()
torch
.
cuda
.
synchronize
()
_time
=
tic
.
elapsed_time
(
toc
)
if
i
>=
args
.
warmup_num
:
total_time
+=
_time
o
=
o
.
detach
()
# toc.record()
# torch.cuda.synchronize()
avg_time
=
total_time
/
args
.
test_num
#print(
# f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return
avg_time
# @torch.no_grad()
def
test
(
N
=
64
,
H_in
=
32
,
W_in
=
32
,
M
=
4
,
D
=
16
,
spec
=
None
):
"""
64x56x56x128(G=4)
2 64: 3.66
- offset_mask collection write 3.4022
- offset_mask collection 3.1968
"""
Kh
,
Kw
=
3
,
3
remove_center
=
False
P
=
Kh
*
Kw
-
remove_center
offset_scale
=
2.0
pad
=
1
dilation
=
1
stride
=
1
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
additions
=
[
None
,
None
,
spec
[
0
],
spec
[
1
],
False
]
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
10
#offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 0
offset
=
(
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
2
-
1
)
*
2
mask_origin
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask_origin
=
mask_origin
.
half
()
mask_origin
.
requires_grad
=
True
# offset_mask = torch.cat([offset.unflatten(-1, (M, P, 2)), mask_origin.detach().unsqueeze(-1)], dim=-1).flatten(-3)
# mask /= mask.sum(-1, keepdim=True)
# mask = torch.nn.functional.softmax(mask_origin, dim=-1, dtype=torch.float32)
mask
=
mask_origin
# mask = mask.reshape(N, H_out, W_out, M*P)
# offset_mask = torch.cat([offset.unflatten(-1, (M, P, 2)), mask.detach().unsqueeze(-1)], dim=-1).flatten(-3)
offset_mask
=
torch
.
cat
([
offset
.
detach
().
unflatten
(
-
1
,
(
M
,
P
*
2
)),
mask_origin
.
detach
()],
dim
=-
1
).
flatten
(
-
2
)
im2col_step
=
128
input
=
input
.
half
()
offset
=
offset
.
half
()
mask
=
mask
.
half
()
input
.
requires_grad
=
True
offset
.
requires_grad
=
True
# mask.requires_grad = True
output_pytorch
=
DCNv3Function
.
apply
(
input
,
offset
,
mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
,
remove_center
)
#.detach().cpu()
(
output_pytorch
.
sum
()
/
10
).
backward
()
def
pad
(
om
):
padded_zero
=
int
(
math
.
ceil
(
om
.
shape
[
3
]
/
8
)
*
8
)
-
om
.
shape
[
3
]
padded
=
torch
.
zeros
(
om
.
shape
[
0
],
om
.
shape
[
1
],
om
.
shape
[
2
],
padded_zero
).
to
(
om
)
return
torch
.
cat
([
om
,
padded
],
dim
=-
1
)
# value_offset_mask = input.detach()
input1
=
input
.
detach
()
input1
.
requires_grad
=
True
offset_mask
=
offset_mask
.
half
()
offset_mask
.
requires_grad
=
True
# offset_mask1.requires_grad = True
torch
.
cuda
.
profiler
.
cudart
().
cudaProfilerStart
()
output_flash_cuda
=
DCNv4Function
.
apply
(
input1
,
offset_mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
,
remove_center
,
*
additions
)
#.detach().cpu()
(
output_flash_cuda
.
sum
()
/
10
).
backward
()
torch
.
cuda
.
profiler
.
cudart
().
cudaProfilerStop
()
input_grad
=
input
.
grad
input2_grad
=
input1
.
grad
bwdok
=
torch
.
allclose
(
input_grad
.
float
(),
input2_grad
.
float
(),
rtol
=
1e-2
,
atol
=
1e-3
)
rel_err
=
(
input_grad
.
abs
()
-
input2_grad
.
abs
())
/
(
input_grad
.
abs
()
+
1e-3
)
offset_grad1
=
offset
.
grad
offset_grad2
=
offset_mask
.
grad
.
reshape
(
N
,
H_out
,
W_out
,
M
,
P
*
3
)[...,
:
P
*
2
].
reshape
(
N
,
H_out
,
W_out
,
M
*
P
*
2
)
bwdok2
=
torch
.
allclose
(
offset_grad1
.
float
(),
offset_grad2
.
float
(),
rtol
=
1e-2
,
atol
=
1e-3
)
rel_err
=
(
offset_grad1
-
offset_grad2
).
abs
()
/
(
offset_grad1
.
abs
()
+
1e-3
)
mask_grad1
=
mask_origin
.
grad
mask_grad2
=
offset_mask
.
grad
.
reshape
(
N
,
H_out
,
W_out
,
M
,
P
*
3
)[...,
P
*
2
:].
reshape
(
N
,
H_out
,
W_out
,
M
,
P
)
bwdok3
=
torch
.
allclose
(
mask_grad1
,
mask_grad2
,
rtol
=
1e-2
,
atol
=
1e-3
)
rel_err
=
(
mask_grad1
-
mask_grad2
).
abs
()
/
(
mask_grad1
.
abs
()
+
1e-3
)
fwdok
=
torch
.
allclose
(
output_flash_cuda
,
output_pytorch
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
output_flash_cuda
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_flash_cuda
-
output_pytorch
).
abs
()
/
(
output_pytorch
.
abs
()
+
1e-3
)).
max
()
if
not
(
bwdok
and
bwdok2
and
bwdok3
):
print
(
f
"Wrong:
{
N
}
x
{
H_in
}
x
{
W_in
}
x
{
M
}
x
{
D
}
\t
{
spec
[
0
]
}
/
{
spec
[
1
]
}
(
{
spec
[
2
]
}
)"
)
return
# fn_args = [
# input,
# offset,
# mask,
# Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
# im2col_step, remove_center
# ]
flash_dcn_fn_args
=
[
input1
,
offset_mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
,
remove_center
,
*
additions
]
test_args
=
edict
({
'warmup_num'
:
1000
,
'test_num'
:
1000
})
try
:
exp_time
=
speed_test_backward
(
DCNv4Function
.
apply
,
test_args
,
flash_dcn_fn_args
,
name
=
'exp'
)
except
:
print
(
f
"Wrong:
{
N
}
x
{
H_in
}
x
{
W_in
}
x
{
M
}
x
{
D
}
\t
{
spec
[
0
]
}
/
{
spec
[
1
]
}
(
{
spec
[
2
]
}
)"
)
return
torch
.
cuda
.
synchronize
()
print
(
f
"
{
N
}
x
{
H_in
}
x
{
W_in
}
x
{
M
}
x
{
D
}
\t
{
spec
[
0
]
}
/
{
spec
[
1
]
}
(
{
spec
[
2
]
}
):
{
exp_time
}
"
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--n"
,
type
=
int
)
parser
.
add_argument
(
"--h"
,
type
=
int
)
parser
.
add_argument
(
"--w"
,
type
=
int
)
parser
.
add_argument
(
"--g"
,
type
=
int
)
parser
.
add_argument
(
"--c"
,
type
=
int
)
parser
.
add_argument
(
"--dstride"
,
type
=
int
)
parser
.
add_argument
(
"--blockthread"
,
type
=
int
)
parser
.
add_argument
(
"--multiplier"
,
type
=
int
)
args
=
parser
.
parse_args
()
test
(
args
.
n
,
args
.
h
,
args
.
w
,
args
.
g
,
args
.
c
,
(
args
.
dstride
,
args
.
blockthread
,
args
.
multiplier
))
DCNv4_op/scripts/search_dcnv4_bwd_engine.py
0 → 100644
View file @
5b17e272
import
os
def
factors
(
N
):
res
=
[]
for
i
in
range
(
1
,
N
+
1
):
if
N
%
i
==
0
:
res
.
append
(
i
)
return
res
if
__name__
==
'__main__'
:
BATCH
=
64
for
N
,
Hin
,
Win
in
[(
BATCH
,
56
,
56
),
(
BATCH
,
28
,
28
),
(
BATCH
,
14
,
14
),
(
BATCH
,
7
,
7
),
(
1
,
200
,
320
),
(
1
,
100
,
160
),
(
1
,
50
,
80
),
(
1
,
25
,
40
),
(
1
,
64
,
64
)]:
for
group_channel
in
[
16
,
32
,
64
]:
for
group
in
[
4
,
5
,
6
,
7
,
8
]:
for
d_stride
in
[
1
,
2
,
4
]:
for
m
in
factors
(
N
*
Hin
*
Win
):
if
m
>
64
:
break
block_thread
=
group
*
(
group_channel
//
d_stride
)
*
m
if
block_thread
>
1024
:
break
cmd
=
f
"python search_dcnv4_bwd.py --n
{
N
}
--h
{
Hin
}
--w
{
Win
}
--g
{
group
}
--c
{
group_channel
}
--dstride
{
d_stride
}
--blockthread
{
block_thread
}
--multiplier
{
m
}
"
os
.
system
(
cmd
)
\ No newline at end of file
DCNv4_op/scripts/search_dcnv4_engine.py
0 → 100644
View file @
5b17e272
import
os
def
factors
(
N
):
res
=
[]
for
i
in
range
(
1
,
N
+
1
):
if
N
%
i
==
0
:
res
.
append
(
i
)
return
res
if
__name__
==
'__main__'
:
BATCH
=
64
for
group_channel
in
[
16
,
32
,
64
]:
for
group
in
[
4
,
5
,
6
,
7
,
8
]:
for
N
,
Hin
,
Win
in
[(
BATCH
,
56
,
56
),
(
BATCH
,
28
,
28
),
(
BATCH
,
14
,
14
),
(
BATCH
,
7
,
7
),
(
1
,
200
,
320
),
(
1
,
100
,
160
),
(
1
,
50
,
80
),
(
1
,
25
,
40
),
(
1
,
64
,
64
)]:
for
d_stride
in
[
2
,
4
,
8
,
16
]:
for
m
in
factors
(
N
*
Hin
*
Win
):
if
m
>
64
:
break
block_thread
=
group
*
(
group_channel
//
d_stride
)
*
m
if
block_thread
>
1024
:
break
cmd
=
f
"python search_dcnv4.py --n
{
N
}
--h
{
Hin
}
--w
{
Win
}
--g
{
group
}
--c
{
group_channel
}
--dstride
{
d_stride
}
--blockthread
{
block_thread
}
--multiplier
{
m
}
"
os
.
system
(
cmd
)
\ No newline at end of file
DCNv4_op/scripts/search_fwd.sh
0 → 100644
View file @
5b17e272
python search_dcnv4_engine.py
>
res.txt
python find_best.py
--input
res.txt
--output
table.py
\ No newline at end of file
DCNv4_op/scripts/test_dcnv4.py
0 → 100644
View file @
5b17e272
# --------------------------------------------------------
# DCNv4
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
import
time
import
torch
import
torch.nn
as
nn
import
math
from
torch.autograd
import
gradcheck
import
pandas
as
pd
from
easydict
import
EasyDict
as
edict
from
torch.cuda
import
Event
# from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
from
functions.dcnv4_func
import
DCNv4Function
torch
.
set_printoptions
(
threshold
=
10000
)
H_in
,
W_in
=
56
,
56
N
,
M
,
D
=
64
,
4
,
32
# H_in, W_in = 28, 28
# N, M, D = 64, 8, 32
# H_in, W_in = 14, 14
# N, M, D = 64, 16, 32
# H_in, W_in = 7, 7
# N, M, D = 64, 32, 32
# H_in, W_in = 8, 8
# N, M, D = 128, 4, 16
Kh
,
Kw
=
3
,
3
remove_center
=
False
P
=
Kh
*
Kw
-
remove_center
offset_scale
=
2.0
pad
=
1
dilation
=
1
stride
=
1
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
torch
.
manual_seed
(
3
)
#@torch.no_grad()
def
speed_test
(
func
,
args
,
inputs
,
name
=
'Unknown'
):
tic
=
Event
(
enable_timing
=
True
)
toc
=
Event
(
enable_timing
=
True
)
# warmup
for
i
in
range
(
args
.
warmup_num
):
func
(
*
inputs
)
total_time
=
0
tic
.
record
()
for
i
in
range
(
args
.
test_num
):
o
=
func
(
*
inputs
)
torch
.
cuda
.
synchronize
()
toc
.
record
()
avg_time
=
tic
.
elapsed_time
(
toc
)
/
args
.
test_num
print
(
f
'>>>
{
name
:
<
10
}
finished
{
args
.
test_num
}
running, avg_time:
{
avg_time
:.
6
f
}
ms'
)
return
avg_time
@
torch
.
no_grad
()
def
check_forward_equal_with_pytorch_half
():
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
print
(
input
.
shape
)
offset
=
(
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
2
-
1
)
*
10
# offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*0
mask_origin
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask_origin
=
mask_origin
.
half
()
mask
=
mask_origin
# mask = torch.nn.functional.softmax(mask_origin, dim=-1)
offset_mask
=
torch
.
cat
([
offset
.
unflatten
(
-
1
,
(
M
,
P
*
2
)),
mask_origin
.
detach
()],
dim
=-
1
).
flatten
(
-
2
)
im2col_step
=
128
input
=
input
.
half
()
offset
=
offset
.
half
()
mask
=
mask
.
half
()
offset_mask
=
offset_mask
.
half
()
input1
=
input
.
detach
()
def
pad
(
om
):
padded_zero
=
int
(
math
.
ceil
(
om
.
shape
[
3
]
/
8
)
*
8
)
-
om
.
shape
[
3
]
padded
=
torch
.
zeros
(
om
.
shape
[
0
],
om
.
shape
[
1
],
om
.
shape
[
2
],
padded_zero
).
to
(
om
)
return
torch
.
cat
([
om
,
padded
],
dim
=-
1
)
dcnv4_args
=
[
input1
,
pad
(
offset_mask
),
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
,
remove_center
,
8
,
512
,
2
,
256
,
True
,
True
,
]
output_flash_cuda
=
DCNv4Function
.
apply
(
*
dcnv4_args
)
print
(
f
"test success"
)
# fwdok = torch.allclose(output_flash_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
# max_abs_err = (output_flash_cuda - output_pytorch).abs().max()
# max_rel_err = ((output_flash_cuda - output_pytorch).abs() /
# (output_pytorch.abs()+ 1e-3)).max()
# print('>>> forward half')
# print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
# assert(fwdok)
# test_args = edict({'warmup_num': 1000, 'test_num': 1000})
# exp_time_dcnv4 = speed_test(DCNv4Function.apply, test_args, dcnv4_args, name='exp')
# torch.cuda.synchronize()
# results = [{}]
# results[0]['dcnv3_time'] = exp_time_dcnv3
# results[0]['dcnv4_time'] = exp_time_dcnv4
# columns = list(results[0].keys())
# outputs = pd.DataFrame(results, columns=columns)
# with pd.option_context(
# 'display.max_rows', None, 'display.max_columns', None,
# 'display.max_colwidth', None, 'display.width', None,
# 'display.precision', 4, ):
# print(outputs)
if
__name__
==
'__main__'
:
check_forward_equal_with_pytorch_half
()
DCNv4_op/scripts/test_dcnv4_bwd.py
0 → 100644
View file @
5b17e272
# --------------------------------------------------------
# DCNv4
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
import
time
import
torch
import
torch.nn
as
nn
import
math
from
torch.autograd
import
gradcheck
import
pandas
as
pd
from
easydict
import
EasyDict
as
edict
from
torch.cuda
import
Event
from
functions
import
DCNv4Function
,
DCNv3Function
torch
.
set_printoptions
(
threshold
=
10000
)
H_in
,
W_in
=
56
,
56
N
,
M
,
D
=
64
,
4
,
32
# H_in, W_in = 28, 28
# N, M, D = 64, 16, 16
# H_in, W_in = 14, 14
# N, M, D = 64, 32, 16
# H_in, W_in = 7, 7
# N, M, D = 64, 64, 16
# H_in, W_in = 8, 8
# N, M, D = 128, 4, 16
Kh
,
Kw
=
3
,
3
remove_center
=
False
P
=
Kh
*
Kw
-
remove_center
offset_scale
=
2.0
pad
=
1
dilation
=
1
stride
=
1
H_out
=
(
H_in
+
2
*
pad
-
(
dilation
*
(
Kh
-
1
)
+
1
))
//
stride
+
1
W_out
=
(
W_in
+
2
*
pad
-
(
dilation
*
(
Kw
-
1
)
+
1
))
//
stride
+
1
torch
.
manual_seed
(
3
)
def
speed_test_backward
(
func
,
args
,
inputs
,
name
=
'Unknown'
):
# warmup
# for i in range(args.warmup_num):
# o = func(*inputs)
# o.sum().backward()
total_time
=
0
len_input
=
len
(
inputs
)
for
i
in
range
(
args
.
warmup_num
+
args
.
test_num
):
tic
=
Event
(
enable_timing
=
True
)
toc
=
Event
(
enable_timing
=
True
)
inputs
[
0
]
=
inputs
[
0
].
detach
()
inputs
[
0
].
requires_grad
=
True
if
len_input
>
1
and
isinstance
(
inputs
[
1
],
torch
.
Tensor
):
inputs
[
1
]
=
inputs
[
1
].
detach
()
inputs
[
1
].
requires_grad
=
True
if
len_input
>
2
and
isinstance
(
inputs
[
2
],
torch
.
Tensor
):
inputs
[
2
]
=
inputs
[
2
].
detach
()
inputs
[
2
].
requires_grad
=
True
o
=
func
(
*
inputs
)
torch
.
cuda
.
synchronize
()
tic
.
record
()
o
.
sum
().
backward
()
toc
.
record
()
torch
.
cuda
.
synchronize
()
_time
=
tic
.
elapsed_time
(
toc
)
if
i
>=
args
.
warmup_num
:
total_time
+=
_time
o
=
o
.
detach
()
# toc.record()
# torch.cuda.synchronize()
avg_time
=
total_time
/
args
.
test_num
#print(
# f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return
avg_time
# @torch.no_grad()
def
check_forward_equal_with_pytorch_half
():
"""
64x56x56x128(G=4)
2 64: 3.66
- offset_mask collection write 3.4022
- offset_mask collection 3.1968
"""
additions
=
[
8
,
128
,
2
,
256
,
False
]
input
=
torch
.
rand
(
N
,
H_in
,
W_in
,
M
*
D
).
cuda
()
*
10
print
(
input
.
shape
)
#offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 0
offset
=
(
torch
.
rand
(
N
,
H_out
,
W_out
,
M
*
P
*
2
).
cuda
()
*
2
-
1
)
*
2
mask_origin
=
torch
.
rand
(
N
,
H_out
,
W_out
,
M
,
P
).
cuda
()
+
1e-5
mask_origin
=
mask_origin
.
half
()
mask_origin
.
requires_grad
=
True
# offset_mask = torch.cat([offset.unflatten(-1, (M, P, 2)), mask_origin.detach().unsqueeze(-1)], dim=-1).flatten(-3)
# mask /= mask.sum(-1, keepdim=True)
# mask = torch.nn.functional.softmax(mask_origin, dim=-1, dtype=torch.float32)
mask
=
mask_origin
# mask = mask.reshape(N, H_out, W_out, M*P)
# offset_mask = torch.cat([offset.unflatten(-1, (M, P, 2)), mask.detach().unsqueeze(-1)], dim=-1).flatten(-3)
offset_mask
=
torch
.
cat
([
offset
.
detach
().
unflatten
(
-
1
,
(
M
,
P
*
2
)),
mask_origin
.
detach
()],
dim
=-
1
).
flatten
(
-
2
)
im2col_step
=
128
input
=
input
.
half
()
offset
=
offset
.
half
()
mask
=
mask
.
half
()
input
.
requires_grad
=
True
offset
.
requires_grad
=
True
# mask.requires_grad = True
output_pytorch
=
DCNv3Function
.
apply
(
input
,
offset
,
mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
,
remove_center
)
#.detach().cpu()
(
output_pytorch
.
sum
()
/
10
).
backward
()
def
pad
(
om
):
padded_zero
=
int
(
math
.
ceil
(
om
.
shape
[
3
]
/
8
)
*
8
)
-
om
.
shape
[
3
]
padded
=
torch
.
zeros
(
om
.
shape
[
0
],
om
.
shape
[
1
],
om
.
shape
[
2
],
padded_zero
).
to
(
om
)
return
torch
.
cat
([
om
,
padded
],
dim
=-
1
)
# value_offset_mask = input.detach()
input1
=
input
.
detach
()
input1
.
requires_grad
=
True
offset_mask
=
offset_mask
.
half
()
offset_mask
.
requires_grad
=
True
# offset_mask1.requires_grad = True
torch
.
cuda
.
profiler
.
cudart
().
cudaProfilerStart
()
output_flash_cuda
=
DCNv4Function
.
apply
(
input1
,
offset_mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
,
remove_center
,
*
additions
)
#.detach().cpu()
(
output_flash_cuda
.
sum
()
/
10
).
backward
()
torch
.
cuda
.
profiler
.
cudart
().
cudaProfilerStop
()
input_grad
=
input
.
grad
input2_grad
=
input1
.
grad
bwdok
=
torch
.
allclose
(
input_grad
.
float
(),
input2_grad
.
float
(),
rtol
=
1e-2
,
atol
=
1e-3
)
print
(
"bwdok"
)
print
(
bwdok
)
rel_err
=
(
input_grad
.
abs
()
-
input2_grad
.
abs
())
/
(
input_grad
.
abs
()
+
1e-3
)
print
(
rel_err
.
max
())
offset_grad1
=
offset
.
grad
offset_grad2
=
offset_mask
.
grad
.
reshape
(
N
,
H_out
,
W_out
,
M
,
P
*
3
)[...,
:
P
*
2
].
reshape
(
N
,
H_out
,
W_out
,
M
*
P
*
2
)
# print(offset_grad1)
# print("====================")
# print(offset_grad2)
bwdok2
=
torch
.
allclose
(
offset_grad1
.
float
(),
offset_grad2
.
float
(),
rtol
=
1e-2
,
atol
=
1e-3
)
print
(
"bwdok2"
)
print
(
bwdok2
)
rel_err
=
(
offset_grad1
-
offset_grad2
).
abs
()
/
(
offset_grad1
.
abs
()
+
1e-3
)
print
(
rel_err
.
max
())
mask_grad1
=
mask_origin
.
grad
mask_grad2
=
offset_mask
.
grad
.
reshape
(
N
,
H_out
,
W_out
,
M
,
P
*
3
)[...,
P
*
2
:].
reshape
(
N
,
H_out
,
W_out
,
M
,
P
)
bwdok3
=
torch
.
allclose
(
mask_grad1
,
mask_grad2
,
rtol
=
1e-2
,
atol
=
1e-3
)
print
(
"bwdok3"
)
print
(
bwdok3
)
rel_err
=
(
mask_grad1
-
mask_grad2
).
abs
()
/
(
mask_grad1
.
abs
()
+
1e-3
)
print
(
rel_err
.
max
())
fwdok
=
torch
.
allclose
(
output_flash_cuda
,
output_pytorch
,
rtol
=
1e-2
,
atol
=
1e-3
)
max_abs_err
=
(
output_flash_cuda
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_flash_cuda
-
output_pytorch
).
abs
()
/
(
output_pytorch
.
abs
()
+
1e-3
)).
max
()
print
(
'>>> forward half'
)
print
(
f
'*
{
fwdok
}
check_forward_equal_with_pytorch_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
'
)
fn_args
=
[
input
,
offset
,
mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
,
remove_center
]
flash_dcn_fn_args
=
[
input1
,
offset_mask
,
Kh
,
Kw
,
stride
,
stride
,
Kh
//
2
,
Kw
//
2
,
dilation
,
dilation
,
M
,
D
,
offset_scale
,
im2col_step
,
remove_center
,
*
additions
]
test_args
=
edict
({
'warmup_num'
:
1000
,
'test_num'
:
1000
})
exp_time
=
speed_test_backward
(
DCNv4Function
.
apply
,
test_args
,
flash_dcn_fn_args
,
name
=
'exp'
)
exp_time_base
=
speed_test_backward
(
DCNv3Function
.
apply
,
test_args
,
fn_args
,
name
=
'exp'
)
results
=
[{}]
results
[
0
][
'time'
]
=
exp_time
results
[
0
][
'time_base'
]
=
exp_time_base
columns
=
list
(
results
[
0
].
keys
())
outputs
=
pd
.
DataFrame
(
results
,
columns
=
columns
)
with
pd
.
option_context
(
'display.max_rows'
,
None
,
'display.max_columns'
,
None
,
'display.max_colwidth'
,
None
,
'display.width'
,
None
,
'display.precision'
,
4
,
):
print
(
outputs
)
if
__name__
==
'__main__'
:
check_forward_equal_with_pytorch_half
()
\ No newline at end of file
DCNv4_op/scripts/test_flash_deform_attn.py
0 → 100644
View file @
5b17e272
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
from
easydict
import
EasyDict
as
edict
from
torch.cuda
import
Event
import
pandas
as
pd
import
time
import
torch
import
torch.nn
as
nn
from
torch.autograd
import
gradcheck
from
functions
import
MSDeformAttnFunction
,
FlashDeformAttnFunction
,
ms_deform_attn_core_pytorch
# N, M, D = 1, 4, 8
# # Lq, L, P = 2, 2, 2
# # shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
# Lq, L, P = 1, 2, 8
# shapes = torch.as_tensor([(8, 16), (4, 8)], dtype=torch.long).cuda()
# N, M, D = 1, 8, 32
# # Lq, L, P = 2, 2, 2
# # shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
# Lq, L, P = 300, 4, 4
# # shapes = torch.as_tensor([(134, 151), (67, 76), (34, 38), (17, 19)], dtype=torch.long).cuda()
# # shapes = torch.as_tensor([(134, 151), (67, 76), (34, 38), (16, 16)], dtype=torch.long).cuda()
# # shapes = torch.as_tensor([(134, 151), (67, 76), (34, 38), (17, 19)], dtype=torch.long).cuda()
# # shapes = torch.as_tensor([(17, 19), (4, 4)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(100, 151), (50, 76), (25, 38), (13, 19)], dtype=torch.long).cuda()
# # shapes = torch.as_tensor([(110, 151)], dtype=torch.long).cuda()
# B:6
# H:232
# W:400
# G:5
# D: 16
# channels: 80
# kernel: 3 points = 3 * 3
# num_split = 45 = kernel *kernel * G
H
=
256
W
=
256
N
,
M
,
D
=
1
,
8
,
32
Lq
,
L
,
P
=
100
*
152
,
4
,
8
shapes
=
torch
.
Tensor
([[
100
,
152
],
[
50
,
76
],
[
25
,
38
],
[
13
,
19
]]).
long
().
cuda
()
# x = x.reshape([B, H*W, G, D + self.num_split * 3])
# shapes = torch.as_tensor([(H, W)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(H, W), (H // 2, W // 2)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(H, W), (H // 2, W // 2), (H // 4, W // 4), (H // 8, W // 8)], dtype=torch.long).cuda()
level_start_index
=
torch
.
cat
((
shapes
.
new_zeros
((
1
,)),
shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
S
=
sum
([(
H
*
W
).
item
()
for
H
,
W
in
shapes
])
print
(
S
)
def
get_reference_points
(
spatial_shapes
,
device
):
reference_points_list
=
[]
for
lvl
,
(
H_
,
W_
)
in
enumerate
(
spatial_shapes
):
ref_y
,
ref_x
=
torch
.
meshgrid
(
torch
.
linspace
(
0.5
,
H_
-
0.5
,
H_
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
linspace
(
0.5
,
W_
-
0.5
,
W_
,
dtype
=
torch
.
float32
,
device
=
device
))
ref_y
=
ref_y
.
reshape
(
-
1
)[
None
]
/
(
H_
)
ref_x
=
ref_x
.
reshape
(
-
1
)[
None
]
/
(
W_
)
ref
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
reference_points_list
.
append
(
ref
)
reference_points
=
torch
.
cat
(
reference_points_list
,
1
)
# reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return
reference_points
torch
.
manual_seed
(
3
)
@
torch
.
no_grad
()
def
speed_test
(
func
,
args
,
inputs
,
name
=
'Unknown'
):
tic
=
Event
(
enable_timing
=
True
)
toc
=
Event
(
enable_timing
=
True
)
# warmup
for
i
in
range
(
args
.
warmup_num
):
func
(
*
inputs
)
tic
.
record
()
for
i
in
range
(
args
.
test_num
):
func
(
*
inputs
)
toc
.
record
()
torch
.
cuda
.
synchronize
()
avg_time
=
tic
.
elapsed_time
(
toc
)
/
args
.
test_num
print
(
f
'>>>
{
name
:
<
10
}
finished
{
args
.
test_num
}
running, avg_time:
{
avg_time
:.
6
f
}
ms'
)
return
avg_time
@
torch
.
no_grad
()
def
check_forward_equal_with_pytorch_half
():
value
=
torch
.
rand
(
N
,
S
,
M
,
D
).
cuda
()
*
0.01
# offset = (torch.rand(N, Lq, M, L, P, 2).cuda() * 2 - 1) / 10
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
).
cuda
()
attention_weights
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
).
cuda
()
+
1e-5
sampling_loc_attn
=
torch
.
cat
([
sampling_locations
.
reshape
(
N
,
Lq
,
M
,
L
*
P
*
2
),
attention_weights
.
reshape
(
N
,
Lq
,
M
,
L
*
P
)],
dim
=-
1
)
attention_weights
=
torch
.
nn
.
functional
.
softmax
(
attention_weights
.
flatten
(
-
2
,
-
1
),
dim
=-
1
).
unflatten
(
-
1
,
(
L
,
P
))
im2col_step
=
128
flash_fn_args
=
(
value
.
half
(),
shapes
,
level_start_index
,
sampling_loc_attn
.
half
(),
im2col_step
,
P
,
16
)
output_cuda
=
(
FlashDeformAttnFunction
.
apply
(
*
flash_fn_args
)
.
detach
()
.
cpu
()
).
double
()
fn_args
=
(
value
,
shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
,
)
output_pytorch
=
(
MSDeformAttnFunction
.
apply
(
*
fn_args
)
.
detach
().
double
()
.
cpu
()
)
max_abs_err
=
(
output_cuda
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
output_pytorch
.
abs
()).
max
()
fwdok
=
torch
.
allclose
(
output_cuda
,
output_pytorch
,
rtol
=
1e-2
,
atol
=
1e-3
)
print
(
f
"*
{
fwdok
}
check_forward_equal_with_pytorch_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
"
)
test_args
=
edict
({
'warmup_num'
:
1000
,
'test_num'
:
1000
})
exp_time_base
=
speed_test
(
MSDeformAttnFunction
.
apply
,
test_args
,
fn_args
,
name
=
'exp'
)
exp_time
=
speed_test
(
FlashDeformAttnFunction
.
apply
,
test_args
,
flash_fn_args
,
name
=
'exp'
)
results
=
[{}]
results
[
0
][
'time'
]
=
exp_time
results
[
0
][
'time_base'
]
=
exp_time_base
columns
=
list
(
results
[
0
].
keys
())
outputs
=
pd
.
DataFrame
(
results
,
columns
=
columns
)
with
pd
.
option_context
(
'display.max_rows'
,
None
,
'display.max_columns'
,
None
,
'display.max_colwidth'
,
None
,
'display.width'
,
None
,
'display.precision'
,
4
,
):
print
(
outputs
)
if
__name__
==
"__main__"
:
check_forward_equal_with_pytorch_half
()
DCNv4_op/scripts/test_flash_deform_attn_backward.py
0 → 100644
View file @
5b17e272
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
from
easydict
import
EasyDict
as
edict
from
torch.cuda
import
Event
import
pandas
as
pd
import
time
import
torch
import
torch.nn
as
nn
from
torch.autograd
import
gradcheck
from
functions
import
MSDeformAttnFunction
,
ms_deform_attn_core_pytorch
,
FlashDeformAttnFunction
H
=
256
W
=
256
N
,
M
,
D
=
1
,
8
,
16
Lq
,
L
,
P
=
H
*
W
,
1
,
8
# x = x.reshape([B, H*W, G, D + self.num_split * 3])
shapes
=
torch
.
as_tensor
([(
H
,
W
)],
dtype
=
torch
.
long
).
cuda
()
# shapes = torch.as_tensor([(H, W), (H // 2, W // 2)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(H, W), (H // 2, W // 2), (H // 4, W // 4), (H // 8, W // 8)], dtype=torch.long).cuda()
H
=
256
W
=
256
N
,
M
,
D
=
1
,
8
,
32
Lq
,
L
,
P
=
100
*
152
,
4
,
8
shapes
=
torch
.
Tensor
([[
100
,
152
],
[
50
,
76
],
[
25
,
38
],
[
13
,
19
]]).
long
().
cuda
()
level_start_index
=
torch
.
cat
((
shapes
.
new_zeros
((
1
,)),
shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
]))
S
=
sum
([(
H
*
W
).
item
()
for
H
,
W
in
shapes
])
def
get_reference_points
(
spatial_shapes
,
device
):
reference_points_list
=
[]
for
lvl
,
(
H_
,
W_
)
in
enumerate
(
spatial_shapes
):
ref_y
,
ref_x
=
torch
.
meshgrid
(
torch
.
linspace
(
0.5
,
H_
-
0.5
,
H_
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
linspace
(
0.5
,
W_
-
0.5
,
W_
,
dtype
=
torch
.
float32
,
device
=
device
))
ref_y
=
ref_y
.
reshape
(
-
1
)[
None
]
/
(
H_
)
ref_x
=
ref_x
.
reshape
(
-
1
)[
None
]
/
(
W_
)
ref
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
reference_points_list
.
append
(
ref
)
reference_points
=
torch
.
cat
(
reference_points_list
,
1
)
# reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return
reference_points
torch
.
manual_seed
(
3
)
@
torch
.
no_grad
()
def
speed_test
(
func
,
args
,
inputs
,
name
=
'Unknown'
):
tic
=
Event
(
enable_timing
=
True
)
toc
=
Event
(
enable_timing
=
True
)
# warmup
for
i
in
range
(
args
.
warmup_num
):
func
(
*
inputs
)
tic
.
record
()
for
i
in
range
(
args
.
test_num
):
func
(
*
inputs
)
toc
.
record
()
torch
.
cuda
.
synchronize
()
avg_time
=
tic
.
elapsed_time
(
toc
)
/
args
.
test_num
print
(
f
'>>>
{
name
:
<
10
}
finished
{
args
.
test_num
}
running, avg_time:
{
avg_time
:.
6
f
}
ms'
)
return
avg_time
def
check_forward_equal_with_pytorch_half
():
value
=
torch
.
rand
(
N
,
S
,
M
,
D
).
cuda
()
*
0.01
offset
=
(
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
).
cuda
()
*
2
-
1
)
/
10
sampling_locations
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
,
2
).
cuda
()
attention_weights_origin
=
torch
.
rand
(
N
,
Lq
,
M
,
L
,
P
).
cuda
()
+
1e-5
attention_weights_origin
.
requires_grad
=
True
sampling_loc_attn
=
torch
.
cat
([
sampling_locations
.
detach
().
reshape
(
N
,
Lq
,
M
,
L
*
P
*
2
),
attention_weights_origin
.
detach
().
reshape
(
N
,
Lq
,
M
,
L
*
P
)],
dim
=-
1
)
attention_weights
=
torch
.
nn
.
functional
.
softmax
(
attention_weights_origin
.
flatten
(
-
2
,
-
1
),
dim
=-
1
).
unflatten
(
-
1
,
(
L
,
P
))
im2col_step
=
128
value
.
requires_grad
=
True
sampling_loc_attn
.
requires_grad
=
True
output_cuda
=
(
FlashDeformAttnFunction
.
apply
(
value
.
float
(),
shapes
,
level_start_index
,
sampling_loc_attn
.
float
(),
im2col_step
,
)
)
(
output_cuda
.
float
().
sum
()
/
10
).
backward
()
value1
=
value
.
detach
()
value1
.
requires_grad
=
True
sampling_locations
.
requires_grad
=
True
#attention_weights.requires_grad = True
output_pytorch
=
(
ms_deform_attn_core_pytorch
(
value1
,
shapes
,
sampling_locations
,
attention_weights
)
)
(
output_pytorch
.
sum
()
/
10
).
backward
()
max_abs_err
=
(
output_cuda
.
float
()
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_cuda
.
float
()
-
output_pytorch
).
abs
()
/
output_pytorch
.
abs
()).
max
()
fwdok
=
torch
.
allclose
(
output_cuda
.
float
(),
output_pytorch
,
rtol
=
1e-2
,
atol
=
1e-3
)
print
(
fwdok
)
print
(
max_abs_err
,
max_rel_err
)
#exit()
bwdok1
=
torch
.
allclose
(
value
.
grad
,
value1
.
grad
,
rtol
=
1e-2
,
atol
=
1e-3
)
print
(
bwdok1
)
# rel_err = (sampling_locations.grad - sampling_loc_attn.grad[..., :L*P*2].reshape(*sampling_locations.shape)).abs()/(sampling_locations.grad.abs()+1e-3)
# print(rel_err.max())
locgrad1
=
sampling_locations
.
grad
locgrad2
=
sampling_loc_attn
.
grad
[...,
:
L
*
P
*
2
].
reshape
(
*
sampling_locations
.
shape
)
bwdok2
=
torch
.
allclose
(
locgrad1
,
locgrad2
,
rtol
=
1e-2
,
atol
=
1e-3
)
print
(
bwdok2
)
rel_err
=
(
locgrad1
-
locgrad2
).
abs
()
/
(
locgrad1
.
abs
()
+
1e-3
)
print
(
rel_err
.
max
())
attngrad1
=
attention_weights_origin
.
grad
attngrad2
=
sampling_loc_attn
.
grad
[...,
L
*
P
*
2
:].
reshape
(
*
attention_weights_origin
.
shape
)
bwdok3
=
torch
.
allclose
(
locgrad1
,
locgrad2
,
rtol
=
1e-2
,
atol
=
1e-3
)
print
(
bwdok3
)
rel_err
=
(
attngrad1
-
attngrad2
).
abs
()
/
(
attngrad1
.
abs
()
+
1e-3
)
print
(
rel_err
.
max
())
exit
()
#exit()
# pdb.set_trace()
max_abs_err
=
(
output_cuda
-
output_pytorch
).
abs
().
max
()
max_rel_err
=
((
output_cuda
-
output_pytorch
).
abs
()
/
output_pytorch
.
abs
()).
max
()
fwdok
=
torch
.
allclose
(
output_cuda
,
output_pytorch
,
rtol
=
1e-2
,
atol
=
1e-3
)
print
(
f
"*
{
fwdok
}
check_forward_equal_with_pytorch_float: max_abs_err
{
max_abs_err
:.
2
e
}
max_rel_err
{
max_rel_err
:.
2
e
}
"
)
fn_args
=
(
value
,
shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
,
)
flash_dcn_fn_args
=
(
value
.
half
(),
shapes
,
level_start_index
,
sampling_loc_attn
.
half
(),
im2col_step
,
)
test_args
=
edict
({
'warmup_num'
:
50
,
'test_num'
:
100
})
exp_time
=
speed_test
(
FlashMSDeformAttnFunction
.
apply
,
test_args
,
flash_dcn_fn_args
,
name
=
'exp'
)
exp_time_base
=
speed_test
(
MSDeformAttnFunction
.
apply
,
test_args
,
fn_args
,
name
=
'exp'
)
results
=
[{}]
results
[
0
][
'time'
]
=
exp_time
results
[
0
][
'time_base'
]
=
exp_time_base
columns
=
list
(
results
[
0
].
keys
())
outputs
=
pd
.
DataFrame
(
results
,
columns
=
columns
)
with
pd
.
option_context
(
'display.max_rows'
,
None
,
'display.max_columns'
,
None
,
'display.max_colwidth'
,
None
,
'display.width'
,
None
,
'display.precision'
,
4
,
):
print
(
outputs
)
if
__name__
==
"__main__"
:
check_forward_equal_with_pytorch_half
()
\ No newline at end of file
DCNv4_op/setup.py
0 → 100644
View file @
5b17e272
# ------------------------------------------------------------------------------------------------
# Deformable Convolution v4
# Copyright (c) 2024 OpenGVLab
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
import
os
import
glob
import
torch
# 导入打包相关库
from
setuptools
import
find_packages
,
setup
from
torch.utils.cpp_extension
import
CUDA_HOME
,
CppExtension
,
CUDAExtension
# 定义获取扩展的函数(保持原样,供非打包模式使用)
def
get_extensions
():
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
extensions_dir
=
os
.
path
.
join
(
this_dir
,
"src"
)
main_file
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
"*.cpp"
))
source_cpu
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
"cpu"
,
"*.cpp"
))
source_cuda
=
glob
.
glob
(
os
.
path
.
join
(
extensions_dir
,
"cuda"
,
"*.cu"
))
sources
=
main_file
+
source_cpu
extension
=
CppExtension
extra_compile_args
=
{
"cxx"
:
[]}
define_macros
=
[]
if
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
:
extension
=
CUDAExtension
sources
+=
source_cuda
define_macros
+=
[(
"WITH_CUDA"
,
None
)]
extra_compile_args
[
"nvcc"
]
=
[
"-DCUDA_HAS_FP16=1"
,
"-D__CUDA_NO_HALF_OPERATORS__"
,
"-D__CUDA_NO_HALF_CONVERSIONS__"
,
"-D__CUDA_NO_HALF2_OPERATORS__"
,
"-O3"
,
]
else
:
raise
NotImplementedError
(
'Cuda is not available'
)
sources
=
[
os
.
path
.
join
(
extensions_dir
,
s
)
for
s
in
sources
]
include_dirs
=
[
extensions_dir
]
ext_modules
=
[
extension
(
"DCNv4.ext"
,
# 注意:这里保持原模块名,方便后面替换
sources
,
include_dirs
=
include_dirs
,
define_macros
=
define_macros
,
extra_compile_args
=
extra_compile_args
,
)
]
return
ext_modules
# --- 核心修改逻辑 ---
# 检查是否是构建 Wheel 的模式
# 如果是构建 Wheel,我们不编译,而是将现有的 .so 作为包数据处理
# 注意:setuptools 打包扩展模块和打包数据文件的逻辑是冲突的,所以我们需要在构建 Wheel 时禁用 ext_modules
if
__name__
==
"__main__"
:
# 检查环境变量,决定是否跳过编译
# 你也可以直接写一个布尔值,或者检查某个文件是否存在
build_so
=
int
(
os
.
getenv
(
"DCNv4_BUILD_SO"
,
"0"
))
# 准备参数
kwargs
=
{
"name"
:
"DCNv4"
,
"version"
:
"1.0.0.post2"
,
"author"
:
"Yuwen Xiong, Feng Wang"
,
"url"
:
""
,
"description"
:
"PyTorch Wrapper for CUDA Functions of DCNv4"
,
"packages"
:
[
'DCNv4'
,
'DCNv4/functions'
,
'DCNv4/modules'
],
"package_data"
:
{
"DCNv4"
:
[
"ext.so"
],
# 假设 ext.so 生成在 DCNv4 目录下
# "DCNv4": ["ext.cpython-310-x86_64-linux-gnu.so"], # 假设 ext.so 生成在 DCNv4 目录下
},
"cmdclass"
:
{
"build_ext"
:
torch
.
utils
.
cpp_extension
.
BuildExtension
},
# 确保生成正确的 .dist-info
"zip_safe"
:
False
,
# 添加以下参数来避免生成 .egg-info 在当前目录
"options"
:
{
'egg_info'
:
{
'egg_base'
:
'/tmp'
# 将 egg-info 生成到临时目录
}
},
}
if
build_so
:
# 正常开发模式,进行编译
kwargs
[
"ext_modules"
]
=
get_extensions
()
else
:
print
(
"=== BUILD WHEEL MODE: Skipping compilation, using existing ext.so ==="
)
# 在构建 Wheel 时,不要传入 ext_modules
# 我们依赖 MANIFEST.in 或 package_data 将 .so 文件包含进去
# 但是 setuptools 的 bdist_wheel 默认会忽略 .so,所以我们需要确保 .so 在包目录里
# 这里我们不传入 ext_modules,而是依靠外部脚本或 MANIFEST.in
# 更简单的方法:直接在 setup 里不写 ext_modules,确保 .so 已经在 DCNv4/ 目录下
kwargs
[
"ext_modules"
]
=
[]
# 强制不编译
setup
(
**
kwargs
)
\ No newline at end of file
DCNv4_op/src/cuda/common.h
0 → 100644
View file @
5b17e272
#ifndef FMSDACOMMON
#define FMSDACOMMON
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#ifdef _WIN32
#define uint unsigned int
#endif
constexpr
int
kWarpSize
=
32
;
#define opmath_t at::opmath_type<scalar_t>
inline
int
GET_BLOCKS
(
const
int
N
,
const
int
num_threads
)
{
return
(
N
+
num_threads
-
1
)
/
num_threads
;
}
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
inline
bool
check_backward_warpp
(
int
d_stride
,
int
D
){
int
n_group_threads
=
D
/
d_stride
;
return
(
n_group_threads
<=
kWarpSize
)
&&
(
kWarpSize
%
n_group_threads
==
0
);
}
template
<
typename
scalar_t
,
typename
transfer_t
,
int
c_per_thread
>
__device__
void
ms_deform_attn_im2col_bilinear
(
opmath_t
out_reg_array
[],
const
scalar_t
*&
p_value
,
const
int
&
height
,
const
int
&
width
,
const
opmath_t
&
h_px
,
const
opmath_t
&
w_px
,
const
opmath_t
&
attn
,
const
int
&
w_stride
,
const
int
&
base_ptr
)
{
const
int
h_low
=
floor
(
h_px
);
const
int
w_low
=
floor
(
w_px
);
const
int
h_high
=
h_low
+
1
;
const
int
w_high
=
w_low
+
1
;
const
opmath_t
lh
=
h_px
-
h_low
;
const
opmath_t
lw
=
w_px
-
w_low
;
const
opmath_t
hh
=
1
-
lh
;
const
opmath_t
hw
=
1
-
lw
;
const
opmath_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
const
int
h_stride
=
width
*
w_stride
;
const
int
h_low_ptr_offset
=
h_low
*
h_stride
;
const
int
h_high_ptr_offset
=
h_low_ptr_offset
+
h_stride
;
const
int
w_low_ptr_offset
=
w_low
*
w_stride
;
const
int
w_high_ptr_offset
=
w_low_ptr_offset
+
w_stride
;
int
idx1
=
h_low_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
int
idx2
=
h_low_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
int
idx3
=
h_high_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
int
idx4
=
h_high_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
scalar_t
v1_array
[
c_per_thread
]
=
{
0.
};
scalar_t
v2_array
[
c_per_thread
]
=
{
0.
};
scalar_t
v3_array
[
c_per_thread
]
=
{
0.
};
scalar_t
v4_array
[
c_per_thread
]
=
{
0.
};
if
(
h_low
>=
0
&&
w_low
>=
0
)
{
auto
p1
=
p_value
+
idx1
;
*
(
transfer_t
*
)(
v1_array
)
=
*
(
transfer_t
*
)(
p1
);
}
if
(
h_low
>=
0
&&
w_high
<
width
)
{
auto
p2
=
p_value
+
idx2
;
*
(
transfer_t
*
)(
v2_array
)
=
*
(
transfer_t
*
)(
p2
);
}
if
(
h_high
<
height
&&
w_low
>=
0
)
{
auto
p3
=
p_value
+
idx3
;
*
(
transfer_t
*
)(
v3_array
)
=
*
(
transfer_t
*
)(
p3
);
}
if
(
h_high
<
height
&&
w_high
<
width
)
{
auto
p4
=
p_value
+
idx4
;
*
(
transfer_t
*
)(
v4_array
)
=
*
(
transfer_t
*
)(
p4
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
c_per_thread
;
i
++
)
{
out_reg_array
[
i
]
+=
(
opmath_t
)
attn
*
(
w1
*
(
opmath_t
)
v1_array
[
i
]
+
w2
*
(
opmath_t
)
v2_array
[
i
]
+
w3
*
(
opmath_t
)
v3_array
[
i
]
+
w4
*
(
opmath_t
)
v4_array
[
i
]);
}
}
template
<
typename
scalar_t
,
typename
transfer_t
,
int
c_per_thread
>
__device__
void
ms_deform_attn_col2im_bilinear
(
const
scalar_t
*&
p_value
,
const
int
&
height
,
const
int
&
width
,
const
opmath_t
&
h_px
,
const
opmath_t
&
w_px
,
const
opmath_t
&
attn
,
const
int
&
w_stride
,
const
int
&
base_ptr
,
const
opmath_t
offset_scale_h
,
const
opmath_t
offset_scale_w
,
const
scalar_t
*&
top_grad
,
opmath_t
*&
grad_im
,
opmath_t
*
grad_offset
)
{
const
int
h_low
=
floor
(
h_px
);
const
int
w_low
=
floor
(
w_px
);
const
int
h_high
=
h_low
+
1
;
const
int
w_high
=
w_low
+
1
;
const
opmath_t
lh
=
h_px
-
h_low
;
const
opmath_t
lw
=
w_px
-
w_low
;
const
opmath_t
hh
=
1
-
lh
;
const
opmath_t
hw
=
1
-
lw
;
const
opmath_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
scalar_t
_top_grad_array
[
c_per_thread
]
=
{
0.
};
*
(
transfer_t
*
)(
_top_grad_array
)
=
*
(
transfer_t
*
)(
top_grad
);
opmath_t
top_grad_array
[
c_per_thread
]
=
{
0.
};
for
(
int
i
=
0
;
i
<
c_per_thread
;
++
i
)
{
top_grad_array
[
i
]
=
(
opmath_t
)(
_top_grad_array
[
i
]);
}
const
int
h_stride
=
width
*
w_stride
;
const
int
h_low_ptr_offset
=
h_low
*
h_stride
;
const
int
h_high_ptr_offset
=
h_low_ptr_offset
+
h_stride
;
const
int
w_low_ptr_offset
=
w_low
*
w_stride
;
const
int
w_high_ptr_offset
=
w_low_ptr_offset
+
w_stride
;
int
idx1
=
h_low_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
int
idx2
=
h_low_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
int
idx3
=
h_high_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
int
idx4
=
h_high_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
scalar_t
v1_array
[
c_per_thread
]
=
{
0.
};
scalar_t
v2_array
[
c_per_thread
]
=
{
0.
};
scalar_t
v3_array
[
c_per_thread
]
=
{
0.
};
scalar_t
v4_array
[
c_per_thread
]
=
{
0.
};
opmath_t
grad_h_weight
[
c_per_thread
]
=
{
0.
};
opmath_t
grad_w_weight
[
c_per_thread
]
=
{
0.
};
if
(
h_low
>=
0
&&
w_low
>=
0
)
{
auto
p1
=
p_value
+
idx1
;
*
(
transfer_t
*
)(
v1_array
)
=
*
(
transfer_t
*
)(
p1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
c_per_thread
;
++
i
)
{
grad_h_weight
[
i
]
-=
hw
*
v1_array
[
i
];
grad_w_weight
[
i
]
-=
hh
*
v1_array
[
i
];
atomicAdd
(
grad_im
+
idx1
+
i
,
top_grad_array
[
i
]
*
attn
*
w1
);
}
}
if
(
h_low
>=
0
&&
w_high
<
width
)
{
auto
p2
=
p_value
+
idx2
;
*
(
transfer_t
*
)(
v2_array
)
=
*
(
transfer_t
*
)(
p2
);
#pragma unroll
for
(
int
i
=
0
;
i
<
c_per_thread
;
++
i
)
{
grad_h_weight
[
i
]
-=
lw
*
v2_array
[
i
];
grad_w_weight
[
i
]
+=
hh
*
v2_array
[
i
];
atomicAdd
(
grad_im
+
idx2
+
i
,
top_grad_array
[
i
]
*
attn
*
w2
);
}
}
if
(
h_high
<
height
&&
w_low
>=
0
)
{
auto
p3
=
p_value
+
idx3
;
*
(
transfer_t
*
)(
v3_array
)
=
*
(
transfer_t
*
)(
p3
);
#pragma unroll
for
(
int
i
=
0
;
i
<
c_per_thread
;
++
i
)
{
grad_h_weight
[
i
]
+=
hw
*
v3_array
[
i
];
grad_w_weight
[
i
]
-=
lh
*
v3_array
[
i
];
atomicAdd
(
grad_im
+
idx3
+
i
,
top_grad_array
[
i
]
*
attn
*
w3
);
}
}
if
(
h_high
<
height
&&
w_high
<
width
)
{
auto
p4
=
p_value
+
idx4
;
*
(
transfer_t
*
)(
v4_array
)
=
*
(
transfer_t
*
)(
p4
);
#pragma unroll
for
(
int
i
=
0
;
i
<
c_per_thread
;
++
i
)
{
grad_h_weight
[
i
]
+=
lw
*
v4_array
[
i
];
grad_w_weight
[
i
]
+=
lh
*
v4_array
[
i
];
atomicAdd
(
grad_im
+
idx4
+
i
,
top_grad_array
[
i
]
*
attn
*
w4
);
}
}
opmath_t
_grad_offset_x
=
0
;
opmath_t
_grad_offset_y
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
c_per_thread
;
++
i
)
{
_grad_offset_x
+=
grad_w_weight
[
i
]
*
top_grad_array
[
i
];
// channel aware term
_grad_offset_y
+=
grad_h_weight
[
i
]
*
top_grad_array
[
i
];
// channel aware term
}
_grad_offset_x
*=
(
offset_scale_w
*
attn
);
// channel shared term
_grad_offset_y
*=
(
offset_scale_h
*
attn
);
// channel shared term
*
grad_offset
=
_grad_offset_x
;
*
(
grad_offset
+
1
)
=
_grad_offset_y
;
opmath_t
current_val
;
opmath_t
_grad_offset_z
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
c_per_thread
;
i
++
)
{
current_val
=
(
opmath_t
)(
w1
*
v1_array
[
i
]
+
w2
*
v2_array
[
i
]
+
w3
*
v3_array
[
i
]
+
w4
*
v4_array
[
i
]);
_grad_offset_z
+=
current_val
*
top_grad_array
[
i
];
}
*
(
grad_offset
+
2
)
=
_grad_offset_z
;
}
#endif
DCNv4_op/src/cuda/dcnv4_col2im_cuda.cuh
0 → 100644
View file @
5b17e272
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "common.h"
template
<
typename
scalar_t
,
int
d_stride
,
typename
transfer_t
,
int
L
,
int
K
,
bool
softmax
>
__global__
void
backward_kernel_dcn
(
const
scalar_t
*
p_value
,
const
scalar_t
*
p_offset
,
const
scalar_t
*
grad_output
,
const
int
G
,
const
int
D
,
const
int
Q
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
height_in
,
const
int
width_in
,
const
int
height_out
,
const
int
width_out
,
const
opmath_t
offset_scale
,
const
int
remove_center
,
const
int
block_multiplier
,
opmath_t
*
grad_im
,
opmath_t
*
grad_offset
,
const
int
padded_offset_dim
)
{
extern
__shared__
char
_s
[];
const
int
&
qi
=
(
blockIdx
.
x
*
block_multiplier
%
Q
)
+
threadIdx
.
z
;
const
int
&
bi
=
blockIdx
.
x
*
block_multiplier
/
Q
;
const
int
&
di_s
=
threadIdx
.
x
*
d_stride
;
const
int
&
gi
=
threadIdx
.
y
;
constexpr
int
li
=
0
;
opmath_t
*
const
cache_g_mask_before_softmax
=
(
opmath_t
*
)(
_s
);
// mG x K
opmath_t
*
const
cache_grad_offset
=
(
opmath_t
*
)(
cache_g_mask_before_softmax
+
block_multiplier
*
G
*
K
);
// mG x blockDim.x x 3
opmath_t
*
const
p_mask_shm
=
(
opmath_t
*
)(
cache_grad_offset
+
block_multiplier
*
G
*
blockDim
.
x
*
3
)
+
(
threadIdx
.
z
*
G
+
gi
)
*
K
;
const
scalar_t
*
p_offset_ptr
=
p_offset
+
(
bi
*
Q
+
qi
)
*
padded_offset_dim
+
gi
*
K
*
3
;
const
int
mask_length
=
K
;
const
int
num_thread
=
(
D
/
d_stride
);
const
int
num_iter
=
mask_length
/
num_thread
;
const
int
remainder
=
mask_length
-
num_iter
*
num_thread
;
const
scalar_t
*
top_grad
=
grad_output
+
((
bi
*
Q
+
qi
)
*
G
+
gi
)
*
D
+
di_s
;
__syncthreads
();
for
(
int
i
=
0
;
i
<
num_iter
;
i
++
)
{
*
(
p_mask_shm
+
num_thread
*
i
+
threadIdx
.
x
)
=
*
(
scalar_t
*
)(
p_offset_ptr
+
K
*
2
+
num_thread
*
i
+
threadIdx
.
x
);
}
if
(
remainder
>
0
&&
threadIdx
.
x
<
remainder
)
{
*
(
p_mask_shm
+
num_thread
*
num_iter
+
threadIdx
.
x
)
=
*
(
scalar_t
*
)(
p_offset_ptr
+
K
*
2
+
num_thread
*
num_iter
+
threadIdx
.
x
);
}
if
(
softmax
)
{
__syncthreads
();
// transfer offset from global memory to shared memory >
// Calculate softmax over L and K
if
(
threadIdx
.
x
==
0
)
{
// gi != 0, di = 0, li = 0
opmath_t
softmax_max
=
-
1e100
;
opmath_t
softmax_sum
=
0.0
;
// get max
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
softmax_max
=
max
(
softmax_max
,
p_mask_shm
[
j
]);
}
// get sumexp
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
opmath_t
exp_results
=
exp
(
p_mask_shm
[
j
]
-
softmax_max
);
p_mask_shm
[
j
]
=
exp_results
;
softmax_sum
+=
exp_results
;
}
// normalize
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
p_mask_shm
[
j
]
/=
softmax_sum
;
}
}
__syncthreads
();
}
int
offset_idx
=
0
;
int
mask_idx
=
0
;
const
int
w_stride
=
G
*
D
;
const
int
base_ptr
=
gi
*
D
+
di_s
;
const
scalar_t
*
p_value_ptr
=
p_value
+
(
bi
*
(
height_in
*
width_in
))
*
(
G
*
D
);
opmath_t
*
grad_im_ptr
=
grad_im
+
(
bi
*
(
height_in
*
width_in
))
*
(
G
*
D
);
const
int
p0_w
=
((
dilation_w
*
(
kernel_w
-
1
))
>>
1
)
-
pad_w
+
(
qi
%
width_out
)
*
stride_w
;
const
int
p0_h
=
((
dilation_h
*
(
kernel_h
-
1
))
>>
1
)
-
pad_h
+
(
qi
/
width_out
)
*
stride_h
;
const
opmath_t
p0_w_
=
p0_w
-
((
dilation_w
*
(
kernel_w
-
1
))
>>
1
)
*
offset_scale
;
const
opmath_t
p0_h_
=
p0_h
-
((
dilation_h
*
(
kernel_h
-
1
))
>>
1
)
*
offset_scale
;
const
int
center_h
=
kernel_h
/
2
;
const
int
center_w
=
kernel_w
/
2
;
grad_offset
+=
(
bi
*
Q
+
qi
)
*
padded_offset_dim
+
gi
*
K
*
3
;
opmath_t
*
grad_offset_softmax
=
grad_offset
+
K
*
2
;
int
cache_grad_off_idx
=
((
threadIdx
.
z
*
G
+
threadIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
)
*
3
;
for
(
int
i
=
0
;
i
<
kernel_w
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kernel_h
;
++
j
)
{
if
(
i
!=
center_w
||
j
!=
center_h
||
!
remove_center
)
{
const
opmath_t
w_im
=
p0_w_
+
(
i
*
dilation_w
+
(
opmath_t
)
p_offset_ptr
[
offset_idx
])
*
offset_scale
;
const
opmath_t
h_im
=
p0_h_
+
(
j
*
dilation_h
+
(
opmath_t
)
p_offset_ptr
[
offset_idx
+
1
])
*
offset_scale
;
const
opmath_t
attn
=
p_mask_shm
[
mask_idx
];
cache_grad_offset
[
cache_grad_off_idx
]
=
0
;
cache_grad_offset
[
cache_grad_off_idx
+
1
]
=
0
;
cache_grad_offset
[
cache_grad_off_idx
+
2
]
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
height_in
&&
w_im
<
width_in
)
{
ms_deform_attn_col2im_bilinear
<
scalar_t
,
transfer_t
,
d_stride
>
(
p_value_ptr
,
height_in
,
width_in
,
h_im
,
w_im
,
attn
,
w_stride
,
base_ptr
,
offset_scale
,
offset_scale
,
top_grad
,
grad_im_ptr
,
cache_grad_offset
+
cache_grad_off_idx
);
}
// aggregated across different channel for offset
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
//
int
_didx
=
(
threadIdx
.
z
*
G
+
threadIdx
.
y
)
*
blockDim
.
x
*
3
;
opmath_t
_grad_w
=
cache_grad_offset
[
_didx
],
_grad_h
=
cache_grad_offset
[
_didx
+
1
],
_grad_a
=
cache_grad_offset
[
_didx
+
2
];
for
(
int
c_id
=
1
;
c_id
<
blockDim
.
x
;
++
c_id
)
{
_grad_w
+=
cache_grad_offset
[
_didx
+
3
*
c_id
];
_grad_h
+=
cache_grad_offset
[
_didx
+
3
*
c_id
+
1
];
_grad_a
+=
cache_grad_offset
[
_didx
+
3
*
c_id
+
2
];
}
*
(
grad_offset
)
=
_grad_w
;
// B x H x W x G x L x K x 3
*
(
grad_offset
+
1
)
=
_grad_h
;
// B x H x W x G x L x K x 3
if
(
softmax
)
{
cache_g_mask_before_softmax
[(
threadIdx
.
z
*
G
+
threadIdx
.
y
)
*
K
+
mask_idx
]
=
_grad_a
*
attn
;
}
else
{
grad_offset_softmax
[
mask_idx
]
=
_grad_a
;
}
}
__syncthreads
();
offset_idx
+=
2
;
mask_idx
+=
1
;
grad_offset
+=
2
;
}
}
}
// backward for softmax
if
(
softmax
){
if
(
threadIdx
.
x
==
0
)
{
const
opmath_t
*
group_g_mask
=
cache_g_mask_before_softmax
+
(
threadIdx
.
z
*
G
+
threadIdx
.
y
)
*
K
;
#pragma unroll
for
(
int
i
=
0
;
i
<
K
;
++
i
)
{
opmath_t
sum
=
0.
;
for
(
int
j
=
0
;
j
<
K
;
++
j
)
{
sum
+=
group_g_mask
[
j
];
// dL/di * di/dj
}
*
(
grad_offset_softmax
)
=
group_g_mask
[
i
]
-
p_mask_shm
[
i
]
*
sum
;
grad_offset_softmax
+=
1
;
}
}
__syncthreads
();
}
}
template
<
typename
scalar_t
,
int
d_stride
,
typename
transfer_t
,
int
L
,
int
K
,
bool
softmax
>
__global__
void
backward_kernel_dcn_warp_primitive
(
const
scalar_t
*
p_value
,
const
scalar_t
*
p_offset
,
const
scalar_t
*
grad_output
,
const
int
G
,
const
int
D
,
const
int
Q
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
height_in
,
const
int
width_in
,
const
int
height_out
,
const
int
width_out
,
const
opmath_t
offset_scale
,
const
int
remove_center
,
const
int
block_multiplier
,
opmath_t
*
grad_im
,
opmath_t
*
grad_offset
,
const
int
padded_offset_dim
)
{
extern
__shared__
char
_s
[];
const
int
&
qi
=
(
blockIdx
.
x
*
block_multiplier
%
Q
)
+
threadIdx
.
z
;
const
int
&
bi
=
blockIdx
.
x
*
block_multiplier
/
Q
;
const
int
&
di_s
=
threadIdx
.
x
*
d_stride
;
const
int
&
gi
=
threadIdx
.
y
;
constexpr
int
li
=
0
;
const
int
tid
=
(
threadIdx
.
z
*
blockDim
.
y
+
threadIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
lane_id
=
tid
%
kWarpSize
;
// find the position of current group in the current warp
const
int
group_per_warp
=
kWarpSize
/
blockDim
.
x
;
const
int
group_in_warp_id
=
(
threadIdx
.
z
*
G
+
threadIdx
.
y
)
%
group_per_warp
;
const
unsigned
lane_mask
=
((
1
<<
blockDim
.
x
)
-
1
)
<<
(
group_in_warp_id
*
blockDim
.
x
);
opmath_t
*
const
p_mask_shm
=
(
opmath_t
*
)(
_s
)
+
(
threadIdx
.
z
*
G
+
gi
)
*
K
;
opmath_t
*
cache_g_mask_before_softmax
=
(
opmath_t
*
)((
opmath_t
*
)(
_s
)
+
block_multiplier
*
G
*
K
)
+
(
threadIdx
.
z
*
G
+
gi
)
*
K
;
// only used by threadIdx.x = 0
const
scalar_t
*
p_offset_ptr
=
p_offset
+
(
bi
*
Q
+
qi
)
*
padded_offset_dim
+
gi
*
K
*
3
;
const
int
mask_length
=
K
;
const
int
num_thread
=
(
D
/
d_stride
);
const
int
num_iter
=
mask_length
/
num_thread
;
const
int
remainder
=
mask_length
-
num_iter
*
num_thread
;
const
scalar_t
*
top_grad
=
grad_output
+
((
bi
*
Q
+
qi
)
*
G
+
gi
)
*
D
+
di_s
;
__syncthreads
();
for
(
int
i
=
0
;
i
<
num_iter
;
i
++
)
{
*
(
p_mask_shm
+
num_thread
*
i
+
threadIdx
.
x
)
=
*
(
scalar_t
*
)(
p_offset_ptr
+
K
*
2
+
num_thread
*
i
+
threadIdx
.
x
);
}
if
(
remainder
>
0
&&
threadIdx
.
x
<
remainder
)
{
*
(
p_mask_shm
+
num_thread
*
num_iter
+
threadIdx
.
x
)
=
*
(
scalar_t
*
)(
p_offset_ptr
+
K
*
2
+
num_thread
*
num_iter
+
threadIdx
.
x
);
}
if
(
softmax
)
{
__syncthreads
();
// transfer offset from global memory to shared memory >
// Calculate softmax over L and K
if
(
threadIdx
.
x
==
0
)
{
// gi != 0, di = 0, li = 0
opmath_t
softmax_max
=
-
1e100
;
opmath_t
softmax_sum
=
0.0
;
// get max
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
softmax_max
=
max
(
softmax_max
,
p_mask_shm
[
j
]);
}
// get sumexp
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
opmath_t
exp_results
=
exp
(
p_mask_shm
[
j
]
-
softmax_max
);
p_mask_shm
[
j
]
=
exp_results
;
softmax_sum
+=
exp_results
;
}
// normalize
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
p_mask_shm
[
j
]
/=
softmax_sum
;
}
}
__syncthreads
();
}
int
offset_idx
=
0
;
int
mask_idx
=
0
;
const
int
w_stride
=
G
*
D
;
const
int
base_ptr
=
gi
*
D
+
di_s
;
const
scalar_t
*
p_value_ptr
=
p_value
+
(
bi
*
(
height_in
*
width_in
))
*
(
G
*
D
);
opmath_t
*
grad_im_ptr
=
grad_im
+
(
bi
*
(
height_in
*
width_in
))
*
(
G
*
D
);
const
int
p0_w
=
((
dilation_w
*
(
kernel_w
-
1
))
>>
1
)
-
pad_w
+
(
qi
%
width_out
)
*
stride_w
;
const
int
p0_h
=
((
dilation_h
*
(
kernel_h
-
1
))
>>
1
)
-
pad_h
+
(
qi
/
width_out
)
*
stride_h
;
const
opmath_t
p0_w_
=
p0_w
-
((
dilation_w
*
(
kernel_w
-
1
))
>>
1
)
*
offset_scale
;
const
opmath_t
p0_h_
=
p0_h
-
((
dilation_h
*
(
kernel_h
-
1
))
>>
1
)
*
offset_scale
;
const
int
center_h
=
kernel_h
/
2
;
const
int
center_w
=
kernel_w
/
2
;
grad_offset
+=
(
bi
*
Q
+
qi
)
*
padded_offset_dim
+
gi
*
K
*
3
;
opmath_t
*
grad_offset_softmax
=
grad_offset
+
K
*
2
;
int
cache_grad_off_idx
=
((
threadIdx
.
z
*
G
+
threadIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
)
*
3
;
opmath_t
reg_grad_offset
[
3
]
=
{
0.
};
for
(
int
i
=
0
;
i
<
kernel_w
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kernel_h
;
++
j
)
{
if
(
i
!=
center_w
||
j
!=
center_h
||
!
remove_center
)
{
const
opmath_t
w_im
=
p0_w_
+
(
i
*
dilation_w
+
(
opmath_t
)
p_offset_ptr
[
offset_idx
])
*
offset_scale
;
const
opmath_t
h_im
=
p0_h_
+
(
j
*
dilation_h
+
(
opmath_t
)
p_offset_ptr
[
offset_idx
+
1
])
*
offset_scale
;
const
opmath_t
attn
=
p_mask_shm
[
mask_idx
];
reg_grad_offset
[
0
]
=
0
;
reg_grad_offset
[
1
]
=
0
;
reg_grad_offset
[
2
]
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
height_in
&&
w_im
<
width_in
)
{
ms_deform_attn_col2im_bilinear
<
scalar_t
,
transfer_t
,
d_stride
>
(
p_value_ptr
,
height_in
,
width_in
,
h_im
,
w_im
,
attn
,
w_stride
,
base_ptr
,
offset_scale
,
offset_scale
,
top_grad
,
grad_im_ptr
,
reg_grad_offset
);
}
// aggregated across different channel for offset
for
(
uint32_t
offset
=
blockDim
.
x
>>
1
;
offset
>
0
;
offset
>>=
1
){
reg_grad_offset
[
0
]
+=
__shfl_down_sync
(
lane_mask
,
reg_grad_offset
[
0
],
offset
);
reg_grad_offset
[
1
]
+=
__shfl_down_sync
(
lane_mask
,
reg_grad_offset
[
1
],
offset
);
reg_grad_offset
[
2
]
+=
__shfl_down_sync
(
lane_mask
,
reg_grad_offset
[
2
],
offset
);
}
if
(
threadIdx
.
x
==
0
)
{
//
*
(
grad_offset
)
=
reg_grad_offset
[
0
];
// B x H x W x G x L x K x 3
*
(
grad_offset
+
1
)
=
reg_grad_offset
[
1
];
// B x H x W x G x L x K x 3
if
(
softmax
)
{
cache_g_mask_before_softmax
[
mask_idx
]
=
reg_grad_offset
[
2
]
*
attn
;
}
else
{
grad_offset_softmax
[
mask_idx
]
=
reg_grad_offset
[
2
];
}
}
offset_idx
+=
2
;
mask_idx
+=
1
;
grad_offset
+=
2
;
}
}
}
// backward for softmax
if
(
softmax
){
if
(
threadIdx
.
x
==
0
)
{
opmath_t
sum
=
0.
;
#pragma unroll
for
(
int
i
=
0
;
i
<
K
;
++
i
){
sum
+=
cache_g_mask_before_softmax
[
i
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
K
;
++
i
)
{
*
(
grad_offset_softmax
)
=
cache_g_mask_before_softmax
[
i
]
-
p_mask_shm
[
i
]
*
sum
;
grad_offset_softmax
+=
1
;
}
}
}
}
template
<
typename
scalar_t
,
typename
stride_type
,
int
d_stride
>
void
_dcnv4_col2im_cuda
(
cudaStream_t
stream
,
const
scalar_t
*
value
,
// B, H * W, (G * D)
const
scalar_t
*
p_offset
,
// B, H * W, (G*K*3)
const
scalar_t
*
grad_output
,
// B, H_out*W_out, G * D
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
G
,
const
int
D
,
const
int
B
,
const
int
height_in
,
const
int
width_in
,
const
int
height_out
,
const
int
width_out
,
const
opmath_t
offset_scale
,
const
int
remove_center
,
opmath_t
*
grad_im
,
opmath_t
*
grad_offset
,
const
int
block_thread
,
const
bool
softmax
,
const
int
padded_offset_dim
)
{
constexpr
int
L
=
1
;
auto
kernel
=
backward_kernel_dcn_warp_primitive
<
scalar_t
,
d_stride
,
stride_type
,
1
,
9
,
false
>
;
int
N
=
height_in
*
width_in
;
int
Q
=
height_out
*
width_out
;
int
K
=
kernel_h
*
kernel_w
;
if
(
remove_center
)
{
K
-=
1
;
}
if
(
softmax
)
{
switch
(
K
)
{
case
9
:
if
(
check_backward_warpp
(
d_stride
,
D
)){
kernel
=
backward_kernel_dcn_warp_primitive
<
scalar_t
,
d_stride
,
stride_type
,
1
,
9
,
true
>
;
}
else
{
kernel
=
backward_kernel_dcn
<
scalar_t
,
d_stride
,
stride_type
,
1
,
9
,
true
>
;
}
break
;
case
8
:
if
(
check_backward_warpp
(
d_stride
,
D
)){
kernel
=
backward_kernel_dcn_warp_primitive
<
scalar_t
,
d_stride
,
stride_type
,
1
,
8
,
true
>
;
}
else
{
kernel
=
backward_kernel_dcn
<
scalar_t
,
d_stride
,
stride_type
,
1
,
8
,
true
>
;
}
break
;
default:
printf
(
"K=%ld
\n
"
,
K
);
throw
std
::
invalid_argument
(
"invalid kernel shape"
);
}
}
else
{
switch
(
K
)
{
case
9
:
if
(
check_backward_warpp
(
d_stride
,
D
)){
kernel
=
backward_kernel_dcn_warp_primitive
<
scalar_t
,
d_stride
,
stride_type
,
1
,
9
,
false
>
;
}
else
{
kernel
=
backward_kernel_dcn
<
scalar_t
,
d_stride
,
stride_type
,
1
,
9
,
false
>
;
}
break
;
case
8
:
if
(
check_backward_warpp
(
d_stride
,
D
)){
kernel
=
backward_kernel_dcn_warp_primitive
<
scalar_t
,
d_stride
,
stride_type
,
1
,
8
,
false
>
;
}
else
{
kernel
=
backward_kernel_dcn
<
scalar_t
,
d_stride
,
stride_type
,
1
,
8
,
false
>
;
}
break
;
default:
printf
(
"K=%ld
\n
"
,
K
);
throw
std
::
invalid_argument
(
"invalid kernel shape"
);
}
}
const
int
block_multiplier
=
block_thread
/
(
D
/
d_stride
)
/
G
;
assert
((
B
*
Q
)
%
block_multiplier
==
0
);
dim3
num_blocks
(
B
*
Q
/
block_multiplier
);
dim3
num_threads
(
D
/
d_stride
,
G
,
block_multiplier
);
const
int
blockdimX
=
D
/
d_stride
;
int
shm_size
=
sizeof
(
opmath_t
)
*
(
G
*
block_multiplier
*
K
)
*
2
;
if
(
!
check_backward_warpp
(
d_stride
,
D
)){
shm_size
=
sizeof
(
opmath_t
)
*
((
G
*
block_multiplier
*
K
)
*
2
+
G
*
block_multiplier
*
blockdimX
*
3
);
}
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shm_size
);
kernel
<<<
num_blocks
,
num_threads
,
shm_size
,
stream
>>>
(
value
,
p_offset
,
grad_output
,
G
,
D
,
Q
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
block_multiplier
,
grad_im
,
grad_offset
,
padded_offset_dim
);
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in dcnv4_im2col_cuda: %s
\n
"
,
cudaGetErrorString
(
err
));
printf
(
"launch arguments: gridDim=(%d, %d, %d), blockDim=(%d, %d, %d), "
"shm_size=%d
\n\n
"
,
num_blocks
.
x
,
num_blocks
.
y
,
num_blocks
.
z
,
num_threads
.
x
,
num_threads
.
y
,
num_threads
.
z
,
shm_size
);
AT_ASSERTM
(
false
,
"kernel launch error"
);
}
}
template
<
typename
scalar_t
>
void
dcnv4_col2im_cuda
(
cudaStream_t
stream
,
const
scalar_t
*
value
,
// B, H * W, (G * D)
const
scalar_t
*
p_offset
,
// B, H * W, (G*K*3)
const
scalar_t
*
grad_output
,
// B, H_out*W_out, G * D
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
G
,
const
int
D
,
const
int
B
,
const
int
height_in
,
const
int
width_in
,
const
int
height_out
,
const
int
width_out
,
const
opmath_t
offset_scale
,
const
int
remove_center
,
opmath_t
*
grad_im
,
opmath_t
*
grad_offset
,
const
int
d_stride
,
const
int
block_thread
,
const
bool
softmax
,
const
int
padded_offset_dim
)
{
assert
(
D
%
d_stride
==
0
);
const
int
size_scalar
=
sizeof
(
scalar_t
);
if
(
size_scalar
==
2
)
{
switch
(
d_stride
)
{
case
1
:
_dcnv4_col2im_cuda
<
scalar_t
,
scalar_t
,
1
>
(
stream
,
value
,
p_offset
,
grad_output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
grad_im
,
grad_offset
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
2
:
_dcnv4_col2im_cuda
<
scalar_t
,
uint
,
2
>
(
stream
,
value
,
p_offset
,
grad_output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
grad_im
,
grad_offset
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
4
:
_dcnv4_col2im_cuda
<
scalar_t
,
uint2
,
4
>
(
stream
,
value
,
p_offset
,
grad_output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
grad_im
,
grad_offset
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
8
:
_dcnv4_col2im_cuda
<
scalar_t
,
uint4
,
8
>
(
stream
,
value
,
p_offset
,
grad_output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
grad_im
,
grad_offset
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
16
:
_dcnv4_col2im_cuda
<
scalar_t
,
ulonglong4
,
16
>
(
stream
,
value
,
p_offset
,
grad_output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
grad_im
,
grad_offset
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
}
}
else
{
assert
(
size_scalar
==
4
);
switch
(
d_stride
)
{
case
1
:
_dcnv4_col2im_cuda
<
scalar_t
,
uint
,
1
>
(
stream
,
value
,
p_offset
,
grad_output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
grad_im
,
grad_offset
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
2
:
_dcnv4_col2im_cuda
<
scalar_t
,
uint2
,
2
>
(
stream
,
value
,
p_offset
,
grad_output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
grad_im
,
grad_offset
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
4
:
_dcnv4_col2im_cuda
<
scalar_t
,
uint4
,
4
>
(
stream
,
value
,
p_offset
,
grad_output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
grad_im
,
grad_offset
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
8
:
_dcnv4_col2im_cuda
<
scalar_t
,
ulonglong4
,
8
>
(
stream
,
value
,
p_offset
,
grad_output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
grad_im
,
grad_offset
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
}
}
}
\ No newline at end of file
DCNv4_op/src/cuda/dcnv4_cuda.cu
0 → 100644
View file @
5b17e272
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "cuda/dcnv4_im2col_cuda.cuh"
#include "cuda/dcnv4_col2im_cuda.cuh"
#include <vector>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/torch.h>
at
::
Tensor
dcnv4_cuda_forward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
p_offset
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
group
,
const
int
group_channels
,
const
float
offset_scale
,
const
int
im2col_step
,
const
int
remove_center
,
const
int
d_stride
,
const
int
block_thread
,
const
bool
softmax
)
{
AT_ASSERTM
(
value
.
is_contiguous
(),
"input tensor has to be contiguous"
);
AT_ASSERTM
(
value
.
type
().
is_cuda
(),
"input must be a CUDA tensor"
);
AT_ASSERTM
(
p_offset
.
is_contiguous
(),
"input tensor has to be contiguous"
);
AT_ASSERTM
(
p_offset
.
type
().
is_cuda
(),
"input must be a CUDA tensor"
);
const
int
batch
=
value
.
size
(
0
);
const
int
height_in
=
value
.
size
(
1
);
const
int
width_in
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
padded_offset_dim
=
p_offset
.
size
(
3
);
// tensor core requirement
assert
(
padded_offset_dim
%
8
==
0
);
const
int
height_out
=
(
height_in
+
2
*
pad_h
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
/
stride_h
+
1
;
const
int
width_out
=
(
width_in
+
2
*
pad_w
-
(
dilation_w
*
(
kernel_w
-
1
)
+
1
))
/
stride_w
+
1
;
const
int
im2col_step_
=
std
::
min
(
batch
,
im2col_step
);
AT_ASSERTM
(
batch
%
im2col_step_
==
0
,
"batch("
,
batch
,
") must divide im2col_step("
,
im2col_step_
,
")"
);
AT_ASSERTM
(
channels
==
(
group
*
group_channels
),
"Input channels and group times group channels wont match: (%d vs %d)."
,
channels
,
group
*
group_channels
);
auto
output
=
at
::
zeros
(
{
batch
,
height_out
,
width_out
,
group
*
group_channels
},
value
.
options
());
const
int
batch_n
=
im2col_step_
;
auto
output_n
=
output
.
view
({
batch
/
batch_n
,
batch_n
,
height_out
,
width_out
,
group
*
group_channels
});
auto
per_value_size
=
height_in
*
width_in
*
channels
;
auto
per_offset_size
=
height_out
*
width_out
*
padded_offset_dim
;
for
(
int
n
=
0
;
n
<
batch
/
im2col_step_
;
++
n
)
{
auto
columns
=
output_n
.
select
(
0
,
n
);
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
value
.
scalar_type
(),
"dcnv4_forward_cuda"
,
([
&
]
{
dcnv4_im2col_cuda
(
at
::
cuda
::
getCurrentCUDAStream
(),
value
.
data_ptr
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_value_size
,
p_offset
.
data_ptr
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_offset_size
,
columns
.
data_ptr
<
scalar_t
>
(),
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
group_channels
,
batch_n
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
d_stride
,
block_thread
,
softmax
,
padded_offset_dim
);
}));
}
return
output
;
}
std
::
vector
<
at
::
Tensor
>
dcnv4_cuda_backward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
p_offset
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
group
,
const
int
group_channels
,
const
float
offset_scale
,
const
int
im2col_step
,
const
at
::
Tensor
&
grad_output
,
const
int
remove_center
,
const
int
d_stride
,
const
int
block_thread
,
const
bool
softmax
)
{
AT_ASSERTM
(
value
.
is_contiguous
(),
"input tensor has to be contiguous"
);
AT_ASSERTM
(
p_offset
.
is_contiguous
(),
"offset tensor has to be contiguous"
);
AT_ASSERTM
(
grad_output
.
is_contiguous
(),
"grad_output tensor has to be contiguous"
);
AT_ASSERTM
(
value
.
type
().
is_cuda
(),
"input must be a CUDA tensor"
);
AT_ASSERTM
(
p_offset
.
type
().
is_cuda
(),
"offset must be a CUDA tensor"
);
AT_ASSERTM
(
grad_output
.
type
().
is_cuda
(),
"grad_output must be a CUDA tensor"
);
const
int
batch
=
value
.
size
(
0
);
const
int
height_in
=
value
.
size
(
1
);
const
int
width_in
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
padded_offset_dim
=
p_offset
.
size
(
3
);
assert
(
padded_offset_dim
%
8
==
0
);
const
int
height_out
=
(
height_in
+
2
*
pad_h
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
/
stride_h
+
1
;
const
int
width_out
=
(
width_in
+
2
*
pad_w
-
(
dilation_w
*
(
kernel_w
-
1
)
+
1
))
/
stride_w
+
1
;
const
int
im2col_step_
=
std
::
min
(
batch
,
im2col_step
);
AT_ASSERTM
(
batch
%
im2col_step_
==
0
,
"batch("
,
batch
,
") must divide im2col_step("
,
im2col_step_
,
")"
);
AT_ASSERTM
(
channels
==
(
group
*
group_channels
),
"Input channels and group times group channels wont match: (%d vs %d)."
,
channels
,
group
*
group_channels
);
auto
dtype
=
value
.
dtype
();
if
(
dtype
==
at
::
kHalf
){
dtype
=
at
::
kFloat
;
}
auto
grad_input
=
at
::
zeros_like
(
value
,
dtype
);
auto
grad_offset
=
at
::
zeros_like
(
p_offset
,
dtype
);
const
int
batch_n
=
im2col_step_
;
auto
grad_output_n
=
grad_output
.
view
({
batch
/
batch_n
,
batch_n
,
height_out
,
width_out
,
group
,
group_channels
});
auto
per_value_size
=
height_in
*
width_in
*
channels
;
auto
per_offset_size
=
height_out
*
width_out
*
padded_offset_dim
;
for
(
int
n
=
0
;
n
<
batch
/
im2col_step_
;
++
n
)
{
auto
columns
=
grad_output_n
.
select
(
0
,
n
);
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
value
.
scalar_type
(),
"dcnv4_backward_cuda"
,
([
&
]
{
dcnv4_col2im_cuda
(
at
::
cuda
::
getCurrentCUDAStream
(),
value
.
data_ptr
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_value_size
,
p_offset
.
data_ptr
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_offset_size
,
columns
.
data_ptr
<
scalar_t
>
(),
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
group_channels
,
batch_n
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
grad_input
.
data
<
opmath_t
>
()
+
n
*
im2col_step_
*
per_value_size
,
grad_offset
.
data
<
opmath_t
>
()
+
n
*
im2col_step_
*
per_offset_size
,
d_stride
,
block_thread
,
softmax
,
padded_offset_dim
);
}));
}
if
(
value
.
dtype
()
==
torch
::
kHalf
){
return
{
grad_input
.
to
(
torch
::
kHalf
),
grad_offset
.
to
(
torch
::
kHalf
)};
}
else
{
return
{
grad_input
,
grad_offset
};
}
}
DCNv4_op/src/cuda/dcnv4_cuda.h
0 → 100644
View file @
5b17e272
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at
::
Tensor
dcnv4_cuda_forward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
p_offset
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
group
,
const
int
group_channels
,
const
float
offset_scale
,
const
int
im2col_step
,
const
int
remove_center
,
const
int
d_stride
,
const
int
block_thread
,
const
bool
softmax
);
std
::
vector
<
at
::
Tensor
>
dcnv4_cuda_backward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
p_offset
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
group
,
const
int
group_channels
,
const
float
offset_scale
,
const
int
im2col_step
,
const
at
::
Tensor
&
grad_output
,
const
int
remove_center
,
const
int
d_stride
,
const
int
block_thread
,
const
bool
softmax
);
\ No newline at end of file
DCNv4_op/src/cuda/dcnv4_im2col_cuda.cuh
0 → 100644
View file @
5b17e272
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "common.h"
template
<
typename
scalar_t
,
int
d_stride
,
typename
transfer_t
,
int
L
,
int
K
,
bool
softmax
>
__global__
void
forward_kernel_dcn
(
const
scalar_t
*
p_value
,
const
scalar_t
*
p_offset
,
scalar_t
*
p_output
,
const
int
G
,
const
int
D
,
const
int
Q
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
height_in
,
const
int
width_in
,
const
int
height_out
,
const
int
width_out
,
const
opmath_t
offset_scale
,
const
int
remove_center
,
const
int
block_multiplier
,
const
int
padded_offset_dim
)
{
const
int
&
qi
=
(
blockIdx
.
x
*
block_multiplier
%
Q
)
+
threadIdx
.
z
;
const
int
&
bi
=
blockIdx
.
x
*
block_multiplier
/
Q
;
const
int
&
di_s
=
threadIdx
.
x
*
d_stride
;
const
int
&
gi
=
threadIdx
.
y
;
constexpr
int
li
=
0
;
extern
__shared__
char
_s
[];
opmath_t
*
const
p_mask_shm
=
(
opmath_t
*
)(
_s
)
+
((
threadIdx
.
z
*
G
+
gi
)
*
L
+
li
)
*
K
;
opmath_t
p_out_shm
[
d_stride
]
=
{
0.
};
const
scalar_t
*
p_offset_ptr
=
p_offset
+
(
bi
*
Q
+
qi
)
*
padded_offset_dim
+
gi
*
K
*
3
;
const
int
mask_length
=
K
;
const
int
num_thread
=
(
D
/
d_stride
);
const
int
num_iter
=
mask_length
/
num_thread
;
const
int
remainder
=
mask_length
-
num_iter
*
num_thread
;
for
(
int
i
=
0
;
i
<
num_iter
;
i
++
)
{
*
(
p_mask_shm
+
num_thread
*
i
+
threadIdx
.
x
)
=
*
(
scalar_t
*
)(
p_offset_ptr
+
K
*
2
+
num_thread
*
i
+
threadIdx
.
x
);
}
if
(
remainder
>
0
&&
threadIdx
.
x
<
remainder
)
{
*
(
p_mask_shm
+
num_thread
*
num_iter
+
threadIdx
.
x
)
=
*
(
scalar_t
*
)(
p_offset_ptr
+
K
*
2
+
num_thread
*
num_iter
+
threadIdx
.
x
);
}
int
mask_idx
;
if
(
softmax
)
{
__syncthreads
();
// Calculate softmax over L and K
if
(
threadIdx
.
x
==
0
)
{
// gi != 0, di = 0, li = 0
opmath_t
softmax_max
=
-
1e100
;
opmath_t
softmax_sum
=
0.0
;
// get max
// #pragma unroll
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
softmax_max
=
max
(
softmax_max
,
p_mask_shm
[
j
]);
}
// get sumexp
// #pragma unroll
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
opmath_t
exp_results
=
exp
(
p_mask_shm
[
j
]
-
softmax_max
);
p_mask_shm
[
j
]
=
exp_results
;
softmax_sum
+=
exp_results
;
}
// normalize
// #pragma unroll
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
p_mask_shm
[
j
]
/=
softmax_sum
;
}
}
__syncthreads
();
}
int
offset_idx
=
0
;
mask_idx
=
0
;
const
int
w_stride
=
G
*
D
;
const
int
base_ptr
=
gi
*
D
+
di_s
;
const
scalar_t
*
p_value_ptr
=
p_value
+
(
bi
*
(
height_in
*
width_in
))
*
(
G
*
D
);
const
int
p0_w
=
((
dilation_w
*
(
kernel_w
-
1
))
>>
1
)
-
pad_w
+
(
qi
%
width_out
)
*
stride_w
;
const
int
p0_h
=
((
dilation_h
*
(
kernel_h
-
1
))
>>
1
)
-
pad_h
+
(
qi
/
width_out
)
*
stride_h
;
const
opmath_t
p0_w_
=
p0_w
-
((
dilation_w
*
(
kernel_w
-
1
))
>>
1
)
*
offset_scale
;
const
opmath_t
p0_h_
=
p0_h
-
((
dilation_h
*
(
kernel_h
-
1
))
>>
1
)
*
offset_scale
;
const
int
center_h
=
kernel_h
/
2
;
const
int
center_w
=
kernel_w
/
2
;
int
out_idx
=
((
bi
*
Q
+
qi
)
*
G
+
gi
)
*
D
+
di_s
;
for
(
int
i
=
0
;
i
<
kernel_w
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kernel_h
;
++
j
)
{
if
(
i
!=
center_w
||
j
!=
center_h
||
!
remove_center
)
{
const
opmath_t
w_im
=
p0_w_
+
(
i
*
dilation_w
+
(
opmath_t
)
p_offset_ptr
[
offset_idx
])
*
offset_scale
;
const
opmath_t
h_im
=
p0_h_
+
(
j
*
dilation_h
+
(
opmath_t
)
p_offset_ptr
[
offset_idx
+
1
])
*
offset_scale
;
const
opmath_t
attn
=
p_mask_shm
[
mask_idx
];
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
height_in
&&
w_im
<
width_in
)
{
ms_deform_attn_im2col_bilinear
<
scalar_t
,
transfer_t
,
d_stride
>
(
p_out_shm
,
p_value_ptr
,
height_in
,
width_in
,
h_im
,
w_im
,
attn
,
w_stride
,
base_ptr
);
}
offset_idx
+=
2
;
mask_idx
+=
1
;
}
}
}
scalar_t
*
fp16_regs
=
(
scalar_t
*
)(
p_out_shm
);
#pragma unroll
for
(
int
ds
=
0
;
ds
<
d_stride
;
ds
++
)
{
fp16_regs
[
ds
]
=
p_out_shm
[
ds
];
}
*
(
transfer_t
*
)(
p_output
+
out_idx
)
=
*
(
transfer_t
*
)(
p_out_shm
);
}
template
<
typename
scalar_t
,
int
d_stride
,
typename
transfer_t
,
int
L
,
int
K
,
bool
softmax
>
__global__
void
forward_kernel_dcn_reg
(
const
scalar_t
*
p_value
,
const
scalar_t
*
p_offset
,
scalar_t
*
p_output
,
const
int
G
,
const
int
D
,
const
int
Q
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
height_in
,
const
int
width_in
,
const
int
height_out
,
const
int
width_out
,
const
opmath_t
offset_scale
,
const
int
remove_center
,
const
int
block_multiplier
,
const
int
padded_offset_dim
)
{
const
int
&
qi
=
(
blockIdx
.
x
*
block_multiplier
%
Q
)
+
threadIdx
.
z
;
const
int
&
bi
=
blockIdx
.
x
*
block_multiplier
/
Q
;
const
int
&
di_s
=
threadIdx
.
x
*
d_stride
;
const
int
&
gi
=
threadIdx
.
y
;
constexpr
int
li
=
0
;
opmath_t
p_mask_shm
[
K
]
=
{
0.
};
opmath_t
p_out_shm
[
d_stride
]
=
{
0.
};
const
scalar_t
*
p_offset_ptr
=
p_offset
+
(
bi
*
Q
+
qi
)
*
padded_offset_dim
+
gi
*
K
*
3
;
const
int
mask_length
=
K
;
const
int
num_thread
=
(
D
/
d_stride
);
const
int
num_iter
=
mask_length
/
num_thread
;
const
int
remainder
=
mask_length
-
num_iter
*
num_thread
;
for
(
int
i
=
0
;
i
<
K
;
i
++
){
p_mask_shm
[
i
]
=
*
(
p_offset_ptr
+
K
*
2
+
i
);
}
if
(
softmax
)
{
// Calculate softmax over L and K
opmath_t
softmax_max
=
-
1e100
;
opmath_t
softmax_sum
=
0.0
;
// get max
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
softmax_max
=
max
(
softmax_max
,
p_mask_shm
[
j
]);
}
// get sumexp
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
opmath_t
exp_results
=
exp
(
p_mask_shm
[
j
]
-
softmax_max
);
p_mask_shm
[
j
]
=
exp_results
;
softmax_sum
+=
exp_results
;
}
// normalize
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
p_mask_shm
[
j
]
/=
softmax_sum
;
}
}
int
offset_idx
=
0
;
int
mask_idx
=
0
;
const
int
w_stride
=
G
*
D
;
const
int
base_ptr
=
gi
*
D
+
di_s
;
const
scalar_t
*
p_value_ptr
=
p_value
+
(
bi
*
(
height_in
*
width_in
))
*
(
G
*
D
);
const
int
p0_w
=
((
dilation_w
*
(
kernel_w
-
1
))
>>
1
)
-
pad_w
+
(
qi
%
width_out
)
*
stride_w
;
const
int
p0_h
=
((
dilation_h
*
(
kernel_h
-
1
))
>>
1
)
-
pad_h
+
(
qi
/
width_out
)
*
stride_h
;
const
opmath_t
p0_w_
=
p0_w
-
((
dilation_w
*
(
kernel_w
-
1
))
>>
1
)
*
offset_scale
;
const
opmath_t
p0_h_
=
p0_h
-
((
dilation_h
*
(
kernel_h
-
1
))
>>
1
)
*
offset_scale
;
const
int
center_h
=
kernel_h
/
2
;
const
int
center_w
=
kernel_w
/
2
;
int
out_idx
=
((
bi
*
Q
+
qi
)
*
G
+
gi
)
*
D
+
di_s
;
for
(
int
i
=
0
;
i
<
kernel_w
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kernel_h
;
++
j
)
{
if
(
i
!=
center_w
||
j
!=
center_h
||
!
remove_center
)
{
const
opmath_t
w_im
=
p0_w_
+
(
i
*
dilation_w
+
(
opmath_t
)
p_offset_ptr
[
offset_idx
])
*
offset_scale
;
const
opmath_t
h_im
=
p0_h_
+
(
j
*
dilation_h
+
(
opmath_t
)
p_offset_ptr
[
offset_idx
+
1
])
*
offset_scale
;
const
opmath_t
attn
=
p_mask_shm
[
mask_idx
];
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
height_in
&&
w_im
<
width_in
)
{
ms_deform_attn_im2col_bilinear
<
scalar_t
,
transfer_t
,
d_stride
>
(
p_out_shm
,
p_value_ptr
,
height_in
,
width_in
,
h_im
,
w_im
,
attn
,
w_stride
,
base_ptr
);
}
offset_idx
+=
2
;
mask_idx
+=
1
;
}
}
}
scalar_t
*
fp16_regs
=
(
scalar_t
*
)(
p_out_shm
);
#pragma unroll
for
(
int
ds
=
0
;
ds
<
d_stride
;
ds
++
)
{
fp16_regs
[
ds
]
=
p_out_shm
[
ds
];
}
*
(
transfer_t
*
)(
p_output
+
out_idx
)
=
*
(
transfer_t
*
)(
p_out_shm
);
}
template
<
typename
scalar_t
,
typename
stride_type
,
int
d_stride
>
void
_dcnv4_im2col_cuda
(
cudaStream_t
stream
,
const
scalar_t
*
value
,
// B, H * W, (G * D)
const
scalar_t
*
p_offset
,
// B, H * W, G * K * 3)
scalar_t
*
output
,
// B, H_out*W_out, G * D
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
G
,
const
int
D
,
const
int
B
,
const
int
height_in
,
const
int
width_in
,
const
int
height_out
,
const
int
width_out
,
const
opmath_t
offset_scale
,
const
int
remove_center
,
const
int
block_thread
,
const
int
softmax
,
const
int
padded_offset_dim
)
{
constexpr
int
L
=
1
;
auto
kernel
=
forward_kernel_dcn_reg
<
scalar_t
,
d_stride
,
stride_type
,
1
,
9
,
true
>
;
int
N
=
height_in
*
width_in
;
int
Q
=
height_out
*
width_out
;
int
K
=
kernel_h
*
kernel_w
;
if
(
remove_center
)
{
K
-=
1
;
}
if
(
softmax
)
{
switch
(
K
)
{
case
9
:
kernel
=
forward_kernel_dcn_reg
<
scalar_t
,
d_stride
,
stride_type
,
1
,
9
,
true
>
;
break
;
case
8
:
kernel
=
forward_kernel_dcn_reg
<
scalar_t
,
d_stride
,
stride_type
,
1
,
8
,
true
>
;
break
;
default:
printf
(
"K=%ld
\n
"
,
K
);
throw
std
::
invalid_argument
(
"invalid kernel shape"
);
}
}
else
{
switch
(
K
)
{
case
9
:
kernel
=
forward_kernel_dcn_reg
<
scalar_t
,
d_stride
,
stride_type
,
1
,
9
,
false
>
;
break
;
case
8
:
kernel
=
forward_kernel_dcn_reg
<
scalar_t
,
d_stride
,
stride_type
,
1
,
8
,
false
>
;
break
;
default:
printf
(
"K=%ld
\n
"
,
K
);
throw
std
::
invalid_argument
(
"invalid kernel shape"
);
}
}
const
int
block_multiplier
=
block_thread
/
(
D
/
d_stride
)
/
G
;
assert
((
B
*
Q
)
%
block_multiplier
==
0
);
dim3
num_blocks
(
B
*
Q
/
block_multiplier
);
dim3
num_threads
(
D
/
d_stride
,
G
,
block_multiplier
);
int
shm_size
=
0
;
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shm_size
);
kernel
<<<
num_blocks
,
num_threads
,
shm_size
,
stream
>>>
(
value
,
p_offset
,
output
,
G
,
D
,
Q
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
block_multiplier
,
padded_offset_dim
);
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in dcnv4_im2col_cuda: %s
\n
"
,
cudaGetErrorString
(
err
));
printf
(
"launch arguments: gridDim=(%d, %d, %d), blockDim=(%d, %d, %d), "
"shm_size=%d
\n\n
"
,
num_blocks
.
x
,
num_blocks
.
y
,
num_blocks
.
z
,
num_threads
.
x
,
num_threads
.
y
,
num_threads
.
z
,
shm_size
);
AT_ASSERTM
(
false
,
"kernel launch error"
);
}
}
template
<
typename
scalar_t
>
void
dcnv4_im2col_cuda
(
cudaStream_t
stream
,
const
scalar_t
*
value
,
// B, H * W, (G * D)
const
scalar_t
*
p_offset
,
// B, H * W, G * K * 3)
scalar_t
*
output
,
// B, H_out*W_out, G * D
const
int
kernel_h
,
const
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
G
,
const
int
D
,
const
int
B
,
const
int
height_in
,
const
int
width_in
,
const
int
height_out
,
const
int
width_out
,
const
opmath_t
offset_scale
,
const
int
remove_center
,
const
int
d_stride
,
const
int
block_thread
,
const
bool
softmax
,
const
int
padded_offset_dim
)
{
assert
(
D
%
d_stride
==
0
);
if
(
sizeof
(
scalar_t
)
==
2
)
{
switch
(
d_stride
)
{
case
1
:
_dcnv4_im2col_cuda
<
scalar_t
,
scalar_t
,
1
>
(
stream
,
value
,
p_offset
,
output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
2
:
_dcnv4_im2col_cuda
<
scalar_t
,
uint
,
2
>
(
stream
,
value
,
p_offset
,
output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
4
:
_dcnv4_im2col_cuda
<
scalar_t
,
uint2
,
4
>
(
stream
,
value
,
p_offset
,
output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
8
:
_dcnv4_im2col_cuda
<
scalar_t
,
uint4
,
8
>
(
stream
,
value
,
p_offset
,
output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
16
:
_dcnv4_im2col_cuda
<
scalar_t
,
ulonglong4
,
16
>
(
stream
,
value
,
p_offset
,
output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
}
}
else
{
assert
(
sizeof
(
scalar_t
)
==
4
);
switch
(
d_stride
)
{
case
1
:
_dcnv4_im2col_cuda
<
scalar_t
,
uint
,
1
>
(
stream
,
value
,
p_offset
,
output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
2
:
_dcnv4_im2col_cuda
<
scalar_t
,
uint2
,
2
>
(
stream
,
value
,
p_offset
,
output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
4
:
_dcnv4_im2col_cuda
<
scalar_t
,
uint4
,
4
>
(
stream
,
value
,
p_offset
,
output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
case
8
:
_dcnv4_im2col_cuda
<
scalar_t
,
ulonglong4
,
8
>
(
stream
,
value
,
p_offset
,
output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
G
,
D
,
B
,
height_in
,
width_in
,
height_out
,
width_out
,
offset_scale
,
remove_center
,
block_thread
,
softmax
,
padded_offset_dim
);
break
;
default:
printf
(
"not supported for d_stride > 8 for fp32"
);
throw
std
::
invalid_argument
(
"invalid d_stride"
);
}
}
}
\ No newline at end of file
DCNv4_op/src/cuda/flash_deform_attn_cuda.cu
0 → 100644
View file @
5b17e272
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "cuda/flash_deform_im2col_cuda.cuh"
#include "cuda/flash_deform_col2im_cuda.cuh"
#include <vector>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/torch.h>
at
::
Tensor
flash_deform_attn_cuda_forward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
spatial_shapes
,
const
at
::
Tensor
&
level_start_index
,
const
at
::
Tensor
&
sampling_loc_attn
,
const
int
im2col_step
=
64
,
const
int
K
=
8
,
const
int
d_stride
=
8
,
const
int
block_thread
=
0
)
{
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc_attn
.
is_contiguous
(),
"sampling_loc_attn tensor has to be contiguous"
);
AT_ASSERTM
(
value
.
type
().
is_cuda
(),
"value must be a CUDA tensor"
);
AT_ASSERTM
(
spatial_shapes
.
type
().
is_cuda
(),
"spatial_shapes must be a CUDA tensor"
);
AT_ASSERTM
(
level_start_index
.
type
().
is_cuda
(),
"level_start_index must be a CUDA tensor"
);
AT_ASSERTM
(
sampling_loc_attn
.
type
().
is_cuda
(),
"sampling_loc_attn must be a CUDA tensor"
);
const
int
batch
=
value
.
size
(
0
);
const
int
spatial_size
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
num_channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_query
=
sampling_loc_attn
.
size
(
1
);
const
int
num_point
=
K
;
const
int
im2col_step_
=
std
::
min
(
batch
,
im2col_step
);
AT_ASSERTM
(
batch
%
im2col_step_
==
0
,
"batch("
,
batch
,
") must divide im2col_step("
,
im2col_step_
,
")"
);
auto
output
=
at
::
zeros
({
batch
,
num_query
,
num_heads
,
num_channels
},
value
.
options
());
auto
per_value_size
=
spatial_size
*
num_heads
*
num_channels
;
auto
per_offset_size
=
num_query
*
num_heads
*
num_levels
*
num_point
*
3
;
auto
per_out_size
=
num_query
*
num_heads
*
num_channels
;
for
(
int
n
=
0
;
n
<
batch
/
im2col_step_
;
++
n
)
{
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
value
.
scalar_type
(),
"flash_deform_attn_forward_cuda"
,
([
&
]
{
flash_deformable_im2col_cuda
(
at
::
cuda
::
getCurrentCUDAStream
(),
value
.
data_ptr
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_value_size
,
spatial_shapes
.
data
<
int64_t
>
(),
level_start_index
.
data
<
int64_t
>
(),
sampling_loc_attn
.
data_ptr
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_offset_size
,
output
.
data_ptr
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_out_size
,
im2col_step_
,
spatial_size
,
num_heads
,
num_channels
,
num_levels
,
num_query
,
num_point
,
d_stride
,
block_thread
,
true
);
}));
}
output
=
output
.
view
({
batch
,
num_query
,
num_heads
*
num_channels
});
return
output
;
}
std
::
vector
<
at
::
Tensor
>
flash_deform_attn_cuda_backward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
spatial_shapes
,
const
at
::
Tensor
&
level_start_index
,
const
at
::
Tensor
&
sampling_loc_attn
,
const
at
::
Tensor
&
grad_output
,
const
int
im2col_step
=
64
,
const
int
K
=
8
,
const
int
d_stride
=
2
,
const
int
block_thread
=
0
)
{
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc_attn
.
is_contiguous
(),
"sampling_loc_attn tensor has to be contiguous"
);
AT_ASSERTM
(
grad_output
.
is_contiguous
(),
"grad_output tensor has to be contiguous"
);
AT_ASSERTM
(
value
.
type
().
is_cuda
(),
"value must be a CUDA tensor"
);
AT_ASSERTM
(
spatial_shapes
.
type
().
is_cuda
(),
"spatial_shapes must be a CUDA tensor"
);
AT_ASSERTM
(
level_start_index
.
type
().
is_cuda
(),
"level_start_index must be a CUDA tensor"
);
AT_ASSERTM
(
sampling_loc_attn
.
type
().
is_cuda
(),
"sampling_loc_attn must be a CUDA tensor"
);
AT_ASSERTM
(
grad_output
.
type
().
is_cuda
(),
"grad_output must be a CUDA tensor"
);
const
int
batch
=
value
.
size
(
0
);
const
int
spatial_size
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
num_channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_query
=
sampling_loc_attn
.
size
(
1
);
const
int
num_point
=
K
;
const
int
im2col_step_
=
std
::
min
(
batch
,
im2col_step
);
AT_ASSERTM
(
batch
%
im2col_step_
==
0
,
"batch("
,
batch
,
") must divide im2col_step("
,
im2col_step_
,
")"
);
auto
dtype
=
value
.
dtype
();
if
(
dtype
==
at
::
kHalf
){
dtype
=
at
::
kFloat
;
}
auto
grad_input
=
at
::
zeros_like
(
value
,
dtype
);
auto
grad_offset
=
at
::
zeros_like
(
sampling_loc_attn
,
dtype
);
auto
per_value_size
=
spatial_size
*
num_heads
*
num_channels
;
auto
per_offset_size
=
num_query
*
num_heads
*
num_levels
*
num_point
*
3
;
auto
per_out_size
=
num_query
*
num_heads
*
num_channels
;
for
(
int
n
=
0
;
n
<
batch
/
im2col_step_
;
++
n
)
{
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
value
.
scalar_type
(),
"flash_deform_attn_backward_cuda"
,
([
&
]
{
flash_deformable_col2im_cuda
(
at
::
cuda
::
getCurrentCUDAStream
(),
value
.
data_ptr
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_value_size
,
spatial_shapes
.
data
<
int64_t
>
(),
level_start_index
.
data
<
int64_t
>
(),
sampling_loc_attn
.
data_ptr
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_offset_size
,
grad_output
.
data_ptr
<
scalar_t
>
()
+
n
*
im2col_step_
*
per_out_size
,
im2col_step_
,
spatial_size
,
num_heads
,
num_channels
,
num_levels
,
num_query
,
num_point
,
grad_input
.
data
<
opmath_t
>
()
+
n
*
im2col_step_
*
per_value_size
,
grad_offset
.
data
<
opmath_t
>
()
+
n
*
im2col_step_
*
per_offset_size
,
d_stride
,
block_thread
);
}));
}
if
(
value
.
dtype
()
==
torch
::
kHalf
){
return
{
grad_input
.
to
(
torch
::
kHalf
),
grad_offset
.
to
(
torch
::
kHalf
)};
}
else
{
return
{
grad_input
,
grad_offset
};
}
}
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment