Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
ff189047
"src/vscode:/vscode.git/clone" did not exist on "010bc4ea198eeaf379926316b214c22a1dab8d17"
Unverified
Commit
ff189047
authored
Oct 24, 2022
by
q.yao
Committed by
GitHub
Oct 24, 2022
Browse files
[Fix] Fix Correlation op (#2274)
* fix correlation * fix lint
parent
7fd7058a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
15 deletions
+17
-15
mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
+16
-14
mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu
mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu
+1
-1
No files found.
mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
View file @
ff189047
...
@@ -36,7 +36,8 @@ template <typename scalar_t>
...
@@ -36,7 +36,8 @@ template <typename scalar_t>
__global__
void
correlation_forward_cuda_kernel
(
__global__
void
correlation_forward_cuda_kernel
(
const
TensorAcc4R
rInput1
,
const
TensorAcc4R
rInput2
,
TensorAcc5R
output
,
const
TensorAcc4R
rInput1
,
const
TensorAcc4R
rInput2
,
TensorAcc5R
output
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
kH
,
int
kW
,
int
patchH
,
int
patchW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
)
{
int
dilationW
,
int
dilation_patchH
,
int
dilation_patchW
,
int
dH
,
int
dW
,
int
oH
,
int
oW
)
{
const
int
iH
=
rInput1
.
size
(
1
);
const
int
iH
=
rInput1
.
size
(
1
);
const
int
iW
=
rInput1
.
size
(
2
);
const
int
iW
=
rInput1
.
size
(
2
);
const
int
C
=
rInput1
.
size
(
3
);
const
int
C
=
rInput1
.
size
(
3
);
...
@@ -44,6 +45,9 @@ __global__ void correlation_forward_cuda_kernel(
...
@@ -44,6 +45,9 @@ __global__ void correlation_forward_cuda_kernel(
const
int
n
=
blockIdx
.
x
;
const
int
n
=
blockIdx
.
x
;
const
int
h
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
int
h
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
int
w
=
blockIdx
.
z
*
blockDim
.
z
+
threadIdx
.
z
;
const
int
w
=
blockIdx
.
z
*
blockDim
.
z
+
threadIdx
.
z
;
if
(
h
>=
oH
||
w
>=
oW
)
return
;
const
int
thread
=
threadIdx
.
x
;
const
int
thread
=
threadIdx
.
x
;
const
int
start_i
=
-
padH
+
h
*
dH
;
const
int
start_i
=
-
padH
+
h
*
dH
;
...
@@ -60,13 +64,11 @@ __global__ void correlation_forward_cuda_kernel(
...
@@ -60,13 +64,11 @@ __global__ void correlation_forward_cuda_kernel(
for
(
int
i
=
0
;
i
<
kH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kH
;
++
i
)
{
int
i1
=
start_i
+
i
*
dilationH
;
int
i1
=
start_i
+
i
*
dilationH
;
int
i2
=
i1
+
ph_dilated
;
int
i2
=
i1
+
ph_dilated
;
if
if
(
WITHIN_BOUNDS
(
i1
,
i2
,
iH
,
iH
))
{
WITHIN_BOUNDS
(
i1
,
i2
,
iH
,
iH
)
{
for
(
int
j
=
0
;
j
<
kW
;
++
j
)
{
for
(
int
j
=
0
;
j
<
kW
;
++
j
)
{
int
j1
=
start_j
+
j
*
dilationW
;
int
j1
=
start_j
+
j
*
dilationW
;
int
j2
=
j1
+
pw_dilated
;
int
j2
=
j1
+
pw_dilated
;
if
if
(
WITHIN_BOUNDS
(
j1
,
j2
,
iW
,
iW
))
{
WITHIN_BOUNDS
(
j1
,
j2
,
iW
,
iW
)
{
for
(
int
c
=
thread
;
c
<
C
;
c
+=
WARP_SIZE
)
{
for
(
int
c
=
thread
;
c
<
C
;
c
+=
WARP_SIZE
)
{
scalar_t
v1
=
rInput1
[
n
][
i1
][
j1
][
c
];
scalar_t
v1
=
rInput1
[
n
][
i1
][
j1
][
c
];
scalar_t
v2
=
rInput2
[
n
][
i2
][
j2
][
c
];
scalar_t
v2
=
rInput2
[
n
][
i2
][
j2
][
c
];
...
...
mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu
View file @
ff189047
...
@@ -42,7 +42,7 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
...
@@ -42,7 +42,7 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
trInput1_acc
,
trInput2_acc
,
output_acc
,
kH
,
kW
,
patchH
,
patchW
,
trInput1_acc
,
trInput2_acc
,
output_acc
,
kH
,
kW
,
patchH
,
patchW
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
padH
,
padW
,
dilationH
,
dilationW
,
dilation_patchH
,
dilation_patchW
,
dH
,
dW
);
dilation_patchW
,
dH
,
dW
,
oH
,
oW
);
}));
}));
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment