Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
85b37e7b
Unverified
Commit
85b37e7b
authored
Apr 10, 2022
by
q.yao
Committed by
GitHub
Apr 10, 2022
Browse files
[Enhancment] Optimize correlation op (#1814)
* optimize forward * fast backward * fix bugs of grad input2
parent
cff3fecc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
104 additions
and
109 deletions
+104
-109
mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
+83
-89
mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu
mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu
+21
-20
No files found.
mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
View file @
85b37e7b
...
@@ -29,8 +29,8 @@ using namespace torch;
...
@@ -29,8 +29,8 @@ using namespace torch;
#define TensorAcc5R PackedTensorAccessor32<scalar_t, 5, RestrictPtrTraits>
#define TensorAcc5R PackedTensorAccessor32<scalar_t, 5, RestrictPtrTraits>
#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)
#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)
#define
THREADS_FORWARD
32
#define
WARP_SIZE
32
#define
THREADS_BACKWARD 16
#define
FULL_MASK 0xffffffff
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
correlation_forward_cuda_kernel
(
__global__
void
correlation_forward_cuda_kernel
(
...
@@ -42,8 +42,8 @@ __global__ void correlation_forward_cuda_kernel(
...
@@ -42,8 +42,8 @@ __global__ void correlation_forward_cuda_kernel(
const
int
C
=
rInput1
.
size
(
3
);
const
int
C
=
rInput1
.
size
(
3
);
const
int
n
=
blockIdx
.
x
;
const
int
n
=
blockIdx
.
x
;
const
int
h
=
blockIdx
.
y
;
const
int
h
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
int
w
=
blockIdx
.
z
;
const
int
w
=
blockIdx
.
z
*
blockDim
.
z
+
threadIdx
.
z
;
const
int
thread
=
threadIdx
.
x
;
const
int
thread
=
threadIdx
.
x
;
const
int
start_i
=
-
padH
+
h
*
dH
;
const
int
start_i
=
-
padH
+
h
*
dH
;
...
@@ -52,13 +52,11 @@ __global__ void correlation_forward_cuda_kernel(
...
@@ -52,13 +52,11 @@ __global__ void correlation_forward_cuda_kernel(
const
int
patchRadH
=
dilation_patchH
*
(
patchH
-
1
)
/
2
;
const
int
patchRadH
=
dilation_patchH
*
(
patchH
-
1
)
/
2
;
const
int
patchRadW
=
dilation_patchW
*
(
patchW
-
1
)
/
2
;
const
int
patchRadW
=
dilation_patchW
*
(
patchW
-
1
)
/
2
;
__shared__
scalar_t
prod_sum
[
THREADS_FORWARD
];
for
(
int
ph
=
0
;
ph
<
patchH
;
++
ph
)
{
for
(
int
ph
=
0
;
ph
<
patchH
;
++
ph
)
{
int
ph_dilated
=
ph
*
dilation_patchH
-
patchRadH
;
int
ph_dilated
=
ph
*
dilation_patchH
-
patchRadH
;
for
(
int
pw
=
0
;
pw
<
patchW
;
++
pw
)
{
for
(
int
pw
=
0
;
pw
<
patchW
;
++
pw
)
{
int
pw_dilated
=
pw
*
dilation_patchW
-
patchRadW
;
int
pw_dilated
=
pw
*
dilation_patchW
-
patchRadW
;
prod_sum
[
thread
]
=
0
;
scalar_t
prod_sum
=
0
.0
f
;
for
(
int
i
=
0
;
i
<
kH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kH
;
++
i
)
{
int
i1
=
start_i
+
i
*
dilationH
;
int
i1
=
start_i
+
i
*
dilationH
;
int
i2
=
i1
+
ph_dilated
;
int
i2
=
i1
+
ph_dilated
;
...
@@ -69,23 +67,20 @@ __global__ void correlation_forward_cuda_kernel(
...
@@ -69,23 +67,20 @@ __global__ void correlation_forward_cuda_kernel(
int
j2
=
j1
+
pw_dilated
;
int
j2
=
j1
+
pw_dilated
;
if
if
WITHIN_BOUNDS
(
j1
,
j2
,
iW
,
iW
)
{
WITHIN_BOUNDS
(
j1
,
j2
,
iW
,
iW
)
{
for
(
int
c
=
thread
;
c
<
C
;
c
+=
THREADS_FORWARD
)
{
for
(
int
c
=
thread
;
c
<
C
;
c
+=
WARP_SIZE
)
{
scalar_t
v1
=
rInput1
[
n
][
i1
][
j1
][
c
];
scalar_t
v1
=
rInput1
[
n
][
i1
][
j1
][
c
];
scalar_t
v2
=
rInput2
[
n
][
i2
][
j2
][
c
];
scalar_t
v2
=
rInput2
[
n
][
i2
][
j2
][
c
];
prod_sum
[
thread
]
+=
v1
*
v2
;
prod_sum
+=
v1
*
v2
;
}
}
}
}
}
}
}
}
}
}
// accumulate
// accumulate
__syncthreads
();
for
(
int
offset
=
16
;
offset
>
0
;
offset
/=
2
)
prod_sum
+=
__shfl_down_sync
(
FULL_MASK
,
float
(
prod_sum
),
offset
);
if
(
thread
==
0
)
{
if
(
thread
==
0
)
{
scalar_t
reduce_sum
=
0
;
output
[
n
][
ph
][
pw
][
h
][
w
]
=
prod_sum
;
for
(
int
index
=
0
;
index
<
THREADS_FORWARD
;
++
index
)
{
reduce_sum
+=
prod_sum
[
index
];
}
output
[
n
][
ph
][
pw
][
h
][
w
]
=
reduce_sum
;
}
}
}
}
}
}
...
@@ -97,9 +92,10 @@ __global__ void correlation_backward_cuda_kernel_input1(
...
@@ -97,9 +92,10 @@ __global__ void correlation_backward_cuda_kernel_input1(
TensorAcc4R
grad_input1
,
const
int
kH
,
const
int
kW
,
const
int
patchH
,
TensorAcc4R
grad_input1
,
const
int
kH
,
const
int
kW
,
const
int
patchH
,
const
int
patchW
,
const
int
padH
,
const
int
padW
,
const
int
dilationH
,
const
int
patchW
,
const
int
padH
,
const
int
padW
,
const
int
dilationH
,
const
int
dilationW
,
const
int
dilation_patchH
,
const
int
dilation_patchW
,
const
int
dilationW
,
const
int
dilation_patchH
,
const
int
dilation_patchW
,
const
int
dH
,
const
int
dW
,
const
int
batch
)
{
const
int
dH
,
const
int
dW
)
{
const
int
iH
=
input2
.
size
(
2
);
const
int
iH
=
input2
.
size
(
1
);
const
int
iW
=
input2
.
size
(
3
);
const
int
iW
=
input2
.
size
(
2
);
const
int
C
=
input2
.
size
(
3
);
const
int
H
=
grad_output
.
size
(
3
);
const
int
H
=
grad_output
.
size
(
3
);
const
int
W
=
grad_output
.
size
(
4
);
const
int
W
=
grad_output
.
size
(
4
);
...
@@ -107,54 +103,53 @@ __global__ void correlation_backward_cuda_kernel_input1(
...
@@ -107,54 +103,53 @@ __global__ void correlation_backward_cuda_kernel_input1(
const
int
patchRadH
=
(
patchH
-
1
)
/
2
;
const
int
patchRadH
=
(
patchH
-
1
)
/
2
;
const
int
patchRadW
=
(
patchW
-
1
)
/
2
;
const
int
patchRadW
=
(
patchW
-
1
)
/
2
;
const
int
n
=
batch
;
const
int
n
=
blockIdx
.
x
;
const
int
c
=
blockIdx
.
x
;
const
int
h
=
blockIdx
.
y
;
const
int
h
=
blockIdx
.
y
;
const
int
w
=
blockIdx
.
z
;
const
int
w
=
blockIdx
.
z
;
const
int
ph_off
=
threadIdx
.
x
;
const
int
pw_off
=
threadIdx
.
y
;
const
int
h_2
=
h
+
padH
;
const
int
h_2
=
h
+
padH
;
const
int
w_2
=
w
+
padW
;
const
int
w_2
=
w
+
padW
;
const
int
min_h
=
h_2
-
kH
*
dilationH
;
const
int
min_h
=
h_2
-
kH
*
dilationH
;
const
int
min_w
=
w_2
-
kW
*
dilationW
;
const
int
min_w
=
w_2
-
kW
*
dilationW
;
__shared__
scalar_t
prod_sum
[
THREADS_BACKWARD
][
THREADS_BACKWARD
];
extern
__shared__
__align__
(
sizeof
(
4
))
unsigned
char
grad_cache_char
[];
prod_sum
[
ph_off
][
pw_off
]
=
0
;
scalar_t
*
grad_cache
=
reinterpret_cast
<
scalar_t
*>
(
grad_cache_char
);
for
(
int
i
=
threadIdx
.
x
;
i
<
patchH
*
patchW
;
i
+=
blockDim
.
x
)
{
for
(
int
ph
=
ph_off
;
ph
<
patchH
;
ph
+=
THREADS_BACKWARD
)
{
const
int
ph
=
i
/
patchW
;
const
int
pw
=
i
%
patchW
;
int
i1
=
h
+
dilation_patchH
*
(
ph
-
patchRadH
);
int
i1
=
h
+
dilation_patchH
*
(
ph
-
patchRadH
);
for
(
int
pw
=
pw_off
;
pw
<
patchW
;
pw
+=
THREADS_BACKWARD
)
{
int
j1
=
w
+
dilation_patchW
*
(
pw
-
patchRadW
);
int
j1
=
w
+
dilation_patchW
*
(
pw
-
patchRadW
);
if
(
WITHIN_BOUNDS
(
i1
,
j1
,
iH
,
iW
))
{
if
(
WITHIN_BOUNDS
(
i1
,
j1
,
iH
,
iW
))
{
scalar_t
val
=
input2
[
n
][
c
][
i1
][
j1
];
scalar_t
grad_val
=
0.0
f
;
for
(
int
h_3
=
h_2
;
h_3
>
min_h
;
h_3
-=
dilationH
)
{
for
(
int
h_3
=
h_2
;
h_3
>
min_h
;
h_3
-=
dilationH
)
{
int
i2
=
(
h_3
)
/
dH
;
int
i2
=
(
h_3
)
/
dH
;
if
(
i2
*
dH
!=
h_3
)
continue
;
if
(
i2
*
dH
!=
h_3
)
continue
;
for
(
int
w_3
=
w_2
;
w_3
>
min_w
;
w_3
-=
dilationW
)
{
for
(
int
w_3
=
w_2
;
w_3
>
min_w
;
w_3
-=
dilationW
)
{
int
j2
=
(
w_3
)
/
dW
;
int
j2
=
(
w_3
)
/
dW
;
if
(
j2
*
dW
!=
w_3
)
continue
;
if
(
j2
*
dW
!=
w_3
)
continue
;
if
if
(
WITHIN_BOUNDS
(
i2
,
j2
,
H
,
W
))
{
WITHIN_BOUNDS
(
i2
,
j2
,
H
,
W
)
{
grad_val
+=
grad_output
[
n
][
ph
][
pw
][
i2
][
j2
];
prod_sum
[
ph_off
][
pw_off
]
+=
grad_output
[
n
][
ph
][
pw
][
i2
][
j2
]
*
val
;
}
}
}
}
}
}
}
grad_cache
[
i
]
=
grad_val
;
}
}
}
}
__syncthreads
();
__syncthreads
();
if
(
ph_off
==
0
&&
pw_off
==
0
)
{
for
(
int
c
=
threadIdx
.
x
;
c
<
C
;
c
+=
blockDim
.
x
)
{
scalar_t
reduce_sum
=
0
;
scalar_t
grad_input_val
=
0.0
f
;
for
(
int
ph
=
0
;
ph
<
THREADS_BACKWARD
;
++
ph
)
{
for
(
int
ph
=
0
;
ph
<
patchH
;
++
ph
)
{
for
(
int
pw
=
0
;
pw
<
THREADS_BACKWARD
;
++
pw
)
{
int
i1
=
h
+
dilation_patchH
*
(
ph
-
patchRadH
);
reduce_sum
+=
prod_sum
[
ph
][
pw
];
for
(
int
pw
=
0
;
pw
<
patchW
;
++
pw
)
{
int
j1
=
w
+
dilation_patchW
*
(
pw
-
patchRadW
);
if
(
WITHIN_BOUNDS
(
i1
,
j1
,
iH
,
iW
))
{
grad_input_val
+=
input2
[
n
][
i1
][
j1
][
c
]
*
grad_cache
[
ph
*
patchW
+
pw
];
}
}
}
}
}
grad_input1
[
n
][
c
][
h
][
w
]
=
reduce_sum
;
grad_input1
[
n
][
c
][
h
][
w
]
=
grad_input_val
;
}
}
}
}
...
@@ -163,9 +158,10 @@ __global__ void correlation_backward_cuda_kernel_input2(
...
@@ -163,9 +158,10 @@ __global__ void correlation_backward_cuda_kernel_input2(
const
TensorAcc5R
grad_output
,
const
TensorAcc4R
input1
,
const
TensorAcc5R
grad_output
,
const
TensorAcc4R
input1
,
TensorAcc4R
grad_input2
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
TensorAcc4R
grad_input2
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
,
int
batch
)
{
int
dilation_patchW
,
int
dH
,
int
dW
)
{
const
int
iH
=
input1
.
size
(
2
);
const
int
iH
=
input1
.
size
(
1
);
const
int
iW
=
input1
.
size
(
3
);
const
int
iW
=
input1
.
size
(
2
);
const
int
C
=
input1
.
size
(
3
);
const
int
patchRadH
=
(
patchH
-
1
)
/
2
;
const
int
patchRadH
=
(
patchH
-
1
)
/
2
;
const
int
patchRadW
=
(
patchW
-
1
)
/
2
;
const
int
patchRadW
=
(
patchW
-
1
)
/
2
;
...
@@ -176,56 +172,54 @@ __global__ void correlation_backward_cuda_kernel_input2(
...
@@ -176,56 +172,54 @@ __global__ void correlation_backward_cuda_kernel_input2(
const
int
dilatedKH
=
kH
*
dilationH
;
const
int
dilatedKH
=
kH
*
dilationH
;
const
int
dilatedKW
=
kW
*
dilationW
;
const
int
dilatedKW
=
kW
*
dilationW
;
const
int
n
=
batch
;
const
int
n
=
blockIdx
.
x
;
const
int
c
=
blockIdx
.
x
;
const
int
h
=
blockIdx
.
y
;
const
int
h
=
blockIdx
.
y
;
const
int
w
=
blockIdx
.
z
;
const
int
w
=
blockIdx
.
z
;
const
int
ph_off
=
threadIdx
.
x
;
const
int
pw_off
=
threadIdx
.
y
;
__shared__
scalar_t
prod_sum
[
THREADS_BACKWARD
][
THREADS_BACKWARD
];
prod_sum
[
ph_off
][
pw_off
]
=
0
;
for
(
int
ph
=
ph_off
;
ph
<
patchH
;
ph
+=
THREADS_BACKWARD
)
{
extern
__shared__
__align__
(
sizeof
(
4
))
unsigned
char
grad_cache_char
[];
scalar_t
*
grad_cache
=
reinterpret_cast
<
scalar_t
*>
(
grad_cache_char
);
for
(
int
i
=
threadIdx
.
x
;
i
<
patchH
*
patchW
;
i
+=
blockDim
.
x
)
{
const
int
ph
=
i
/
patchW
;
const
int
pw
=
i
%
patchW
;
int
i1
=
h
-
dilation_patchH
*
(
ph
-
patchRadH
);
int
i1
=
h
-
dilation_patchH
*
(
ph
-
patchRadH
);
for
(
int
pw
=
pw_off
;
pw
<
patchW
;
pw
+=
THREADS_BACKWARD
)
{
int
j1
=
w
-
dilation_patchW
*
(
pw
-
patchRadW
);
int
j1
=
w
-
dilation_patchW
*
(
pw
-
patchRadW
);
if
if
(
WITHIN_BOUNDS
(
i1
,
j1
,
iH
,
iW
))
{
WITHIN_BOUNDS
(
i1
,
j1
,
iH
,
iW
)
{
scalar_t
grad_val
=
0.0
f
;
scalar_t
val
=
input1
[
n
][
c
][
i1
][
j1
];
const
int
h_2
=
i1
+
padH
;
const
int
h_2
=
i1
+
padH
;
const
int
w_2
=
j1
+
padW
;
const
int
w_2
=
j1
+
padW
;
const
int
min_h
=
h_2
-
dilatedKH
;
const
int
min_h
=
h_2
-
dilatedKH
;
const
int
min_w
=
w_2
-
dilatedKW
;
const
int
min_w
=
w_2
-
dilatedKW
;
for
(
int
h_3
=
h_2
;
h_3
>
min_h
;
h_3
-=
dilationH
)
{
for
(
int
h_3
=
h_2
;
h_3
>
min_h
;
h_3
-=
dilationH
)
{
int
i2
=
(
h_3
)
/
dH
;
int
i2
=
(
h_3
)
/
dH
;
if
(
i2
*
dH
!=
h_3
)
continue
;
if
(
i2
*
dH
!=
h_3
)
continue
;
for
(
int
w_3
=
w_2
;
w_3
>
min_w
;
w_3
-=
dilationW
)
{
for
(
int
w_3
=
w_2
;
w_3
>
min_w
;
w_3
-=
dilationW
)
{
int
j2
=
(
w_3
)
/
dW
;
int
j2
=
(
w_3
)
/
dW
;
if
(
j2
*
dW
!=
w_3
)
continue
;
if
(
j2
*
dW
!=
w_3
)
continue
;
if
(
WITHIN_BOUNDS
(
i2
,
j2
,
H
,
W
))
{
if
grad_val
+=
grad_output
[
n
][
ph
][
pw
][
i2
][
j2
];
WITHIN_BOUNDS
(
i2
,
j2
,
H
,
W
)
{
prod_sum
[
ph_off
][
pw_off
]
+=
grad_output
[
n
][
ph
][
pw
][
i2
][
j2
]
*
val
;
}
}
}
}
}
}
}
grad_cache
[
i
]
=
grad_val
;
}
}
}
}
__syncthreads
();
__syncthreads
();
if
(
ph_off
==
0
&&
pw_off
==
0
)
{
for
(
int
c
=
threadIdx
.
x
;
c
<
C
;
c
+=
blockDim
.
x
)
{
scalar_t
reduce_sum
=
0
;
scalar_t
grad_input_val
=
0.0
f
;
for
(
int
ph
=
0
;
ph
<
THREADS_BACKWARD
;
++
ph
)
{
for
(
int
ph
=
0
;
ph
<
patchH
;
++
ph
)
{
for
(
int
pw
=
0
;
pw
<
THREADS_BACKWARD
;
++
pw
)
{
int
i1
=
h
-
dilation_patchH
*
(
ph
-
patchRadH
);
reduce_sum
+=
prod_sum
[
ph
][
pw
];
for
(
int
pw
=
0
;
pw
<
patchW
;
++
pw
)
{
int
j1
=
w
-
dilation_patchW
*
(
pw
-
patchRadW
);
if
(
WITHIN_BOUNDS
(
i1
,
j1
,
iH
,
iW
))
{
grad_input_val
+=
input1
[
n
][
i1
][
j1
][
c
]
*
grad_cache
[
ph
*
patchW
+
pw
];
}
}
}
}
}
grad_input2
[
n
][
c
][
h
][
w
]
=
reduce_sum
;
grad_input2
[
n
][
c
][
h
][
w
]
=
grad_input_val
;
}
}
}
}
#endif
#endif
mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu
View file @
85b37e7b
...
@@ -24,8 +24,8 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
...
@@ -24,8 +24,8 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
auto
trInput1
=
input1
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
();
auto
trInput1
=
input1
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
();
auto
trInput2
=
input2
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
();
auto
trInput2
=
input2
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
();
const
int
threads
=
THREADS_FORWARD
;
const
dim3
threads
(
WARP_SIZE
,
4
,
4
)
;
const
dim3
blocks
(
batch_size
,
oH
,
oW
);
const
dim3
blocks
(
batch_size
,
(
oH
+
3
)
>>
2
,
(
oW
+
3
)
>>
2
);
at
::
cuda
::
CUDAGuard
device_guard
(
input1
.
device
());
at
::
cuda
::
CUDAGuard
device_guard
(
input1
.
device
());
...
@@ -56,17 +56,20 @@ void CorrelationBackwardCUDAKernelLauncher(
...
@@ -56,17 +56,20 @@ void CorrelationBackwardCUDAKernelLauncher(
const
int
iW
=
input1
.
size
(
3
);
const
int
iW
=
input1
.
size
(
3
);
const
int
C
=
input1
.
size
(
1
);
const
int
C
=
input1
.
size
(
1
);
const
dim3
blocks
(
C
,
iH
,
iW
);
auto
trInput1
=
input1
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
();
const
dim3
threads
(
THREADS_BACKWARD
,
THREADS_BACKWARD
);
auto
trInput2
=
input2
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
();
const
dim3
blocks
(
batch_size
,
iH
,
iW
);
const
dim3
threads
(
THREADS_PER_BLOCK
);
at
::
cuda
::
CUDAGuard
device_guard
(
input1
.
device
());
at
::
cuda
::
CUDAGuard
device_guard
(
input1
.
device
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input1
.
scalar_type
(),
"correlation_backward_cuda"
,
([
&
]
{
input1
.
scalar_type
(),
"correlation_backward_cuda"
,
([
&
]
{
const
int
grad_cache_size
=
patchH
*
patchW
*
sizeof
(
scalar_t
);
TensorAcc4R
input1_acc
=
TensorAcc4R
input1_acc
=
i
nput1
.
packed_accessor32
<
scalar_t
,
4
,
RestrictPtrTraits
>
();
trI
nput1
.
packed_accessor32
<
scalar_t
,
4
,
RestrictPtrTraits
>
();
TensorAcc4R
input2_acc
=
TensorAcc4R
input2_acc
=
i
nput2
.
packed_accessor32
<
scalar_t
,
4
,
RestrictPtrTraits
>
();
trI
nput2
.
packed_accessor32
<
scalar_t
,
4
,
RestrictPtrTraits
>
();
TensorAcc4R
grad_input1_acc
=
TensorAcc4R
grad_input1_acc
=
grad_input1
.
packed_accessor32
<
scalar_t
,
4
,
RestrictPtrTraits
>
();
grad_input1
.
packed_accessor32
<
scalar_t
,
4
,
RestrictPtrTraits
>
();
TensorAcc4R
grad_input2_acc
=
TensorAcc4R
grad_input2_acc
=
...
@@ -74,20 +77,18 @@ void CorrelationBackwardCUDAKernelLauncher(
...
@@ -74,20 +77,18 @@ void CorrelationBackwardCUDAKernelLauncher(
TensorAcc5R
grad_output_acc
=
TensorAcc5R
grad_output_acc
=
grad_output
.
packed_accessor32
<
scalar_t
,
5
,
RestrictPtrTraits
>
();
grad_output
.
packed_accessor32
<
scalar_t
,
5
,
RestrictPtrTraits
>
();
for
(
int
n
=
0
;
n
<
batch_size
;
++
n
)
{
correlation_backward_cuda_kernel_input1
<
scalar_t
>
correlation_backward_cuda_kernel_input1
<
scalar_t
>
<<<
blocks
,
threads
,
grad_cache_size
,
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_output_acc
,
input2_acc
,
grad_input1_acc
,
kH
,
kW
,
patchH
,
grad_output_acc
,
input2_acc
,
grad_input1_acc
,
kH
,
kW
,
patchH
,
patchW
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
patchW
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
dilation_patchW
,
dH
,
dW
,
n
);
dilation_patchW
,
dH
,
dW
);
}
for
(
int
n
=
0
;
n
<
batch_size
;
++
n
)
{
correlation_backward_cuda_kernel_input2
<
scalar_t
>
correlation_backward_cuda_kernel_input2
<
scalar_t
>
<<<
blocks
,
threads
,
grad_cache_size
,
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_output_acc
,
input1_acc
,
grad_input2_acc
,
kH
,
kW
,
patchH
,
grad_output_acc
,
input1_acc
,
grad_input2_acc
,
kH
,
kW
,
patchH
,
patchW
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
patchW
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
dilation_patchW
,
dH
,
dW
,
n
);
dilation_patchW
,
dH
,
dW
);
}
}));
}));
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment