Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
f27d308f
Commit
f27d308f
authored
Jun 07, 2020
by
yinchimaoliang
Browse files
merge master
parents
c66ae813
27ebcfac
Changes
80
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1457 additions
and
1163 deletions
+1457
-1163
mmdet3d/ops/gather_points/src/gather_points_cuda.cu
mmdet3d/ops/gather_points/src/gather_points_cuda.cu
+78
-68
mmdet3d/ops/group_points/src/group_points.cpp
mmdet3d/ops/group_points/src/group_points.cpp
+34
-28
mmdet3d/ops/group_points/src/group_points_cuda.cu
mmdet3d/ops/group_points/src/group_points_cuda.cu
+77
-64
mmdet3d/ops/interpolate/src/interpolate.cpp
mmdet3d/ops/interpolate/src/interpolate.cpp
+62
-53
mmdet3d/ops/interpolate/src/three_interpolate_cuda.cu
mmdet3d/ops/interpolate/src/three_interpolate_cuda.cu
+92
-80
mmdet3d/ops/interpolate/src/three_nn_cuda.cu
mmdet3d/ops/interpolate/src/three_nn_cuda.cu
+68
-56
mmdet3d/ops/iou3d/src/iou3d_kernel.cu
mmdet3d/ops/iou3d/src/iou3d_kernel.cu
+345
-296
mmdet3d/ops/roiaware_pool3d/__init__.py
mmdet3d/ops/roiaware_pool3d/__init__.py
+6
-2
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
+26
-0
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
+76
-0
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
+5
-0
mmdet3d/ops/sparse_block.py
mmdet3d/ops/sparse_block.py
+31
-7
mmdet3d/ops/spconv/include/paramsgrid.h
mmdet3d/ops/spconv/include/paramsgrid.h
+9
-3
mmdet3d/ops/spconv/include/prettyprint.h
mmdet3d/ops/spconv/include/prettyprint.h
+442
-394
mmdet3d/ops/spconv/include/spconv/box_iou.h
mmdet3d/ops/spconv/include/spconv/box_iou.h
+13
-14
mmdet3d/ops/spconv/include/spconv/geometry.h
mmdet3d/ops/spconv/include/spconv/geometry.h
+10
-10
mmdet3d/ops/spconv/include/spconv/indice.cu.h
mmdet3d/ops/spconv/include/spconv/indice.cu.h
+14
-19
mmdet3d/ops/spconv/include/spconv/indice.h
mmdet3d/ops/spconv/include/spconv/indice.h
+49
-48
mmdet3d/ops/spconv/include/spconv/maxpool.h
mmdet3d/ops/spconv/include/spconv/maxpool.h
+9
-14
mmdet3d/ops/spconv/include/spconv/mp_helper.h
mmdet3d/ops/spconv/include/spconv/mp_helper.h
+11
-7
No files found.
mmdet3d/ops/gather_points/src/gather_points_cuda.cu
View file @
f27d308f
...
@@ -3,82 +3,92 @@
...
@@ -3,82 +3,92 @@
#define TOTAL_THREADS 1024
#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
gather_points_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
__global__
void
gather_points_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
out
)
{
const
float
*
__restrict__
points
,
// points: (B, C, N)
const
int
*
__restrict__
idx
,
// idx: (B, M)
float
*
__restrict__
out
)
{
// output:
// points: (B, C, N)
// out: (B, C, M)
// idx: (B, M)
// output:
int
bs_idx
=
blockIdx
.
z
;
// out: (B, C, M)
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
bs_idx
=
blockIdx
.
z
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
idx
+=
bs_idx
*
m
+
pt_idx
;
points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
out
[
0
]
=
points
[
idx
[
0
]];
idx
+=
bs_idx
*
m
+
pt_idx
;
points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
[
0
]
=
points
[
idx
[
0
]];
}
}
void
gather_points_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
void
gather_points_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
const
float
*
points
,
const
int
*
idx
,
// points: (B, C, N)
float
*
out
,
cudaStream_t
stream
)
{
// idx: (B, npoints)
// points: (B, C, N)
// output:
// idx: (B, npoints)
// out: (B, C, npoints)
// output:
// out: (B, C, npoints)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
cudaError_t
err
;
dim3
threads
(
THREADS_PER_BLOCK
);
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
gather_points_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
points
,
idx
,
out
);
dim3
threads
(
THREADS_PER_BLOCK
);
err
=
cudaGetLastError
();
gather_points_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
points
,
if
(
cudaSuccess
!=
err
)
{
idx
,
out
);
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
err
=
cudaGetLastError
();
}
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
}
__global__
void
gather_points_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
__global__
void
gather_points_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
grad_points
)
{
const
float
*
__restrict__
grad_out
,
// grad_out: (B, C, M)
const
int
*
__restrict__
idx
,
// idx: (B, M)
float
*
__restrict__
grad_points
)
{
// output:
// grad_out: (B, C, M)
// grad_points: (B, C, N)
// idx: (B, M)
// output:
int
bs_idx
=
blockIdx
.
z
;
// grad_points: (B, C, N)
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
bs_idx
=
blockIdx
.
z
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
grad_out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
idx
+=
bs_idx
*
m
+
pt_idx
;
grad_points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
grad_out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
idx
+=
bs_idx
*
m
+
pt_idx
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]);
grad_points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]);
}
}
void
gather_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
void
gather_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
const
float
*
grad_out
,
const
int
*
idx
,
// grad_out: (B, C, npoints)
float
*
grad_points
,
// idx: (B, npoints)
cudaStream_t
stream
)
{
// output:
// grad_out: (B, C, npoints)
// grad_points: (B, C, N)
// idx: (B, npoints)
// output:
cudaError_t
err
;
// grad_points: (B, C, N)
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
gather_points_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
);
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
gather_points_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
);
exit
(
-
1
);
}
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
}
mmdet3d/ops/group_points/src/group_points.cpp
View file @
f27d308f
#include <
torch/serialize/tensor
.h>
#include <
THC/THC
.h>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <THC/THC.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern
THCState
*
state
;
extern
THCState
*
state
;
int
group_points_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
int
group_points_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
);
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
);
void
group_points_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
void
group_points_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
);
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
);
int
group_points_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
int
group_points_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
);
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
);
void
group_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
);
void
group_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
);
int
group_points_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
int
group_points_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
)
{
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
)
{
float
*
grad_points
=
grad_points_tensor
.
data
<
float
>
();
float
*
grad_points
=
grad_points_tensor
.
data
_ptr
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
const
int
*
idx
=
idx_tensor
.
data
_ptr
<
int
>
();
const
float
*
grad_out
=
grad_out_tensor
.
data
<
float
>
();
const
float
*
grad_out
=
grad_out_tensor
.
data
_ptr
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
group_points_grad_kernel_launcher
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
,
stream
);
group_points_grad_kernel_launcher
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
return
1
;
grad_points
,
stream
);
return
1
;
}
}
int
group_points_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
int
group_points_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
)
{
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
)
{
const
float
*
points
=
points_tensor
.
data
<
float
>
();
const
float
*
points
=
points_tensor
.
data
_ptr
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
const
int
*
idx
=
idx_tensor
.
data
_ptr
<
int
>
();
float
*
out
=
out_tensor
.
data
<
float
>
();
float
*
out
=
out_tensor
.
data
_ptr
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
group_points_kernel_launcher
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
,
stream
);
group_points_kernel_launcher
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
,
return
1
;
stream
);
return
1
;
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
group_points_wrapper
,
"group_points_wrapper"
);
m
.
def
(
"forward"
,
&
group_points_wrapper
,
"group_points_wrapper"
);
m
.
def
(
"backward"
,
&
group_points_grad_wrapper
,
"group_points_grad_wrapper"
);
m
.
def
(
"backward"
,
&
group_points_grad_wrapper
,
"group_points_grad_wrapper"
);
}
}
mmdet3d/ops/group_points/src/group_points_cuda.cu
View file @
f27d308f
...
@@ -2,84 +2,97 @@
...
@@ -2,84 +2,97 @@
#include <stdlib.h>
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m,
n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
group_points_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
__global__
void
group_points_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
grad_points
)
{
int
nsample
,
// grad_out: (B, C, npoints, nsample)
const
float
*
__restrict__
grad_out
,
// idx: (B, npoints, nsample)
const
int
*
__restrict__
idx
,
// output:
float
*
__restrict__
grad_points
)
{
// grad_points: (B, C, N)
// grad_out: (B, C, npoints, nsample)
int
bs_idx
=
blockIdx
.
z
;
// idx: (B, npoints, nsample)
int
c_idx
=
blockIdx
.
y
;
// output:
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// grad_points: (B, C, N)
int
pt_idx
=
index
/
nsample
;
int
bs_idx
=
blockIdx
.
z
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
int
c_idx
=
blockIdx
.
y
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
pt_idx
=
index
/
nsample
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
int
sample_idx
=
index
%
nsample
;
int
sample_idx
=
index
%
nsample
;
grad_out
+=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
grad_out
+=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
pt_idx
*
nsample
+
sample_idx
;
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
atomicAdd
(
grad_points
+
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
]
,
grad_out
[
0
]);
atomicAdd
(
grad_points
+
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
],
grad_out
[
0
]);
}
}
void
group_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
void
group_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
int
nsample
,
const
float
*
grad_out
,
// grad_out: (B, C, npoints, nsample)
const
int
*
idx
,
float
*
grad_points
,
// idx: (B, npoints, nsample)
cudaStream_t
stream
)
{
// output:
// grad_out: (B, C, npoints, nsample)
// grad_points: (B, C, N)
// idx: (B, npoints, nsample)
cudaError_t
err
;
// output:
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
// grad_points: (B, C, N)
dim3
threads
(
THREADS_PER_BLOCK
);
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
group_points_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
);
group_points_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
);
err
=
cudaGetLastError
();
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
exit
(
-
1
);
}
}
}
}
__global__
void
group_points_kernel
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
out
)
{
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
pt_idx
=
index
/
nsample
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
__global__
void
group_points_kernel
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
int
sample_idx
=
index
%
nsample
;
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
out
)
{
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
pt_idx
=
index
/
nsample
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
int
sample_idx
=
index
%
nsample
;
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
int
in_idx
=
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
];
int
out_idx
=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
out
[
out_idx
]
=
points
[
in_idx
];
int
in_idx
=
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
];
int
out_idx
=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
out
[
out_idx
]
=
points
[
in_idx
];
}
}
void
group_points_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
void
group_points_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
const
float
*
points
,
const
int
*
idx
,
// points: (B, C, N)
float
*
out
,
cudaStream_t
stream
)
{
// idx: (B, npoints, nsample)
// points: (B, C, N)
// output:
// idx: (B, npoints, nsample)
// out: (B, C, npoints, nsample)
// output:
cudaError_t
err
;
// out: (B, C, npoints, nsample)
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
cudaError_t
err
;
dim3
threads
(
THREADS_PER_BLOCK
);
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
group_points_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
);
group_points_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
// cudaDeviceSynchronize(); // for using printf in kernel function
points
,
idx
,
out
);
err
=
cudaGetLastError
();
// cudaDeviceSynchronize(); // for using printf in kernel function
if
(
cudaSuccess
!=
err
)
{
err
=
cudaGetLastError
();
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
if
(
cudaSuccess
!=
err
)
{
exit
(
-
1
);
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
}
exit
(
-
1
);
}
}
}
mmdet3d/ops/interpolate/src/interpolate.cpp
View file @
f27d308f
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <math.h>
#include <math.h>
#include <stdio.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdlib.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern
THCState
*
state
;
extern
THCState
*
state
;
void
three_nn_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
unknown_tensor
,
void
three_nn_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
unknown_tensor
,
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
);
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
);
void
three_nn_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
unknown
,
void
three_nn_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
unknown
,
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
);
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
);
void
three_interpolate_wrapper
(
int
b
,
int
c
,
int
m
,
int
n
,
void
three_interpolate_wrapper
(
int
b
,
int
c
,
int
m
,
int
n
,
at
::
Tensor
points
_tensor
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx
_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
);
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
);
void
three_interpolate_kernel_launcher
(
int
b
,
int
c
,
int
m
,
int
n
,
void
three_interpolate_kernel_launcher
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
);
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
);
void
three_interpolate_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
m
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
);
void
three_interpolate_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
m
,
at
::
Tensor
grad_out_tensor
,
void
three_interpolate_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
m
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
);
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
void
three_interpolate_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
float
*
grad_points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
);
cudaStream_t
stream
);
void
three_nn_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
unknown_tensor
,
void
three_nn_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
unknown_tensor
,
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
)
{
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
const
float
*
unknown
=
unknown_tensor
.
data
<
float
>
();
at
::
Tensor
idx_tensor
)
{
const
float
*
known
=
known_tensor
.
data
<
float
>
();
const
float
*
unknown
=
unknown_tensor
.
data_ptr
<
float
>
();
float
*
dist2
=
dist2_tensor
.
data
<
float
>
();
const
float
*
known
=
known_tensor
.
data_ptr
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
float
*
dist2
=
dist2_tensor
.
data_ptr
<
float
>
();
int
*
idx
=
idx_tensor
.
data_ptr
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_nn_kernel_launcher
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
,
stream
);
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_nn_kernel_launcher
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
,
stream
);
}
}
void
three_interpolate_wrapper
(
int
b
,
int
c
,
int
m
,
int
n
,
void
three_interpolate_wrapper
(
int
b
,
int
c
,
int
m
,
int
n
,
at
::
Tensor
points_tensor
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
)
{
at
::
Tensor
out_tensor
)
{
const
float
*
points
=
points_tensor
.
data_ptr
<
float
>
();
const
float
*
weight
=
weight_tensor
.
data_ptr
<
float
>
();
const
float
*
points
=
points_tensor
.
data
<
float
>
();
float
*
out
=
out_tensor
.
data_ptr
<
float
>
();
const
float
*
weight
=
weight_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data_ptr
<
int
>
();
float
*
out
=
out_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_kernel_launcher
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
,
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
stream
);
three_interpolate_kernel_launcher
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
,
stream
);
}
}
void
three_interpolate_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
m
,
void
three_interpolate_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
m
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
)
{
at
::
Tensor
grad_points_tensor
)
{
const
float
*
grad_out
=
grad_out_tensor
.
data_ptr
<
float
>
();
const
float
*
grad_out
=
grad_ou
t_tensor
.
data
<
float
>
();
const
float
*
weight
=
weigh
t_tensor
.
data
_ptr
<
float
>
();
const
float
*
weight
=
weight
_tensor
.
data
<
float
>
();
float
*
grad_points
=
grad_points
_tensor
.
data
_ptr
<
float
>
();
float
*
grad_points
=
grad_points
_tensor
.
data
<
floa
t
>
();
const
int
*
idx
=
idx
_tensor
.
data
_ptr
<
in
t
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_grad_kernel_launcher
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
three_interpolate_grad_kernel_launcher
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
,
stream
);
grad_points
,
stream
);
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"three_nn_wrapper"
,
&
three_nn_wrapper
,
"three_nn_wrapper"
);
m
.
def
(
"three_nn_wrapper"
,
&
three_nn_wrapper
,
"three_nn_wrapper"
);
m
.
def
(
"three_interpolate_wrapper"
,
&
three_interpolate_wrapper
,
"three_interpolate_wrapper"
);
m
.
def
(
"three_interpolate_wrapper"
,
&
three_interpolate_wrapper
,
m
.
def
(
"three_interpolate_grad_wrapper"
,
&
three_interpolate_grad_wrapper
,
"three_interpolate_grad_wrapper"
);
"three_interpolate_wrapper"
);
m
.
def
(
"three_interpolate_grad_wrapper"
,
&
three_interpolate_grad_wrapper
,
"three_interpolate_grad_wrapper"
);
}
}
mmdet3d/ops/interpolate/src/three_interpolate_cuda.cu
View file @
f27d308f
...
@@ -3,91 +3,103 @@
...
@@ -3,91 +3,103 @@
#include <stdlib.h>
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
three_interpolate_kernel
(
int
b
,
int
c
,
int
m
,
int
n
,
__global__
void
three_interpolate_kernel
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
__restrict__
points
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
out
)
{
const
int
*
__restrict__
idx
,
// points: (B, C, M)
const
float
*
__restrict__
weight
,
// idx: (B, N, 3)
float
*
__restrict__
out
)
{
// weight: (B, N, 3)
// points: (B, C, M)
// output:
// idx: (B, N, 3)
// out: (B, C, N)
// weight: (B, N, 3)
// output:
int
bs_idx
=
blockIdx
.
z
;
// out: (B, C, N)
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
weight
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
points
+=
bs_idx
*
c
*
m
+
c_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
weight
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
out
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
points
+=
bs_idx
*
c
*
m
+
c_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
out
[
pt_idx
]
=
weight
[
0
]
*
points
[
idx
[
0
]]
+
weight
[
1
]
*
points
[
idx
[
1
]]
+
weight
[
2
]
*
points
[
idx
[
2
]];
out
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
[
pt_idx
]
=
weight
[
0
]
*
points
[
idx
[
0
]]
+
weight
[
1
]
*
points
[
idx
[
1
]]
+
weight
[
2
]
*
points
[
idx
[
2
]];
}
}
void
three_interpolate_kernel_launcher
(
int
b
,
int
c
,
int
m
,
int
n
,
void
three_interpolate_kernel_launcher
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
)
{
const
float
*
points
,
const
int
*
idx
,
// points: (B, C, M)
const
float
*
weight
,
float
*
out
,
// idx: (B, N, 3)
cudaStream_t
stream
)
{
// weight: (B, N, 3)
// points: (B, C, M)
// output:
// idx: (B, N, 3)
// out: (B, C, N)
// weight: (B, N, 3)
// output:
cudaError_t
err
;
// out: (B, C, N)
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
cudaError_t
err
;
three_interpolate_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
);
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
err
=
cudaGetLastError
();
dim3
threads
(
THREADS_PER_BLOCK
);
if
(
cudaSuccess
!=
err
)
{
three_interpolate_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
m
,
n
,
points
,
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
idx
,
weight
,
out
);
exit
(
-
1
);
}
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
}
__global__
void
three_interpolate_grad_kernel
(
__global__
void
three_interpolate_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
grad_points
)
{
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
// grad_out: (B, C, N)
float
*
__restrict__
grad_points
)
{
//
weigh
t: (B,
N
,
3
)
//
grad_ou
t: (B,
C
,
N
)
//
output:
//
weight: (B, N, 3)
//
grad_points: (B, C, M)
//
output:
// grad_points: (B, C, M)
int
bs_idx
=
blockIdx
.
z
;
int
c
_idx
=
blockIdx
.
y
;
int
bs
_idx
=
blockIdx
.
z
;
int
pt
_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
c
_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
grad_out
+=
bs_idx
*
c
*
n
+
c_idx
*
n
+
pt_idx
;
weigh
t
+=
bs_idx
*
n
*
3
+
pt
_idx
*
3
;
grad_ou
t
+=
bs_idx
*
c
*
n
+
c
_idx
*
n
+
pt_idx
;
grad_points
+=
bs_idx
*
c
*
m
+
c
_idx
*
m
;
weight
+=
bs_idx
*
n
*
3
+
pt
_idx
*
3
;
idx
+=
bs_idx
*
n
*
3
+
pt
_idx
*
3
;
grad_points
+=
bs_idx
*
c
*
m
+
c
_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]
*
weight
[
0
]);
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]
*
weight
[
0
]);
atomicAdd
(
grad_points
+
idx
[
1
],
grad_out
[
0
]
*
weight
[
1
]);
atomicAdd
(
grad_points
+
idx
[
1
],
grad_out
[
0
]
*
weight
[
1
]);
atomicAdd
(
grad_points
+
idx
[
2
],
grad_out
[
0
]
*
weight
[
2
]);
atomicAdd
(
grad_points
+
idx
[
2
],
grad_out
[
0
]
*
weight
[
2
]);
}
}
void
three_interpolate_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
void
three_interpolate_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
m
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
)
{
const
float
*
grad_out
,
// grad_out: (B, C, N)
const
int
*
idx
,
const
float
*
weight
,
// weight: (B, N, 3)
float
*
grad_points
,
// output:
cudaStream_t
stream
)
{
// grad_points: (B, C, M)
// grad_out: (B, C, N)
// weight: (B, N, 3)
cudaError_t
err
;
// output:
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
// grad_points: (B, C, M)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
);
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
err
=
cudaGetLastError
();
b
);
// blockIdx.x(col), blockIdx.y(row)
if
(
cudaSuccess
!=
err
)
{
dim3
threads
(
THREADS_PER_BLOCK
);
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
three_interpolate_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
exit
(
-
1
);
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
);
}
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
}
mmdet3d/ops/interpolate/src/three_nn_cuda.cu
View file @
f27d308f
...
@@ -3,72 +3,84 @@
...
@@ -3,72 +3,84 @@
#include <stdlib.h>
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m,
n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
three_nn_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
unknown
,
const
float
*
__restrict__
known
,
float
*
__restrict__
dist2
,
int
*
__restrict__
idx
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
__global__
void
three_nn_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
unknown
,
int
bs_idx
=
blockIdx
.
y
;
const
float
*
__restrict__
known
,
float
*
__restrict__
dist2
,
int
*
__restrict__
idx
)
{
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// unknown: (B, N, 3)
if
(
bs_idx
>=
b
||
pt_idx
>=
n
)
return
;
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
int
bs_idx
=
blockIdx
.
y
;
unknown
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
known
+=
bs_idx
*
m
*
3
;
if
(
bs_idx
>=
b
||
pt_idx
>=
n
)
return
;
dist2
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
unknown
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
float
ux
=
unknown
[
0
];
known
+=
bs_idx
*
m
*
3
;
float
uy
=
unknown
[
1
];
dist2
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
float
uz
=
unknown
[
2
];
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
float
ux
=
unknown
[
0
];
double
best1
=
1e40
,
best2
=
1e40
,
best3
=
1e40
;
float
uy
=
unknown
[
1
];
int
besti1
=
0
,
besti2
=
0
,
besti3
=
0
;
float
uz
=
unknown
[
2
];
for
(
int
k
=
0
;
k
<
m
;
++
k
)
{
float
x
=
known
[
k
*
3
+
0
];
double
best1
=
1e40
,
best2
=
1e40
,
best3
=
1e40
;
float
y
=
known
[
k
*
3
+
1
];
int
besti1
=
0
,
besti2
=
0
,
besti3
=
0
;
float
z
=
known
[
k
*
3
+
2
];
for
(
int
k
=
0
;
k
<
m
;
++
k
)
{
float
d
=
(
ux
-
x
)
*
(
ux
-
x
)
+
(
uy
-
y
)
*
(
uy
-
y
)
+
(
uz
-
z
)
*
(
uz
-
z
);
float
x
=
known
[
k
*
3
+
0
];
if
(
d
<
best1
)
{
float
y
=
known
[
k
*
3
+
1
];
best3
=
best2
;
float
z
=
known
[
k
*
3
+
2
];
besti3
=
besti2
;
float
d
=
(
ux
-
x
)
*
(
ux
-
x
)
+
(
uy
-
y
)
*
(
uy
-
y
)
+
(
uz
-
z
)
*
(
uz
-
z
);
best2
=
best1
;
if
(
d
<
best1
)
{
besti2
=
besti1
;
best3
=
best2
;
besti3
=
besti2
;
best1
=
d
;
best2
=
best1
;
besti2
=
besti1
;
besti1
=
k
;
best1
=
d
;
besti1
=
k
;
}
else
if
(
d
<
best2
)
{
}
best3
=
best2
;
else
if
(
d
<
best2
)
{
besti3
=
besti2
;
best3
=
best2
;
besti3
=
besti2
;
best2
=
d
;
best2
=
d
;
besti2
=
k
;
besti2
=
k
;
}
}
else
if
(
d
<
best3
)
{
else
if
(
d
<
best3
)
{
best3
=
d
;
best3
=
d
;
besti3
=
k
;
besti3
=
k
;
}
}
}
dist2
[
0
]
=
best1
;
dist2
[
1
]
=
best2
;
dist2
[
2
]
=
best3
;
}
idx
[
0
]
=
besti1
;
idx
[
1
]
=
besti2
;
idx
[
2
]
=
besti3
;
dist2
[
0
]
=
best1
;
dist2
[
1
]
=
best2
;
dist2
[
2
]
=
best3
;
idx
[
0
]
=
besti1
;
idx
[
1
]
=
besti2
;
idx
[
2
]
=
besti3
;
}
}
void
three_nn_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
unknown
,
void
three_nn_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
unknown
,
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
)
{
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
// unknown: (B, N, 3)
cudaStream_t
stream
)
{
// known: (B, M, 3)
// unknown: (B, N, 3)
// output:
// known: (B, M, 3)
// dist2: (B, N, 3)
// output:
// idx: (B, N, 3)
// dist2: (B, N, 3)
// idx: (B, N, 3)
cudaError_t
err
;
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
dim3
threads
(
THREADS_PER_BLOCK
);
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_nn_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
);
three_nn_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
);
err
=
cudaGetLastError
();
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
exit
(
-
1
);
}
}
}
}
mmdet3d/ops/iou3d/src/iou3d_kernel.cu
View file @
f27d308f
...
@@ -6,376 +6,425 @@
...
@@ -6,376 +6,425 @@
const
int
THREADS_PER_BLOCK_NMS
=
sizeof
(
unsigned
long
long
)
*
8
;
const
int
THREADS_PER_BLOCK_NMS
=
sizeof
(
unsigned
long
long
)
*
8
;
const
float
EPS
=
1e-8
;
const
float
EPS
=
1e-8
;
struct
Point
{
struct
Point
{
float
x
,
y
;
float
x
,
y
;
__device__
Point
()
{}
__device__
Point
()
{}
__device__
Point
(
double
_x
,
double
_y
){
__device__
Point
(
double
_x
,
double
_y
)
{
x
=
_x
,
y
=
_y
;
}
x
=
_x
,
y
=
_y
;
}
__device__
void
set
(
float
_x
,
float
_y
)
{
x
=
_x
;
__device__
void
set
(
float
_x
,
float
_y
){
y
=
_y
;
x
=
_x
;
y
=
_y
;
}
}
__device__
Point
operator
+
(
const
Point
&
b
)
const
{
__device__
Point
operator
+
(
const
Point
&
b
)
const
{
return
Point
(
x
+
b
.
x
,
y
+
b
.
y
);
return
Point
(
x
+
b
.
x
,
y
+
b
.
y
);
}
}
__device__
Point
operator
-
(
const
Point
&
b
)
const
{
__device__
Point
operator
-
(
const
Point
&
b
)
const
{
return
Point
(
x
-
b
.
x
,
y
-
b
.
y
);
return
Point
(
x
-
b
.
x
,
y
-
b
.
y
);
}
}
};
};
__device__
inline
float
cross
(
const
Point
&
a
,
const
Point
&
b
){
__device__
inline
float
cross
(
const
Point
&
a
,
const
Point
&
b
)
{
return
a
.
x
*
b
.
y
-
a
.
y
*
b
.
x
;
return
a
.
x
*
b
.
y
-
a
.
y
*
b
.
x
;
}
}
__device__
inline
float
cross
(
const
Point
&
p1
,
const
Point
&
p2
,
const
Point
&
p0
){
__device__
inline
float
cross
(
const
Point
&
p1
,
const
Point
&
p2
,
return
(
p1
.
x
-
p0
.
x
)
*
(
p2
.
y
-
p0
.
y
)
-
(
p2
.
x
-
p0
.
x
)
*
(
p1
.
y
-
p0
.
y
);
const
Point
&
p0
)
{
return
(
p1
.
x
-
p0
.
x
)
*
(
p2
.
y
-
p0
.
y
)
-
(
p2
.
x
-
p0
.
x
)
*
(
p1
.
y
-
p0
.
y
);
}
}
__device__
int
check_rect_cross
(
const
Point
&
p1
,
const
Point
&
p2
,
const
Point
&
q1
,
const
Point
&
q2
){
__device__
int
check_rect_cross
(
const
Point
&
p1
,
const
Point
&
p2
,
int
ret
=
min
(
p1
.
x
,
p2
.
x
)
<=
max
(
q1
.
x
,
q2
.
x
)
&&
const
Point
&
q1
,
const
Point
&
q2
)
{
min
(
q1
.
x
,
q2
.
x
)
<=
max
(
p1
.
x
,
p2
.
x
)
&&
int
ret
=
min
(
p1
.
x
,
p2
.
x
)
<=
max
(
q1
.
x
,
q2
.
x
)
&&
min
(
p1
.
y
,
p2
.
y
)
<=
max
(
q1
.
y
,
q2
.
y
)
&&
min
(
q1
.
x
,
q2
.
x
)
<=
max
(
p1
.
x
,
p2
.
x
)
&&
min
(
q1
.
y
,
q2
.
y
)
<=
max
(
p1
.
y
,
p2
.
y
);
min
(
p1
.
y
,
p2
.
y
)
<=
max
(
q1
.
y
,
q2
.
y
)
&&
return
ret
;
min
(
q1
.
y
,
q2
.
y
)
<=
max
(
p1
.
y
,
p2
.
y
);
return
ret
;
}
}
__device__
inline
int
check_in_box2d
(
const
float
*
box
,
const
Point
&
p
){
__device__
inline
int
check_in_box2d
(
const
float
*
box
,
const
Point
&
p
)
{
//params: box (5) [x1, y1, x2, y2, angle]
// params: box (5) [x1, y1, x2, y2, angle]
const
float
MARGIN
=
1e-5
;
const
float
MARGIN
=
1e-5
;
float
center_x
=
(
box
[
0
]
+
box
[
2
])
/
2
;
float
center_x
=
(
box
[
0
]
+
box
[
2
])
/
2
;
float
center_y
=
(
box
[
1
]
+
box
[
3
])
/
2
;
float
center_y
=
(
box
[
1
]
+
box
[
3
])
/
2
;
float
angle_cos
=
cos
(
-
box
[
4
]),
angle_sin
=
sin
(
-
box
[
4
]);
// rotate the point in the opposite direction of box
float
angle_cos
=
cos
(
-
box
[
4
]),
float
rot_x
=
(
p
.
x
-
center_x
)
*
angle_cos
+
(
p
.
y
-
center_y
)
*
angle_sin
+
center_x
;
angle_sin
=
float
rot_y
=
-
(
p
.
x
-
center_x
)
*
angle_sin
+
(
p
.
y
-
center_y
)
*
angle_cos
+
center_y
;
sin
(
-
box
[
4
]);
// rotate the point in the opposite direction of box
float
rot_x
=
(
p
.
x
-
center_x
)
*
angle_cos
+
(
p
.
y
-
center_y
)
*
angle_sin
+
center_x
;
float
rot_y
=
-
(
p
.
x
-
center_x
)
*
angle_sin
+
(
p
.
y
-
center_y
)
*
angle_cos
+
center_y
;
#ifdef DEBUG
#ifdef DEBUG
printf
(
"box: (%.3f, %.3f, %.3f, %.3f, %.3f)
\n
"
,
box
[
0
],
box
[
1
],
box
[
2
],
box
[
3
],
box
[
4
]);
printf
(
"box: (%.3f, %.3f, %.3f, %.3f, %.3f)
\n
"
,
box
[
0
],
box
[
1
],
box
[
2
],
printf
(
"center: (%.3f, %.3f), cossin(%.3f, %.3f), src(%.3f, %.3f), rot(%.3f, %.3f)
\n
"
,
center_x
,
center_y
,
box
[
3
],
box
[
4
]);
angle_cos
,
angle_sin
,
p
.
x
,
p
.
y
,
rot_x
,
rot_y
);
printf
(
"center: (%.3f, %.3f), cossin(%.3f, %.3f), src(%.3f, %.3f), rot(%.3f, "
"%.3f)
\n
"
,
center_x
,
center_y
,
angle_cos
,
angle_sin
,
p
.
x
,
p
.
y
,
rot_x
,
rot_y
);
#endif
#endif
return
(
rot_x
>
box
[
0
]
-
MARGIN
&&
rot_x
<
box
[
2
]
+
MARGIN
&&
rot_y
>
box
[
1
]
-
MARGIN
&&
rot_y
<
box
[
3
]
+
MARGIN
);
return
(
rot_x
>
box
[
0
]
-
MARGIN
&&
rot_x
<
box
[
2
]
+
MARGIN
&&
rot_y
>
box
[
1
]
-
MARGIN
&&
rot_y
<
box
[
3
]
+
MARGIN
);
}
}
__device__
inline
int
intersection
(
const
Point
&
p1
,
const
Point
&
p0
,
const
Point
&
q1
,
const
Point
&
q0
,
Point
&
ans
){
__device__
inline
int
intersection
(
const
Point
&
p1
,
const
Point
&
p0
,
// fast exclusion
const
Point
&
q1
,
const
Point
&
q0
,
if
(
check_rect_cross
(
p0
,
p1
,
q0
,
q1
)
==
0
)
return
0
;
Point
&
ans
)
{
// fast exclusion
if
(
check_rect_cross
(
p0
,
p1
,
q0
,
q1
)
==
0
)
return
0
;
// check cross standing
// check cross standing
float
s1
=
cross
(
q0
,
p1
,
p0
);
float
s1
=
cross
(
q0
,
p1
,
p0
);
float
s2
=
cross
(
p1
,
q1
,
p0
);
float
s2
=
cross
(
p1
,
q1
,
p0
);
float
s3
=
cross
(
p0
,
q1
,
q0
);
float
s3
=
cross
(
p0
,
q1
,
q0
);
float
s4
=
cross
(
q1
,
p1
,
q0
);
float
s4
=
cross
(
q1
,
p1
,
q0
);
if
(
!
(
s1
*
s2
>
0
&&
s3
*
s4
>
0
))
return
0
;
if
(
!
(
s1
*
s2
>
0
&&
s3
*
s4
>
0
))
return
0
;
// calculate intersection of two lines
// calculate intersection of two lines
float
s5
=
cross
(
q1
,
p1
,
p0
);
float
s5
=
cross
(
q1
,
p1
,
p0
);
if
(
fabs
(
s5
-
s1
)
>
EPS
){
if
(
fabs
(
s5
-
s1
)
>
EPS
)
{
ans
.
x
=
(
s5
*
q0
.
x
-
s1
*
q1
.
x
)
/
(
s5
-
s1
);
ans
.
x
=
(
s5
*
q0
.
x
-
s1
*
q1
.
x
)
/
(
s5
-
s1
);
ans
.
y
=
(
s5
*
q0
.
y
-
s1
*
q1
.
y
)
/
(
s5
-
s1
);
ans
.
y
=
(
s5
*
q0
.
y
-
s1
*
q1
.
y
)
/
(
s5
-
s1
);
}
}
else
{
else
{
float
a0
=
p0
.
y
-
p1
.
y
,
b0
=
p1
.
x
-
p0
.
x
,
c0
=
p0
.
x
*
p1
.
y
-
p1
.
x
*
p0
.
y
;
float
a0
=
p0
.
y
-
p1
.
y
,
b0
=
p1
.
x
-
p0
.
x
,
c0
=
p0
.
x
*
p1
.
y
-
p1
.
x
*
p0
.
y
;
float
a1
=
q0
.
y
-
q1
.
y
,
b1
=
q1
.
x
-
q0
.
x
,
c1
=
q0
.
x
*
q1
.
y
-
q1
.
x
*
q0
.
y
;
float
a1
=
q0
.
y
-
q1
.
y
,
b1
=
q1
.
x
-
q0
.
x
,
c1
=
q0
.
x
*
q1
.
y
-
q1
.
x
*
q0
.
y
;
float
D
=
a0
*
b1
-
a1
*
b0
;
float
D
=
a0
*
b1
-
a1
*
b0
;
ans
.
x
=
(
b0
*
c1
-
b1
*
c0
)
/
D
;
ans
.
x
=
(
b0
*
c1
-
b1
*
c0
)
/
D
;
ans
.
y
=
(
a1
*
c0
-
a0
*
c1
)
/
D
;
ans
.
y
=
(
a1
*
c0
-
a0
*
c1
)
/
D
;
}
}
return
1
;
return
1
;
}
}
__device__
inline
void
rotate_around_center
(
const
Point
&
center
,
const
float
angle_cos
,
const
float
angle_sin
,
Point
&
p
){
__device__
inline
void
rotate_around_center
(
const
Point
&
center
,
float
new_x
=
(
p
.
x
-
center
.
x
)
*
angle_cos
+
(
p
.
y
-
center
.
y
)
*
angle_sin
+
center
.
x
;
const
float
angle_cos
,
float
new_y
=
-
(
p
.
x
-
center
.
x
)
*
angle_sin
+
(
p
.
y
-
center
.
y
)
*
angle_cos
+
center
.
y
;
const
float
angle_sin
,
Point
&
p
)
{
p
.
set
(
new_x
,
new_y
);
float
new_x
=
(
p
.
x
-
center
.
x
)
*
angle_cos
+
(
p
.
y
-
center
.
y
)
*
angle_sin
+
center
.
x
;
float
new_y
=
-
(
p
.
x
-
center
.
x
)
*
angle_sin
+
(
p
.
y
-
center
.
y
)
*
angle_cos
+
center
.
y
;
p
.
set
(
new_x
,
new_y
);
}
}
__device__
inline
int
point_cmp
(
const
Point
&
a
,
const
Point
&
b
,
const
Point
&
center
){
__device__
inline
int
point_cmp
(
const
Point
&
a
,
const
Point
&
b
,
return
atan2
(
a
.
y
-
center
.
y
,
a
.
x
-
center
.
x
)
>
atan2
(
b
.
y
-
center
.
y
,
b
.
x
-
center
.
x
);
const
Point
&
center
)
{
return
atan2
(
a
.
y
-
center
.
y
,
a
.
x
-
center
.
x
)
>
atan2
(
b
.
y
-
center
.
y
,
b
.
x
-
center
.
x
);
}
}
__device__
inline
float
box_overlap
(
const
float
*
box_a
,
const
float
*
box_b
){
__device__
inline
float
box_overlap
(
const
float
*
box_a
,
const
float
*
box_b
)
{
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
float
a_x1
=
box_a
[
0
],
a_y1
=
box_a
[
1
],
a_x2
=
box_a
[
2
],
a_y2
=
box_a
[
3
],
a_angle
=
box_a
[
4
];
float
a_x1
=
box_a
[
0
],
a_y1
=
box_a
[
1
],
a_x2
=
box_a
[
2
],
a_y2
=
box_a
[
3
],
float
b_x1
=
box_b
[
0
],
b_y1
=
box_b
[
1
],
b_x2
=
box_b
[
2
],
b_y2
=
box_b
[
3
],
b_angle
=
box_b
[
4
];
a_angle
=
box_a
[
4
];
float
b_x1
=
box_b
[
0
],
b_y1
=
box_b
[
1
],
b_x2
=
box_b
[
2
],
b_y2
=
box_b
[
3
],
b_angle
=
box_b
[
4
];
Point
center_a
((
a_x1
+
a_x2
)
/
2
,
(
a_y1
+
a_y2
)
/
2
);
Point
center_a
((
a_x1
+
a_x2
)
/
2
,
(
a_y1
+
a_y2
)
/
2
);
Point
center_b
((
b_x1
+
b_x2
)
/
2
,
(
b_y1
+
b_y2
)
/
2
);
Point
center_b
((
b_x1
+
b_x2
)
/
2
,
(
b_y1
+
b_y2
)
/
2
);
#ifdef DEBUG
#ifdef DEBUG
printf
(
"a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)
\n
"
,
a_x1
,
a_y1
,
a_x2
,
a_y2
,
a_angle
,
printf
(
b_x1
,
b_y1
,
b_x2
,
b_y2
,
b_angle
);
"a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)
\n
"
,
printf
(
"center a: (%.3f, %.3f), b: (%.3f, %.3f)
\n
"
,
center_a
.
x
,
center_a
.
y
,
center_b
.
x
,
center_b
.
y
);
a_x1
,
a_y1
,
a_x2
,
a_y2
,
a_angle
,
b_x1
,
b_y1
,
b_x2
,
b_y2
,
b_angle
);
printf
(
"center a: (%.3f, %.3f), b: (%.3f, %.3f)
\n
"
,
center_a
.
x
,
center_a
.
y
,
center_b
.
x
,
center_b
.
y
);
#endif
#endif
Point
box_a_corners
[
5
];
Point
box_a_corners
[
5
];
box_a_corners
[
0
].
set
(
a_x1
,
a_y1
);
box_a_corners
[
0
].
set
(
a_x1
,
a_y1
);
box_a_corners
[
1
].
set
(
a_x2
,
a_y1
);
box_a_corners
[
1
].
set
(
a_x2
,
a_y1
);
box_a_corners
[
2
].
set
(
a_x2
,
a_y2
);
box_a_corners
[
2
].
set
(
a_x2
,
a_y2
);
box_a_corners
[
3
].
set
(
a_x1
,
a_y2
);
box_a_corners
[
3
].
set
(
a_x1
,
a_y2
);
Point
box_b_corners
[
5
];
Point
box_b_corners
[
5
];
box_b_corners
[
0
].
set
(
b_x1
,
b_y1
);
box_b_corners
[
0
].
set
(
b_x1
,
b_y1
);
box_b_corners
[
1
].
set
(
b_x2
,
b_y1
);
box_b_corners
[
1
].
set
(
b_x2
,
b_y1
);
box_b_corners
[
2
].
set
(
b_x2
,
b_y2
);
box_b_corners
[
2
].
set
(
b_x2
,
b_y2
);
box_b_corners
[
3
].
set
(
b_x1
,
b_y2
);
box_b_corners
[
3
].
set
(
b_x1
,
b_y2
);
// get oriented corners
// get oriented corners
float
a_angle_cos
=
cos
(
a_angle
),
a_angle_sin
=
sin
(
a_angle
);
float
a_angle_cos
=
cos
(
a_angle
),
a_angle_sin
=
sin
(
a_angle
);
float
b_angle_cos
=
cos
(
b_angle
),
b_angle_sin
=
sin
(
b_angle
);
float
b_angle_cos
=
cos
(
b_angle
),
b_angle_sin
=
sin
(
b_angle
);
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
#ifdef DEBUG
#ifdef DEBUG
printf
(
"before corner %d: a(%.3f, %.3f), b(%.3f, %.3f)
\n
"
,
k
,
box_a_corners
[
k
].
x
,
box_a_corners
[
k
].
y
,
box_b_corners
[
k
].
x
,
box_b_corners
[
k
].
y
);
printf
(
"before corner %d: a(%.3f, %.3f), b(%.3f, %.3f)
\n
"
,
k
,
box_a_corners
[
k
].
x
,
box_a_corners
[
k
].
y
,
box_b_corners
[
k
].
x
,
box_b_corners
[
k
].
y
);
#endif
#endif
rotate_around_center
(
center_a
,
a_angle_cos
,
a_angle_sin
,
box_a_corners
[
k
]);
rotate_around_center
(
center_a
,
a_angle_cos
,
a_angle_sin
,
box_a_corners
[
k
]);
rotate_around_center
(
center_b
,
b_angle_cos
,
b_angle_sin
,
box_b_corners
[
k
]);
rotate_around_center
(
center_b
,
b_angle_cos
,
b_angle_sin
,
box_b_corners
[
k
]);
#ifdef DEBUG
#ifdef DEBUG
printf
(
"corner %d: a(%.3f, %.3f), b(%.3f, %.3f)
\n
"
,
k
,
box_a_corners
[
k
].
x
,
box_a_corners
[
k
].
y
,
box_b_corners
[
k
].
x
,
box_b_corners
[
k
].
y
);
printf
(
"corner %d: a(%.3f, %.3f), b(%.3f, %.3f)
\n
"
,
k
,
box_a_corners
[
k
].
x
,
box_a_corners
[
k
].
y
,
box_b_corners
[
k
].
x
,
box_b_corners
[
k
].
y
);
#endif
#endif
}
box_a_corners
[
4
]
=
box_a_corners
[
0
];
box_b_corners
[
4
]
=
box_b_corners
[
0
];
// get intersection of lines
Point
cross_points
[
16
];
Point
poly_center
;
int
cnt
=
0
,
flag
=
0
;
poly_center
.
set
(
0
,
0
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
flag
=
intersection
(
box_a_corners
[
i
+
1
],
box_a_corners
[
i
],
box_b_corners
[
j
+
1
],
box_b_corners
[
j
],
cross_points
[
cnt
]);
if
(
flag
)
{
poly_center
=
poly_center
+
cross_points
[
cnt
];
cnt
++
;
}
}
}
}
box_a_corners
[
4
]
=
box_a_corners
[
0
];
box_b_corners
[
4
]
=
box_b_corners
[
0
];
// check corners
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
// get intersection of lines
if
(
check_in_box2d
(
box_a
,
box_b_corners
[
k
]))
{
Point
cross_points
[
16
];
poly_center
=
poly_center
+
box_b_corners
[
k
];
Point
poly_center
;
cross_points
[
cnt
]
=
box_b_corners
[
k
];
int
cnt
=
0
,
flag
=
0
;
cnt
++
;
poly_center
.
set
(
0
,
0
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
for
(
int
j
=
0
;
j
<
4
;
j
++
){
flag
=
intersection
(
box_a_corners
[
i
+
1
],
box_a_corners
[
i
],
box_b_corners
[
j
+
1
],
box_b_corners
[
j
],
cross_points
[
cnt
]);
if
(
flag
){
poly_center
=
poly_center
+
cross_points
[
cnt
];
cnt
++
;
}
}
}
}
if
(
check_in_box2d
(
box_b
,
box_a_corners
[
k
]))
{
// check corners
poly_center
=
poly_center
+
box_a_corners
[
k
];
for
(
int
k
=
0
;
k
<
4
;
k
++
){
cross_points
[
cnt
]
=
box_a_corners
[
k
];
if
(
check_in_box2d
(
box_a
,
box_b_corners
[
k
])){
cnt
++
;
poly_center
=
poly_center
+
box_b_corners
[
k
];
cross_points
[
cnt
]
=
box_b_corners
[
k
];
cnt
++
;
}
if
(
check_in_box2d
(
box_b
,
box_a_corners
[
k
])){
poly_center
=
poly_center
+
box_a_corners
[
k
];
cross_points
[
cnt
]
=
box_a_corners
[
k
];
cnt
++
;
}
}
}
}
poly_center
.
x
/=
cnt
;
poly_center
.
y
/=
cnt
;
poly_center
.
x
/=
cnt
;
poly_center
.
y
/=
cnt
;
// sort the points of polygon
Point
temp
;
// sort the points of polygon
for
(
int
j
=
0
;
j
<
cnt
-
1
;
j
++
){
Point
temp
;
for
(
int
i
=
0
;
i
<
cnt
-
j
-
1
;
i
++
){
for
(
int
j
=
0
;
j
<
cnt
-
1
;
j
++
)
{
if
(
point_cmp
(
cross_points
[
i
],
cross_points
[
i
+
1
],
poly_center
))
{
for
(
int
i
=
0
;
i
<
cnt
-
j
-
1
;
i
++
)
{
temp
=
cross_points
[
i
];
if
(
point_cmp
(
cross_points
[
i
],
cross_points
[
i
+
1
],
poly_center
))
{
cross_points
[
i
]
=
cross_points
[
i
+
1
];
temp
=
cross_points
[
i
];
cross_points
[
i
+
1
]
=
temp
;
cross_points
[
i
]
=
cross_points
[
i
+
1
];
}
cross_points
[
i
+
1
]
=
temp
;
}
}
}
}
}
#ifdef DEBUG
#ifdef DEBUG
printf
(
"cnt=%d
\n
"
,
cnt
);
printf
(
"cnt=%d
\n
"
,
cnt
);
for
(
int
i
=
0
;
i
<
cnt
;
i
++
){
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
printf
(
"All cross point %d: (%.3f, %.3f)
\n
"
,
i
,
cross_points
[
i
].
x
,
cross_points
[
i
].
y
);
printf
(
"All cross point %d: (%.3f, %.3f)
\n
"
,
i
,
cross_points
[
i
].
x
,
}
cross_points
[
i
].
y
);
}
#endif
#endif
// get the overlap areas
// get the overlap areas
float
area
=
0
;
float
area
=
0
;
for
(
int
k
=
0
;
k
<
cnt
-
1
;
k
++
){
for
(
int
k
=
0
;
k
<
cnt
-
1
;
k
++
)
{
area
+=
cross
(
cross_points
[
k
]
-
cross_points
[
0
],
cross_points
[
k
+
1
]
-
cross_points
[
0
]);
area
+=
cross
(
cross_points
[
k
]
-
cross_points
[
0
],
}
cross_points
[
k
+
1
]
-
cross_points
[
0
]);
}
return
fabs
(
area
)
/
2.0
;
return
fabs
(
area
)
/
2.0
;
}
}
__device__
inline
float
iou_bev
(
const
float
*
box_a
,
const
float
*
box_b
){
__device__
inline
float
iou_bev
(
const
float
*
box_a
,
const
float
*
box_b
)
{
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
float
sa
=
(
box_a
[
2
]
-
box_a
[
0
])
*
(
box_a
[
3
]
-
box_a
[
1
]);
float
sa
=
(
box_a
[
2
]
-
box_a
[
0
])
*
(
box_a
[
3
]
-
box_a
[
1
]);
float
sb
=
(
box_b
[
2
]
-
box_b
[
0
])
*
(
box_b
[
3
]
-
box_b
[
1
]);
float
sb
=
(
box_b
[
2
]
-
box_b
[
0
])
*
(
box_b
[
3
]
-
box_b
[
1
]);
float
s_overlap
=
box_overlap
(
box_a
,
box_b
);
float
s_overlap
=
box_overlap
(
box_a
,
box_b
);
return
s_overlap
/
fmaxf
(
sa
+
sb
-
s_overlap
,
EPS
);
return
s_overlap
/
fmaxf
(
sa
+
sb
-
s_overlap
,
EPS
);
}
}
__global__
void
boxes_overlap_kernel
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_overlap
){
__global__
void
boxes_overlap_kernel
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
a_idx
=
blockIdx
.
y
*
THREADS_PER_BLOCK
+
threadIdx
.
y
;
const
int
num_b
,
const
float
*
boxes_b
,
const
int
b_idx
=
blockIdx
.
x
*
THREADS_PER_BLOCK
+
threadIdx
.
x
;
float
*
ans_overlap
)
{
const
int
a_idx
=
blockIdx
.
y
*
THREADS_PER_BLOCK
+
threadIdx
.
y
;
if
(
a_idx
>=
num_a
||
b_idx
>=
num_b
){
const
int
b_idx
=
blockIdx
.
x
*
THREADS_PER_BLOCK
+
threadIdx
.
x
;
return
;
}
if
(
a_idx
>=
num_a
||
b_idx
>=
num_b
)
{
const
float
*
cur_box_a
=
boxes_a
+
a_idx
*
5
;
return
;
const
float
*
cur_box_b
=
boxes_b
+
b_idx
*
5
;
}
float
s_overlap
=
box_overlap
(
cur_box_a
,
cur_box_b
);
const
float
*
cur_box_a
=
boxes_a
+
a_idx
*
5
;
ans_overlap
[
a_idx
*
num_b
+
b_idx
]
=
s_overlap
;
const
float
*
cur_box_b
=
boxes_b
+
b_idx
*
5
;
float
s_overlap
=
box_overlap
(
cur_box_a
,
cur_box_b
);
ans_overlap
[
a_idx
*
num_b
+
b_idx
]
=
s_overlap
;
}
}
__global__
void
boxes_iou_bev_kernel
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_iou
){
__global__
void
boxes_iou_bev_kernel
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
a_idx
=
blockIdx
.
y
*
THREADS_PER_BLOCK
+
threadIdx
.
y
;
const
int
num_b
,
const
float
*
boxes_b
,
const
int
b_idx
=
blockIdx
.
x
*
THREADS_PER_BLOCK
+
threadIdx
.
x
;
float
*
ans_iou
)
{
const
int
a_idx
=
blockIdx
.
y
*
THREADS_PER_BLOCK
+
threadIdx
.
y
;
const
int
b_idx
=
blockIdx
.
x
*
THREADS_PER_BLOCK
+
threadIdx
.
x
;
if
(
a_idx
>=
num_a
||
b_idx
>=
num_b
){
if
(
a_idx
>=
num_a
||
b_idx
>=
num_b
)
{
return
;
return
;
}
}
const
float
*
cur_box_a
=
boxes_a
+
a_idx
*
5
;
const
float
*
cur_box_a
=
boxes_a
+
a_idx
*
5
;
const
float
*
cur_box_b
=
boxes_b
+
b_idx
*
5
;
const
float
*
cur_box_b
=
boxes_b
+
b_idx
*
5
;
float
cur_iou_bev
=
iou_bev
(
cur_box_a
,
cur_box_b
);
float
cur_iou_bev
=
iou_bev
(
cur_box_a
,
cur_box_b
);
ans_iou
[
a_idx
*
num_b
+
b_idx
]
=
cur_iou_bev
;
ans_iou
[
a_idx
*
num_b
+
b_idx
]
=
cur_iou_bev
;
}
}
__global__
void
nms_kernel
(
const
int
boxes_num
,
const
float
nms_overlap_thresh
,
__global__
void
nms_kernel
(
const
int
boxes_num
,
const
float
nms_overlap_thresh
,
const
float
*
boxes
,
unsigned
long
long
*
mask
){
const
float
*
boxes
,
unsigned
long
long
*
mask
)
{
//params: boxes (N, 5) [x1, y1, x2, y2, ry]
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
//params: mask (N, N/THREADS_PER_BLOCK_NMS)
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const
int
row_start
=
blockIdx
.
y
;
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
const
int
col_start
=
blockIdx
.
x
;
// if (row_start > col_start) return;
// if (row_start > col_start) return;
const
int
row_size
=
fminf
(
boxes_num
-
row_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
const
int
row_size
=
fminf
(
boxes_num
-
row_start
*
THREADS_PER_BLOCK_NMS
,
const
int
col_size
=
fminf
(
boxes_num
-
col_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
THREADS_PER_BLOCK_NMS
);
const
int
col_size
=
fminf
(
boxes_num
-
col_start
*
THREADS_PER_BLOCK_NMS
,
__shared__
float
block_boxes
[
THREADS_PER_BLOCK_NMS
*
5
];
THREADS_PER_BLOCK_NMS
);
if
(
threadIdx
.
x
<
col_size
)
{
__shared__
float
block_boxes
[
THREADS_PER_BLOCK_NMS
*
5
];
block_boxes
[
threadIdx
.
x
*
5
+
0
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
0
];
block_boxes
[
threadIdx
.
x
*
5
+
1
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
1
];
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
5
+
2
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
2
];
block_boxes
[
threadIdx
.
x
*
5
+
0
]
=
block_boxes
[
threadIdx
.
x
*
5
+
3
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
3
];
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
0
];
block_boxes
[
threadIdx
.
x
*
5
+
4
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
4
];
block_boxes
[
threadIdx
.
x
*
5
+
1
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
1
];
block_boxes
[
threadIdx
.
x
*
5
+
2
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
2
];
block_boxes
[
threadIdx
.
x
*
5
+
3
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
3
];
block_boxes
[
threadIdx
.
x
*
5
+
4
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
4
];
}
__syncthreads
();
if
(
threadIdx
.
x
<
row_size
)
{
const
int
cur_box_idx
=
THREADS_PER_BLOCK_NMS
*
row_start
+
threadIdx
.
x
;
const
float
*
cur_box
=
boxes
+
cur_box_idx
*
5
;
int
i
=
0
;
unsigned
long
long
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
}
__syncthreads
();
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
iou_bev
(
cur_box
,
block_boxes
+
i
*
5
)
>
nms_overlap_thresh
)
{
if
(
threadIdx
.
x
<
row_size
)
{
t
|=
1ULL
<<
i
;
const
int
cur_box_idx
=
THREADS_PER_BLOCK_NMS
*
row_start
+
threadIdx
.
x
;
}
const
float
*
cur_box
=
boxes
+
cur_box_idx
*
5
;
int
i
=
0
;
unsigned
long
long
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
iou_bev
(
cur_box
,
block_boxes
+
i
*
5
)
>
nms_overlap_thresh
){
t
|=
1ULL
<<
i
;
}
}
const
int
col_blocks
=
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
);
mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
}
}
const
int
col_blocks
=
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
);
mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
}
}
}
__device__
inline
float
iou_normal
(
float
const
*
const
a
,
float
const
*
const
b
)
{
__device__
inline
float
iou_normal
(
float
const
*
const
a
,
float
const
*
const
b
)
{
float
left
=
fmaxf
(
a
[
0
],
b
[
0
]),
right
=
fminf
(
a
[
2
],
b
[
2
]);
float
left
=
fmaxf
(
a
[
0
],
b
[
0
]),
right
=
fminf
(
a
[
2
],
b
[
2
]);
float
top
=
fmaxf
(
a
[
1
],
b
[
1
]),
bottom
=
fminf
(
a
[
3
],
b
[
3
]);
float
top
=
fmaxf
(
a
[
1
],
b
[
1
]),
bottom
=
fminf
(
a
[
3
],
b
[
3
]);
float
width
=
fmaxf
(
right
-
left
,
0.
f
),
height
=
fmaxf
(
bottom
-
top
,
0.
f
);
float
width
=
fmaxf
(
right
-
left
,
0.
f
),
height
=
fmaxf
(
bottom
-
top
,
0.
f
);
float
interS
=
width
*
height
;
float
interS
=
width
*
height
;
float
Sa
=
(
a
[
2
]
-
a
[
0
])
*
(
a
[
3
]
-
a
[
1
]);
float
Sa
=
(
a
[
2
]
-
a
[
0
])
*
(
a
[
3
]
-
a
[
1
]);
float
Sb
=
(
b
[
2
]
-
b
[
0
])
*
(
b
[
3
]
-
b
[
1
]);
float
Sb
=
(
b
[
2
]
-
b
[
0
])
*
(
b
[
3
]
-
b
[
1
]);
return
interS
/
fmaxf
(
Sa
+
Sb
-
interS
,
EPS
);
return
interS
/
fmaxf
(
Sa
+
Sb
-
interS
,
EPS
);
}
}
__global__
void
nms_normal_kernel
(
const
int
boxes_num
,
__global__
void
nms_normal_kernel
(
const
int
boxes_num
,
const
float
nms_overlap_thresh
,
const
float
nms_overlap_thresh
,
const
float
*
boxes
,
unsigned
long
long
*
mask
){
const
float
*
boxes
,
//params: boxes (N, 5) [x1, y1, x2, y2, ry]
unsigned
long
long
*
mask
)
{
//params: mask (N, N/THREADS_PER_BLOCK_NMS)
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
// if (row_start > col_start) return;
// if (row_start > col_start) return;
const
int
row_size
=
fminf
(
boxes_num
-
row_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
const
int
col_size
=
fminf
(
boxes_num
-
col_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
const
int
row_size
=
fminf
(
boxes_num
-
row_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
__shared__
float
block_boxes
[
THREADS_PER_BLOCK_NMS
*
5
];
const
int
col_size
=
fminf
(
boxes_num
-
col_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
5
+
0
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
0
];
__shared__
float
block_boxes
[
THREADS_PER_BLOCK_NMS
*
5
];
block_boxes
[
threadIdx
.
x
*
5
+
1
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
1
];
block_boxes
[
threadIdx
.
x
*
5
+
2
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
2
];
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
5
+
3
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
3
];
block_boxes
[
threadIdx
.
x
*
5
+
0
]
=
block_boxes
[
threadIdx
.
x
*
5
+
4
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
4
];
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
0
];
block_boxes
[
threadIdx
.
x
*
5
+
1
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
1
];
block_boxes
[
threadIdx
.
x
*
5
+
2
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
2
];
block_boxes
[
threadIdx
.
x
*
5
+
3
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
3
];
block_boxes
[
threadIdx
.
x
*
5
+
4
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
4
];
}
__syncthreads
();
if
(
threadIdx
.
x
<
row_size
)
{
const
int
cur_box_idx
=
THREADS_PER_BLOCK_NMS
*
row_start
+
threadIdx
.
x
;
const
float
*
cur_box
=
boxes
+
cur_box_idx
*
5
;
int
i
=
0
;
unsigned
long
long
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
}
__syncthreads
();
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
iou_normal
(
cur_box
,
block_boxes
+
i
*
5
)
>
nms_overlap_thresh
)
{
if
(
threadIdx
.
x
<
row_size
)
{
t
|=
1ULL
<<
i
;
const
int
cur_box_idx
=
THREADS_PER_BLOCK_NMS
*
row_start
+
threadIdx
.
x
;
}
const
float
*
cur_box
=
boxes
+
cur_box_idx
*
5
;
int
i
=
0
;
unsigned
long
long
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
iou_normal
(
cur_box
,
block_boxes
+
i
*
5
)
>
nms_overlap_thresh
){
t
|=
1ULL
<<
i
;
}
}
const
int
col_blocks
=
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
);
mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
}
}
const
int
col_blocks
=
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
);
mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
}
}
}
void
boxesoverlapLauncher
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_overlap
)
{
dim3
blocks
(
DIVUP
(
num_b
,
THREADS_PER_BLOCK
),
DIVUP
(
num_a
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
,
THREADS_PER_BLOCK
);
boxes_overlap_kernel
<<<
blocks
,
threads
>>>
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_overlap
);
void
boxesoverlapLauncher
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_overlap
){
dim3
blocks
(
DIVUP
(
num_b
,
THREADS_PER_BLOCK
),
DIVUP
(
num_a
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
,
THREADS_PER_BLOCK
);
boxes_overlap_kernel
<<<
blocks
,
threads
>>>
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_overlap
);
#ifdef DEBUG
#ifdef DEBUG
cudaDeviceSynchronize
();
// for using printf in kernel function
cudaDeviceSynchronize
();
// for using printf in kernel function
#endif
#endif
}
}
void
boxesioubevLauncher
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_iou
){
void
boxesioubevLauncher
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_iou
)
{
dim3
blocks
(
DIVUP
(
num_b
,
THREADS_PER_BLOCK
),
DIVUP
(
num_a
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
dim3
blocks
(
dim3
threads
(
THREADS_PER_BLOCK
,
THREADS_PER_BLOCK
);
DIVUP
(
num_b
,
THREADS_PER_BLOCK
),
DIVUP
(
num_a
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
,
THREADS_PER_BLOCK
);
boxes_iou_bev_kernel
<<<
blocks
,
threads
>>>
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_iou
);
boxes_iou_bev_kernel
<<<
blocks
,
threads
>>>
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_iou
);
}
}
void
nmsLauncher
(
const
float
*
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
void
nmsLauncher
(
const
float
*
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
){
float
nms_overlap_thresh
)
{
dim3
blocks
(
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
),
dim3
blocks
(
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
),
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
));
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
));
dim3
threads
(
THREADS_PER_BLOCK_NMS
);
dim3
threads
(
THREADS_PER_BLOCK_NMS
);
nms_kernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_overlap_thresh
,
boxes
,
mask
);
nms_kernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_overlap_thresh
,
boxes
,
mask
);
}
}
void
nmsNormalLauncher
(
const
float
*
boxes
,
unsigned
long
long
*
mask
,
void
nmsNormalLauncher
(
const
float
*
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
){
int
boxes_num
,
float
nms_overlap_thresh
)
{
dim3
blocks
(
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
),
dim3
blocks
(
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
),
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
));
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
));
dim3
threads
(
THREADS_PER_BLOCK_NMS
);
dim3
threads
(
THREADS_PER_BLOCK_NMS
);
nms_normal_kernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_overlap_thresh
,
boxes
,
mask
);
nms_normal_kernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_overlap_thresh
,
boxes
,
mask
);
}
}
mmdet3d/ops/roiaware_pool3d/__init__.py
View file @
f27d308f
from
.points_in_boxes
import
points_in_boxes_cpu
,
points_in_boxes_gpu
from
.points_in_boxes
import
(
points_in_boxes_batch
,
points_in_boxes_cpu
,
points_in_boxes_gpu
)
from
.roiaware_pool3d
import
RoIAwarePool3d
from
.roiaware_pool3d
import
RoIAwarePool3d
__all__
=
[
'RoIAwarePool3d'
,
'points_in_boxes_gpu'
,
'points_in_boxes_cpu'
]
__all__
=
[
'RoIAwarePool3d'
,
'points_in_boxes_gpu'
,
'points_in_boxes_cpu'
,
'points_in_boxes_batch'
]
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
View file @
f27d308f
...
@@ -53,3 +53,29 @@ def points_in_boxes_cpu(points, boxes):
...
@@ -53,3 +53,29 @@ def points_in_boxes_cpu(points, boxes):
point_indices
)
point_indices
)
return
point_indices
return
point_indices
def
points_in_boxes_batch
(
points
,
boxes
):
"""Find points that are in boxes (CUDA)
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR coordinate
boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, w, l, h, ry] in LiDAR coordinate,
(x, y, z) is the bottom center
Returns:
box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0
"""
assert
boxes
.
shape
[
0
]
==
points
.
shape
[
0
]
assert
boxes
.
shape
[
2
]
==
7
batch_size
,
num_points
,
_
=
points
.
shape
num_boxes
=
boxes
.
shape
[
1
]
box_idxs_of_pts
=
points
.
new_zeros
((
batch_size
,
num_points
,
num_boxes
),
dtype
=
torch
.
int
).
fill_
(
0
)
roiaware_pool3d_ext
.
points_in_boxes_batch
(
boxes
.
contiguous
(),
points
.
contiguous
(),
box_idxs_of_pts
)
return
box_idxs_of_pts
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
View file @
f27d308f
...
@@ -77,6 +77,34 @@ __global__ void points_in_boxes_kernel(int batch_size, int boxes_num,
...
@@ -77,6 +77,34 @@ __global__ void points_in_boxes_kernel(int batch_size, int boxes_num,
}
}
}
}
__global__
void
points_in_boxes_batch_kernel
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
// y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
// -1
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
batch_size
||
pt_idx
>=
pts_num
)
return
;
boxes
+=
bs_idx
*
boxes_num
*
7
;
pts
+=
bs_idx
*
pts_num
*
3
+
pt_idx
*
3
;
box_idx_of_points
+=
bs_idx
*
pts_num
*
boxes_num
+
pt_idx
*
boxes_num
;
float
local_x
=
0
,
local_y
=
0
;
int
cur_in_flag
=
0
;
for
(
int
k
=
0
;
k
<
boxes_num
;
k
++
)
{
cur_in_flag
=
check_pt_in_box3d
(
pts
,
boxes
+
k
*
7
,
local_x
,
local_y
);
if
(
cur_in_flag
)
{
box_idx_of_points
[
k
]
=
1
;
}
cur_in_flag
=
0
;
}
}
void
points_in_boxes_launcher
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
void
points_in_boxes_launcher
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
)
{
int
*
box_idx_of_points
)
{
...
@@ -102,6 +130,30 @@ void points_in_boxes_launcher(int batch_size, int boxes_num, int pts_num,
...
@@ -102,6 +130,30 @@ void points_in_boxes_launcher(int batch_size, int boxes_num, int pts_num,
#endif
#endif
}
}
void
points_in_boxes_batch_launcher
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center, each box params pts: (B, npoints, 3) [x, y, z] in
// LiDAR coordinate params boxes_idx_of_points: (B, npoints), default -1
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
pts_num
,
THREADS_PER_BLOCK
),
batch_size
);
dim3
threads
(
THREADS_PER_BLOCK
);
points_in_boxes_batch_kernel
<<<
blocks
,
threads
>>>
(
batch_size
,
boxes_num
,
pts_num
,
boxes
,
pts
,
box_idx_of_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
#ifdef DEBUG
cudaDeviceSynchronize
();
// for using printf in kernel function
#endif
}
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
)
{
at
::
Tensor
box_idx_of_points_tensor
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
...
@@ -126,3 +178,27 @@ int points_in_boxes_gpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
...
@@ -126,3 +178,27 @@ int points_in_boxes_gpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
return
1
;
return
1
;
}
}
int
points_in_boxes_batch
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center. params pts: (B, npoints, 3) [x, y, z] in LiDAR
// coordinate params boxes_idx_of_points: (B, npoints), default -1
CHECK_INPUT
(
boxes_tensor
);
CHECK_INPUT
(
pts_tensor
);
CHECK_INPUT
(
box_idx_of_points_tensor
);
int
batch_size
=
boxes_tensor
.
size
(
0
);
int
boxes_num
=
boxes_tensor
.
size
(
1
);
int
pts_num
=
pts_tensor
.
size
(
1
);
const
float
*
boxes
=
boxes_tensor
.
data_ptr
<
float
>
();
const
float
*
pts
=
pts_tensor
.
data_ptr
<
float
>
();
int
*
box_idx_of_points
=
box_idx_of_points_tensor
.
data_ptr
<
int
>
();
points_in_boxes_batch_launcher
(
batch_size
,
boxes_num
,
pts_num
,
boxes
,
pts
,
box_idx_of_points
);
return
1
;
}
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
View file @
f27d308f
...
@@ -44,6 +44,9 @@ int points_in_boxes_cpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
...
@@ -44,6 +44,9 @@ int points_in_boxes_cpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
);
at
::
Tensor
box_idx_of_points_tensor
);
int
points_in_boxes_batch
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
);
int
roiaware_pool3d_gpu
(
at
::
Tensor
rois
,
at
::
Tensor
pts
,
at
::
Tensor
pts_feature
,
int
roiaware_pool3d_gpu
(
at
::
Tensor
rois
,
at
::
Tensor
pts
,
at
::
Tensor
pts_feature
,
at
::
Tensor
argmax
,
at
::
Tensor
pts_idx_of_voxels
,
at
::
Tensor
argmax
,
at
::
Tensor
pts_idx_of_voxels
,
at
::
Tensor
pooled_features
,
int
pool_method
)
{
at
::
Tensor
pooled_features
,
int
pool_method
)
{
...
@@ -127,6 +130,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -127,6 +130,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"roiaware pool3d backward (CUDA)"
);
"roiaware pool3d backward (CUDA)"
);
m
.
def
(
"points_in_boxes_gpu"
,
&
points_in_boxes_gpu
,
m
.
def
(
"points_in_boxes_gpu"
,
&
points_in_boxes_gpu
,
"points_in_boxes_gpu forward (CUDA)"
);
"points_in_boxes_gpu forward (CUDA)"
);
m
.
def
(
"points_in_boxes_batch"
,
&
points_in_boxes_batch
,
"points_in_boxes_batch forward (CUDA)"
);
m
.
def
(
"points_in_boxes_cpu"
,
&
points_in_boxes_cpu
,
m
.
def
(
"points_in_boxes_cpu"
,
&
points_in_boxes_cpu
,
"points_in_boxes_cpu forward (CPU)"
);
"points_in_boxes_cpu forward (CPU)"
);
}
}
mmdet3d/ops/sparse_block.py
View file @
f27d308f
...
@@ -6,6 +6,21 @@ from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
...
@@ -6,6 +6,21 @@ from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
class
SparseBottleneck
(
Bottleneck
,
spconv
.
SparseModule
):
class
SparseBottleneck
(
Bottleneck
,
spconv
.
SparseModule
):
"""Sparse bottleneck block for PartA^2.
Bottleneck block implemented with submanifold sparse convolution.
Args:
inplanes (int): inplanes of block.
planes (int): planes of block.
stride (int): stride of the first block. Default: 1
downsample (None | Module): down sample module for block.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
expansion
=
4
expansion
=
4
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -15,10 +30,7 @@ class SparseBottleneck(Bottleneck, spconv.SparseModule):
...
@@ -15,10 +30,7 @@ class SparseBottleneck(Bottleneck, spconv.SparseModule):
downsample
=
None
,
downsample
=
None
,
conv_cfg
=
None
,
conv_cfg
=
None
,
norm_cfg
=
None
):
norm_cfg
=
None
):
"""Sparse bottleneck block for PartA^2.
Bottleneck block implemented with submanifold sparse convolution.
"""
spconv
.
SparseModule
.
__init__
(
self
)
spconv
.
SparseModule
.
__init__
(
self
)
Bottleneck
.
__init__
(
Bottleneck
.
__init__
(
self
,
self
,
...
@@ -53,6 +65,21 @@ class SparseBottleneck(Bottleneck, spconv.SparseModule):
...
@@ -53,6 +65,21 @@ class SparseBottleneck(Bottleneck, spconv.SparseModule):
class
SparseBasicBlock
(
BasicBlock
,
spconv
.
SparseModule
):
class
SparseBasicBlock
(
BasicBlock
,
spconv
.
SparseModule
):
"""Sparse basic block for PartA^2.
Sparse basic block implemented with submanifold sparse convolution.
Args:
inplanes (int): inplanes of block.
planes (int): planes of block.
stride (int): stride of the first block. Default: 1
downsample (None | Module): down sample module for block.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
expansion
=
1
expansion
=
1
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -62,10 +89,6 @@ class SparseBasicBlock(BasicBlock, spconv.SparseModule):
...
@@ -62,10 +89,6 @@ class SparseBasicBlock(BasicBlock, spconv.SparseModule):
downsample
=
None
,
downsample
=
None
,
conv_cfg
=
None
,
conv_cfg
=
None
,
norm_cfg
=
None
):
norm_cfg
=
None
):
"""Sparse basic block for PartA^2.
Sparse basic block implemented with submanifold sparse convolution.
"""
spconv
.
SparseModule
.
__init__
(
self
)
spconv
.
SparseModule
.
__init__
(
self
)
BasicBlock
.
__init__
(
BasicBlock
.
__init__
(
self
,
self
,
...
@@ -125,6 +148,7 @@ def make_sparse_convmodule(in_channels,
...
@@ -125,6 +148,7 @@ def make_sparse_convmodule(in_channels,
spconv.SparseSequential: sparse convolution module.
spconv.SparseSequential: sparse convolution module.
"""
"""
assert
isinstance
(
order
,
tuple
)
and
len
(
order
)
<=
3
assert
isinstance
(
order
,
tuple
)
and
len
(
order
)
<=
3
assert
set
(
order
)
|
{
'conv'
,
'norm'
,
'act'
}
==
{
'conv'
,
'norm'
,
'act'
}
conv_cfg
=
dict
(
type
=
conv_type
,
indice_key
=
indice_key
)
conv_cfg
=
dict
(
type
=
conv_type
,
indice_key
=
indice_key
)
...
...
mmdet3d/ops/spconv/include/paramsgrid.h
View file @
f27d308f
...
@@ -18,13 +18,19 @@
...
@@ -18,13 +18,19 @@
#include <vector>
#include <vector>
namespace
detail
{
namespace
detail
{
template
<
class
T
>
int
getTotalSize
(
std
::
vector
<
T
>
arg
)
{
return
arg
.
size
();
}
template
<
class
T
>
int
getTotalSize
(
std
::
vector
<
T
>
arg
)
{
return
arg
.
size
();
}
template
<
class
T
,
class
...
TArgs
>
template
<
class
T
,
class
...
TArgs
>
int
getTotalSize
(
std
::
vector
<
T
>
arg
,
std
::
vector
<
TArgs
>
...
args
)
{
int
getTotalSize
(
std
::
vector
<
T
>
arg
,
std
::
vector
<
TArgs
>
...
args
)
{
return
arg
.
size
()
*
getTotalSize
(
args
...);
return
arg
.
size
()
*
getTotalSize
(
args
...);
}
}
template
<
typename
T
>
int
getSize
(
std
::
vector
<
T
>
arg
)
{
return
arg
.
size
();
}
template
<
typename
T
>
int
getSize
(
std
::
vector
<
T
>
arg
)
{
return
arg
.
size
();
}
template
<
int
Idx
,
class
TT
,
class
T
>
template
<
int
Idx
,
class
TT
,
class
T
>
void
assigner
(
TT
&
src
,
std
::
vector
<
int
>
counter
,
std
::
vector
<
T
>
&
arg
)
{
void
assigner
(
TT
&
src
,
std
::
vector
<
int
>
counter
,
std
::
vector
<
T
>
&
arg
)
{
...
@@ -37,7 +43,7 @@ void assigner(TT &src, std::vector<int> counter, std::vector<T> &arg,
...
@@ -37,7 +43,7 @@ void assigner(TT &src, std::vector<int> counter, std::vector<T> &arg,
std
::
get
<
Idx
>
(
src
)
=
arg
[
counter
[
Idx
]];
std
::
get
<
Idx
>
(
src
)
=
arg
[
counter
[
Idx
]];
assigner
<
Idx
+
1
>
(
src
,
counter
,
args
...);
assigner
<
Idx
+
1
>
(
src
,
counter
,
args
...);
}
}
}
// namespace detail
}
// namespace detail
template
<
class
...
TArgs
>
template
<
class
...
TArgs
>
std
::
vector
<
std
::
tuple
<
TArgs
...
>>
paramsGrid
(
std
::
vector
<
TArgs
>
...
args
)
{
std
::
vector
<
std
::
tuple
<
TArgs
...
>>
paramsGrid
(
std
::
vector
<
TArgs
>
...
args
)
{
int
length
=
detail
::
getTotalSize
(
args
...);
int
length
=
detail
::
getTotalSize
(
args
...);
...
...
mmdet3d/ops/spconv/include/prettyprint.h
View file @
f27d308f
...
@@ -22,424 +22,472 @@
...
@@ -22,424 +22,472 @@
#include <utility>
#include <utility>
#include <valarray>
#include <valarray>
namespace
pretty_print
namespace
pretty_print
{
{
namespace
detail
{
namespace
detail
// SFINAE type trait to detect whether T::const_iterator exists.
{
// SFINAE type trait to detect whether T::const_iterator exists.
struct
sfinae_base
{
using
yes
=
char
;
struct
sfinae_base
using
no
=
yes
[
2
];
{
};
using
yes
=
char
;
using
no
=
yes
[
2
];
template
<
typename
T
>
};
struct
has_const_iterator
:
private
sfinae_base
{
private:
template
<
typename
T
>
template
<
typename
C
>
struct
has_const_iterator
:
private
sfinae_base
static
yes
&
test
(
typename
C
::
const_iterator
*
);
{
template
<
typename
C
>
private:
static
no
&
test
(...);
template
<
typename
C
>
static
yes
&
test
(
typename
C
::
const_iterator
*
);
template
<
typename
C
>
static
no
&
test
(...);
public:
public:
static
const
bool
value
=
sizeof
(
test
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
static
const
bool
value
=
sizeof
(
test
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
using
type
=
T
;
using
type
=
T
;
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
has_begin_end
:
private
sfinae_base
{
struct
has_begin_end
:
private
sfinae_base
private:
{
template
<
typename
C
>
private:
static
yes
&
template
<
typename
C
>
f
(
typename
std
::
enable_if
<
static
yes
&
f
(
typename
std
::
enable_if
<
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
begin
)),
const
>
(
&
C
::
begin
)),
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
>::
type
*
);
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
>::
type
*
);
template
<
typename
C
>
static
no
&
f
(...);
template
<
typename
C
>
static
no
&
f
(...);
template
<
typename
C
>
static
yes
&
g
(
typename
std
::
enable_if
<
template
<
typename
C
>
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
end
)),
static
yes
&
g
(
typename
std
::
enable_if
<
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
,
void
>::
type
*
);
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
end
)),
template
<
typename
C
>
static
no
&
g
(...);
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
,
void
>::
type
*
);
public:
static
bool
const
beg_value
=
sizeof
(
f
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
template
<
typename
C
>
static
bool
const
end_value
=
sizeof
(
g
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
static
no
&
g
(...);
};
public:
}
// namespace detail
static
bool
const
beg_value
=
sizeof
(
f
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
static
bool
const
end_value
=
sizeof
(
g
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
};
// Holds the delimiter values for a specific character type
}
// namespace detail
template
<
typename
TChar
>
struct
delimiters_values
// Holds the delimiter values for a specific character type
{
using
char_type
=
TChar
;
template
<
typename
TChar
>
const
char_type
*
prefix
;
struct
delimiters_values
{
const
char_type
*
delimiter
;
using
char_type
=
TChar
;
const
char_type
*
postfix
;
const
char_type
*
prefix
;
};
const
char_type
*
delimiter
;
const
char_type
*
postfix
;
};
// Defines the delimiter values for a specific container and character type
// Defines the delimiter values for a specific container and character type
template
<
typename
T
,
typename
TChar
>
struct
delimiters
template
<
typename
T
,
typename
TChar
>
{
struct
delimiters
{
using
type
=
delimiters_values
<
TChar
>
;
using
type
=
delimiters_values
<
TChar
>
;
static
const
type
values
;
static
const
type
values
;
};
};
// Functor to print containers. You can use this directly if you want
// Functor to print containers. You can use this directly if you want
// to specificy a non-default delimiters type. The printing logic can
// to specificy a non-default delimiters type. The printing logic can
// be customized by specializing the nested template.
// be customized by specializing the nested template.
template
<
typename
T
,
typename
TChar
=
char
,
template
<
typename
T
,
typename
TCharTraits
=
::
std
::
char_traits
<
TChar
>,
typename
TChar
=
char
,
typename
TDelimiters
=
delimiters
<
T
,
TChar
>>
typename
TCharTraits
=
::
std
::
char_traits
<
TChar
>,
struct
print_container_helper
{
typename
TDelimiters
=
delimiters
<
T
,
TChar
>>
using
delimiters_type
=
TDelimiters
;
struct
print_container_helper
using
ostream_type
=
std
::
basic_ostream
<
TChar
,
TCharTraits
>
;
{
using
delimiters_type
=
TDelimiters
;
template
<
typename
U
>
using
ostream_type
=
std
::
basic_ostream
<
TChar
,
TCharTraits
>
;
struct
printer
{
static
void
print_body
(
const
U
&
c
,
ostream_type
&
stream
)
{
template
<
typename
U
>
using
std
::
begin
;
struct
printer
using
std
::
end
;
{
static
void
print_body
(
const
U
&
c
,
ostream_type
&
stream
)
auto
it
=
begin
(
c
);
{
const
auto
the_end
=
end
(
c
);
using
std
::
begin
;
using
std
::
end
;
if
(
it
!=
the_end
)
{
for
(;;)
{
auto
it
=
begin
(
c
);
stream
<<
*
it
;
const
auto
the_end
=
end
(
c
);
if
(
++
it
==
the_end
)
break
;
if
(
it
!=
the_end
)
{
if
(
delimiters_type
::
values
.
delimiter
!=
NULL
)
for
(
;
;
)
stream
<<
delimiters_type
::
values
.
delimiter
;
{
stream
<<
*
it
;
if
(
++
it
==
the_end
)
break
;
if
(
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
delimiters_type
::
values
.
delimiter
;
}
}
}
};
print_container_helper
(
const
T
&
container
)
:
container_
(
container
)
{
}
inline
void
operator
()(
ostream_type
&
stream
)
const
{
if
(
delimiters_type
::
values
.
prefix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
prefix
;
printer
<
T
>::
print_body
(
container_
,
stream
);
if
(
delimiters_type
::
values
.
postfix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
postfix
;
}
}
}
private:
const
T
&
container_
;
};
// Specialization for pairs
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
T1
,
typename
T2
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
pair
<
T1
,
T2
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
static
void
print_body
(
const
std
::
pair
<
T1
,
T2
>
&
c
,
ostream_type
&
stream
)
{
stream
<<
c
.
first
;
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
c
.
second
;
}
};
// Specialization for tuples
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
...
Args
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
tuple
<
Args
...
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
using
element_type
=
std
::
tuple
<
Args
...
>
;
template
<
std
::
size_t
I
>
struct
Int
{
};
static
void
print_body
(
const
element_type
&
c
,
ostream_type
&
stream
)
{
tuple_print
(
c
,
stream
,
Int
<
0
>
());
}
static
void
tuple_print
(
const
element_type
&
,
ostream_type
&
,
Int
<
sizeof
...(
Args
)
>
)
{
}
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
typename
std
::
conditional
<
sizeof
...(
Args
)
!=
0
,
Int
<
0
>
,
std
::
nullptr_t
>::
type
)
{
stream
<<
std
::
get
<
0
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
1
>
());
}
template
<
std
::
size_t
N
>
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
Int
<
N
>
)
{
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
std
::
get
<
N
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
N
+
1
>
());
}
};
// Prints a print_container_helper to the specified stream.
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>
&
helper
)
{
helper
(
stream
);
return
stream
;
}
// Basic is_container template; specialize to derive from std::true_type for all desired container types
template
<
typename
T
>
struct
is_container
:
public
std
::
integral_constant
<
bool
,
detail
::
has_const_iterator
<
T
>::
value
&&
detail
::
has_begin_end
<
T
>::
beg_value
&&
detail
::
has_begin_end
<
T
>::
end_value
>
{
};
template
<
typename
T
,
std
::
size_t
N
>
struct
is_container
<
T
[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
is_container
<
char
[
N
]
>
:
std
::
false_type
{
};
template
<
typename
T
>
struct
is_container
<
std
::
valarray
<
T
>>
:
std
::
true_type
{
};
template
<
typename
T1
,
typename
T2
>
struct
is_container
<
std
::
pair
<
T1
,
T2
>>
:
std
::
true_type
{
};
template
<
typename
...
Args
>
struct
is_container
<
std
::
tuple
<
Args
...
>>
:
std
::
true_type
{
};
// Default delimiters
template
<
typename
T
>
struct
delimiters
<
T
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
char
>
delimiters
<
T
,
char
>::
values
=
{
"["
,
", "
,
"]"
};
template
<
typename
T
>
struct
delimiters
<
T
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
wchar_t
>
delimiters
<
T
,
wchar_t
>::
values
=
{
L"["
,
L", "
,
L"]"
};
// Delimiters for (multi)set and unordered_(multi)set
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
// Delimiters for pair and tuple
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
char
>
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<
::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
template
<
typename
...
Args
>
struct
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
char
>
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
...
Args
>
struct
delimiters
<
::
std
::
tuple
<
Args
...
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
tuple
<
Args
...
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
// Type-erasing helper class for easy use of custom delimiters.
// Requires TCharTraits = std::char_traits<TChar> and TChar = char or wchar_t, and MyDelims needs to be defined for TChar.
// Usage: "cout << pretty_print::custom_delims<MyDelims>(x)".
struct
custom_delims_base
{
virtual
~
custom_delims_base
()
{
}
virtual
std
::
ostream
&
stream
(
::
std
::
ostream
&
)
=
0
;
virtual
std
::
wostream
&
stream
(
::
std
::
wostream
&
)
=
0
;
};
template
<
typename
T
,
typename
Delims
>
struct
custom_delims_wrapper
:
custom_delims_base
{
custom_delims_wrapper
(
const
T
&
t_
)
:
t
(
t_
)
{
}
std
::
ostream
&
stream
(
std
::
ostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
char
,
std
::
char_traits
<
char
>
,
Delims
>
(
t
);
}
std
::
wostream
&
stream
(
std
::
wostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
wchar_t
,
std
::
char_traits
<
wchar_t
>
,
Delims
>
(
t
);
}
private:
const
T
&
t
;
};
template
<
typename
Delims
>
struct
custom_delims
{
template
<
typename
Container
>
custom_delims
(
const
Container
&
c
)
:
base
(
new
custom_delims_wrapper
<
Container
,
Delims
>
(
c
))
{
}
std
::
unique_ptr
<
custom_delims_base
>
base
;
};
template
<
typename
TChar
,
typename
TCharTraits
,
typename
Delims
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
s
,
const
custom_delims
<
Delims
>
&
p
)
{
return
p
.
base
->
stream
(
s
);
}
}
};
print_container_helper
(
const
T
&
container
)
:
container_
(
container
)
{}
inline
void
operator
()(
ostream_type
&
stream
)
const
{
if
(
delimiters_type
::
values
.
prefix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
prefix
;
printer
<
T
>::
print_body
(
container_
,
stream
);
if
(
delimiters_type
::
values
.
postfix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
postfix
;
}
private:
const
T
&
container_
;
};
// Specialization for pairs
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
T1
,
typename
T2
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
pair
<
T1
,
T2
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
static
void
print_body
(
const
std
::
pair
<
T1
,
T2
>
&
c
,
ostream_type
&
stream
)
{
stream
<<
c
.
first
;
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
c
.
second
;
}
};
// Specialization for tuples
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
...
Args
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
tuple
<
Args
...
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
using
element_type
=
std
::
tuple
<
Args
...
>
;
template
<
std
::
size_t
I
>
struct
Int
{};
static
void
print_body
(
const
element_type
&
c
,
ostream_type
&
stream
)
{
tuple_print
(
c
,
stream
,
Int
<
0
>
());
}
static
void
tuple_print
(
const
element_type
&
,
ostream_type
&
,
Int
<
sizeof
...(
Args
)
>
)
{}
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
typename
std
::
conditional
<
sizeof
...(
Args
)
!=
0
,
Int
<
0
>
,
std
::
nullptr_t
>::
type
)
{
stream
<<
std
::
get
<
0
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
1
>
());
}
template
<
std
::
size_t
N
>
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
Int
<
N
>
)
{
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
std
::
get
<
N
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
N
+
1
>
());
}
};
// Prints a print_container_helper to the specified stream.
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>
&
helper
)
{
helper
(
stream
);
return
stream
;
}
// Basic is_container template; specialize to derive from std::true_type for all
// desired container types
template
<
typename
T
>
struct
is_container
:
public
std
::
integral_constant
<
bool
,
detail
::
has_const_iterator
<
T
>::
value
&&
detail
::
has_begin_end
<
T
>::
beg_value
&&
detail
::
has_begin_end
<
T
>::
end_value
>
{};
template
<
typename
T
,
std
::
size_t
N
>
struct
is_container
<
T
[
N
]
>
:
std
::
true_type
{};
template
<
std
::
size_t
N
>
struct
is_container
<
char
[
N
]
>
:
std
::
false_type
{};
template
<
typename
T
>
struct
is_container
<
std
::
valarray
<
T
>>
:
std
::
true_type
{};
template
<
typename
T1
,
typename
T2
>
struct
is_container
<
std
::
pair
<
T1
,
T2
>>
:
std
::
true_type
{};
template
<
typename
...
Args
>
struct
is_container
<
std
::
tuple
<
Args
...
>>
:
std
::
true_type
{};
// Default delimiters
template
<
typename
T
>
struct
delimiters
<
T
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
char
>
delimiters
<
T
,
char
>::
values
=
{
"["
,
", "
,
"]"
};
template
<
typename
T
>
struct
delimiters
<
T
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
wchar_t
>
delimiters
<
T
,
wchar_t
>::
values
=
{
L"["
,
L", "
,
L"]"
};
// Delimiters for (multi)set and unordered_(multi)set
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
// Delimiters for pair and tuple
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
char
>
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
template
<
typename
...
Args
>
struct
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
char
>
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
...
Args
>
struct
delimiters
<::
std
::
tuple
<
Args
...
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
tuple
<
Args
...
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
// Type-erasing helper class for easy use of custom delimiters.
// Requires TCharTraits = std::char_traits<TChar> and TChar = char or wchar_t,
// and MyDelims needs to be defined for TChar. Usage: "cout <<
// pretty_print::custom_delims<MyDelims>(x)".
struct
custom_delims_base
{
virtual
~
custom_delims_base
()
{}
virtual
std
::
ostream
&
stream
(
::
std
::
ostream
&
)
=
0
;
virtual
std
::
wostream
&
stream
(
::
std
::
wostream
&
)
=
0
;
};
template
<
typename
T
,
typename
Delims
>
struct
custom_delims_wrapper
:
custom_delims_base
{
custom_delims_wrapper
(
const
T
&
t_
)
:
t
(
t_
)
{}
std
::
ostream
&
stream
(
std
::
ostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
char
,
std
::
char_traits
<
char
>
,
Delims
>
(
t
);
}
std
::
wostream
&
stream
(
std
::
wostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
wchar_t
,
std
::
char_traits
<
wchar_t
>
,
Delims
>
(
t
);
}
private:
const
T
&
t
;
};
template
<
typename
Delims
>
struct
custom_delims
{
template
<
typename
Container
>
custom_delims
(
const
Container
&
c
)
:
base
(
new
custom_delims_wrapper
<
Container
,
Delims
>
(
c
))
{}
std
::
unique_ptr
<
custom_delims_base
>
base
;
};
template
<
typename
TChar
,
typename
TCharTraits
,
typename
Delims
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
s
,
const
custom_delims
<
Delims
>
&
p
)
{
return
p
.
base
->
stream
(
s
);
}
// A wrapper for a C-style array given as pointer-plus-size.
// A wrapper for a C-style array given as pointer-plus-size.
// Usage: std::cout << pretty_print_array(arr, n) << std::endl;
// Usage: std::cout << pretty_print_array(arr, n) << std::endl;
template
<
typename
T
>
struct
array_wrapper_n
{
typedef
const
T
*
const_iterator
;
typedef
T
value_type
;
array_wrapper_n
(
const
T
*
const
a
,
size_t
n
)
:
_array
(
a
),
_n
(
n
)
{
}
template
<
typename
T
>
inline
const_iterator
begin
()
const
{
return
_array
;
}
struct
array_wrapper_n
{
inline
const_iterator
end
()
const
{
return
_array
+
_n
;
}
typedef
const
T
*
const_iterator
;
typedef
T
value_type
;
private:
array_wrapper_n
(
const
T
*
const
a
,
size_t
n
)
:
_array
(
a
),
_n
(
n
)
{}
const
T
*
const
_array
;
inline
const_iterator
begin
()
const
{
return
_array
;
}
size_t
_n
;
inline
const_iterator
end
()
const
{
return
_array
+
_n
;
}
};
private:
const
T
*
const
_array
;
size_t
_n
;
};
// A wrapper for hash-table based containers that offer local iterators to each bucket.
// A wrapper for hash-table based containers that offer local iterators to each
// Usage: std::cout << bucket_print(m, 4) << std::endl; (Prints bucket 5 of container m.)
// bucket. Usage: std::cout << bucket_print(m, 4) << std::endl; (Prints bucket
// 5 of container m.)
template
<
typename
T
>
template
<
typename
T
>
struct
bucket_print_wrapper
struct
bucket_print_wrapper
{
{
typedef
typename
T
::
const_local_iterator
const_iterator
;
typedef
typename
T
::
const_local_iterator
const_iterator
;
typedef
typename
T
::
size_type
size_type
;
typedef
typename
T
::
size_type
size_type
;
const_iterator
begin
()
const
const_iterator
begin
()
const
{
return
m_map
.
cbegin
(
n
);
}
{
return
m_map
.
cbegin
(
n
);
}
const_iterator
end
()
const
{
return
m_map
.
cend
(
n
);
}
bucket_print_wrapper
(
const
T
&
m
,
size_type
bucket
)
:
m_map
(
m
),
n
(
bucket
)
{
}
const_iterator
end
()
const
{
return
m_map
.
cend
(
n
);
}
private:
bucket_print_wrapper
(
const
T
&
m
,
size_type
bucket
)
:
m_map
(
m
),
n
(
bucket
)
{}
const
T
&
m_map
;
const
size_type
n
;
};
}
// namespace pretty_print
private:
const
T
&
m_map
;
const
size_type
n
;
};
}
// namespace pretty_print
// Global accessor functions for the convenience wrappers
// Global accessor functions for the convenience wrappers
template
<
typename
T
>
template
<
typename
T
>
inline
pretty_print
::
array_wrapper_n
<
T
>
pretty_print_array
(
const
T
*
const
a
,
size_t
n
)
inline
pretty_print
::
array_wrapper_n
<
T
>
pretty_print_array
(
const
T
*
const
a
,
{
size_t
n
)
{
return
pretty_print
::
array_wrapper_n
<
T
>
(
a
,
n
);
return
pretty_print
::
array_wrapper_n
<
T
>
(
a
,
n
);
}
}
template
<
typename
T
>
pretty_print
::
bucket_print_wrapper
<
T
>
template
<
typename
T
>
bucket_print
(
const
T
&
m
,
typename
T
::
size_type
n
)
pretty_print
::
bucket_print_wrapper
<
T
>
bucket_print
(
const
T
&
m
,
{
typename
T
::
size_type
n
)
{
return
pretty_print
::
bucket_print_wrapper
<
T
>
(
m
,
n
);
return
pretty_print
::
bucket_print_wrapper
<
T
>
(
m
,
n
);
}
}
// Main magic entry point: An overload snuck into namespace std.
// Main magic entry point: An overload snuck into namespace std.
// Can we do better?
// Can we do better?
namespace
std
namespace
std
{
{
// Prints a container to the stream using default delimiters
// Prints a container to the stream using default delimiters
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
>
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
>
inline
typename
enable_if
<
::
pretty_print
::
is_container
<
T
>::
value
,
inline
typename
enable_if
<::
pretty_print
::
is_container
<
T
>::
value
,
basic_ostream
<
TChar
,
TCharTraits
>
&>::
type
basic_ostream
<
TChar
,
TCharTraits
>
&>::
type
operator
<<
(
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
T
&
container
)
operator
<<
(
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
T
&
container
)
{
{
return
stream
return
stream
<<
::
pretty_print
::
print_container_helper
<
T
,
TChar
,
TCharTraits
>
(
container
);
<<
::
pretty_print
::
print_container_helper
<
T
,
TChar
,
TCharTraits
>
(
}
container
);
}
}
}
// namespace std
#endif // H_PRETTY_PRINT
#endif // H_PRETTY_PRINT
mmdet3d/ops/spconv/include/spconv/box_iou.h
View file @
f27d308f
...
@@ -12,15 +12,15 @@
...
@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#ifndef BOX_IOU_H
#ifndef BOX_IOU_H
#define BOX_IOU_H
#define BOX_IOU_H
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
// must include pybind11/eigen.h if using eigen matrix as arguments.
// must include pybind11/eigen.h if using eigen matrix as arguments.
#include <pybind11/numpy.h>
#include <algorithm>
#include <algorithm>
#include <boost/geometry.hpp>
#include <boost/geometry.hpp>
#include <pybind11/numpy.h>
namespace
spconv
{
namespace
spconv
{
// #include "voxelnet/core/cc/pybind11_helper.h"
// #include "voxelnet/core/cc/pybind11_helper.h"
...
@@ -40,9 +40,10 @@ inline py::array_t<DType> zeros(std::vector<long int> shape) {
...
@@ -40,9 +40,10 @@ inline py::array_t<DType> zeros(std::vector<long int> shape) {
}
}
template
<
typename
DType
>
template
<
typename
DType
>
py
::
array_t
<
DType
>
py
::
array_t
<
DType
>
rbbox_iou
(
py
::
array_t
<
DType
>
box_corners
,
rbbox_iou
(
py
::
array_t
<
DType
>
box_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
namespace
bg
=
boost
::
geometry
;
namespace
bg
=
boost
::
geometry
;
typedef
bg
::
model
::
point
<
DType
,
2
,
bg
::
cs
::
cartesian
>
point_t
;
typedef
bg
::
model
::
point
<
DType
,
2
,
bg
::
cs
::
cartesian
>
point_t
;
typedef
bg
::
model
::
polygon
<
point_t
>
polygon_t
;
typedef
bg
::
model
::
polygon
<
point_t
>
polygon_t
;
...
@@ -61,8 +62,7 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners,
...
@@ -61,8 +62,7 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners,
}
}
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
if
(
standup_iou_r
(
n
,
k
)
<=
standup_thresh
)
if
(
standup_iou_r
(
n
,
k
)
<=
standup_thresh
)
continue
;
continue
;
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
0
,
0
),
box_corners_r
(
n
,
0
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
0
,
0
),
box_corners_r
(
n
,
0
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
1
,
0
),
box_corners_r
(
n
,
1
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
1
,
0
),
box_corners_r
(
n
,
1
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
2
,
0
),
box_corners_r
(
n
,
2
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
2
,
0
),
box_corners_r
(
n
,
2
,
1
)));
...
@@ -99,9 +99,10 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners,
...
@@ -99,9 +99,10 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners,
}
}
template
<
typename
DType
>
template
<
typename
DType
>
py
::
array_t
<
DType
>
py
::
array_t
<
DType
>
rbbox_intersection
(
py
::
array_t
<
DType
>
box_corners
,
rbbox_intersection
(
py
::
array_t
<
DType
>
box_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
namespace
bg
=
boost
::
geometry
;
namespace
bg
=
boost
::
geometry
;
typedef
bg
::
model
::
point
<
DType
,
2
,
bg
::
cs
::
cartesian
>
point_t
;
typedef
bg
::
model
::
point
<
DType
,
2
,
bg
::
cs
::
cartesian
>
point_t
;
typedef
bg
::
model
::
polygon
<
point_t
>
polygon_t
;
typedef
bg
::
model
::
polygon
<
point_t
>
polygon_t
;
...
@@ -120,8 +121,7 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne
...
@@ -120,8 +121,7 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne
}
}
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
if
(
standup_iou_r
(
n
,
k
)
<=
standup_thresh
)
if
(
standup_iou_r
(
n
,
k
)
<=
standup_thresh
)
continue
;
continue
;
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
0
,
0
),
box_corners_r
(
n
,
0
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
0
,
0
),
box_corners_r
(
n
,
0
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
1
,
0
),
box_corners_r
(
n
,
1
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
1
,
0
),
box_corners_r
(
n
,
1
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
2
,
0
),
box_corners_r
(
n
,
2
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
2
,
0
),
box_corners_r
(
n
,
2
,
1
)));
...
@@ -152,6 +152,5 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne
...
@@ -152,6 +152,5 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne
return
overlaps
;
return
overlaps
;
}
}
}
// namespace spconv
}
// namespace spconv
#endif
#endif
mmdet3d/ops/spconv/include/spconv/geometry.h
View file @
f27d308f
...
@@ -15,9 +15,10 @@
...
@@ -15,9 +15,10 @@
#ifndef SPCONV_GEOMETRY_H_
#ifndef SPCONV_GEOMETRY_H_
#define SPCONV_GEOMETRY_H_
#define SPCONV_GEOMETRY_H_
#include <tensorview/tensorview.h>
#include <iostream>
#include <iostream>
#include <limits>
#include <limits>
#include <tensorview/tensorview.h>
namespace
spconv
{
namespace
spconv
{
template
<
typename
Index
,
unsigned
NDim
>
template
<
typename
Index
,
unsigned
NDim
>
...
@@ -70,8 +71,7 @@ TV_HOST_DEVICE Index getValidOutPos(const Index *input_pos,
...
@@ -70,8 +71,7 @@ TV_HOST_DEVICE Index getValidOutPos(const Index *input_pos,
}
}
out
[
pointCounter
*
(
NDim
+
1
)
+
NDim
]
=
offset
;
out
[
pointCounter
*
(
NDim
+
1
)
+
NDim
]
=
offset
;
if
(
valid
)
if
(
valid
)
++
pointCounter
;
++
pointCounter
;
counter
[
NDim
-
1
]
+=
1
;
counter
[
NDim
-
1
]
+=
1
;
#pragma unroll
#pragma unroll
for
(
int
c
=
NDim
-
1
;
c
>=
0
;
--
c
)
{
for
(
int
c
=
NDim
-
1
;
c
>=
0
;
--
c
)
{
...
@@ -128,8 +128,7 @@ TV_HOST_DEVICE Index getValidOutPosTranspose(
...
@@ -128,8 +128,7 @@ TV_HOST_DEVICE Index getValidOutPosTranspose(
m
*=
kernelSize
[
j
];
m
*=
kernelSize
[
j
];
}
}
out
[
pointCounter
*
(
NDim
+
1
)
+
NDim
]
=
offset
;
out
[
pointCounter
*
(
NDim
+
1
)
+
NDim
]
=
offset
;
if
(
valid
)
if
(
valid
)
++
pointCounter
;
++
pointCounter
;
counter
[
NDim
-
1
]
+=
1
;
counter
[
NDim
-
1
]
+=
1
;
#pragma unroll
#pragma unroll
for
(
int
c
=
NDim
-
1
;
c
>=
0
;
--
c
)
{
for
(
int
c
=
NDim
-
1
;
c
>=
0
;
--
c
)
{
...
@@ -167,7 +166,7 @@ Index getIndicePairsConv(tv::TensorView<const Index> indicesIn,
...
@@ -167,7 +166,7 @@ Index getIndicePairsConv(tv::TensorView<const Index> indicesIn,
}
}
Index
numValidPoints
=
0
;
Index
numValidPoints
=
0
;
std
::
vector
<
Index
>
validPoints_
(
kernelVolume
*
(
NDim
+
1
));
std
::
vector
<
Index
>
validPoints_
(
kernelVolume
*
(
NDim
+
1
));
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
pointPtr
=
nullptr
;
Index
*
pointPtr
=
nullptr
;
for
(
int
j
=
0
;
j
<
numActIn
;
++
j
)
{
for
(
int
j
=
0
;
j
<
numActIn
;
++
j
)
{
batchIdx
=
indicesIn
(
j
,
0
);
batchIdx
=
indicesIn
(
j
,
0
);
...
@@ -218,7 +217,7 @@ Index getIndicePairsDeConv(tv::TensorView<const Index> indicesIn,
...
@@ -218,7 +217,7 @@ Index getIndicePairsDeConv(tv::TensorView<const Index> indicesIn,
}
}
Index
numValidPoints
=
0
;
Index
numValidPoints
=
0
;
std
::
vector
<
Index
>
validPoints_
(
kernelVolume
*
(
NDim
+
1
));
std
::
vector
<
Index
>
validPoints_
(
kernelVolume
*
(
NDim
+
1
));
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
pointPtr
=
nullptr
;
Index
*
pointPtr
=
nullptr
;
for
(
int
j
=
0
;
j
<
numActIn
;
++
j
)
{
for
(
int
j
=
0
;
j
<
numActIn
;
++
j
)
{
batchIdx
=
indicesIn
(
j
,
0
);
batchIdx
=
indicesIn
(
j
,
0
);
...
@@ -252,7 +251,8 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
...
@@ -252,7 +251,8 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
Index
*
const
kernelSize
,
const
Index
*
const
kernelSize
,
const
Index
*
const
stride
,
const
Index
*
const
padding
,
const
Index
*
const
stride
,
const
Index
*
const
padding
,
const
Index
*
dilation
,
const
Index
*
const
outSpatialShape
)
{
const
Index
*
dilation
,
const
Index
*
const
outSpatialShape
)
{
Index
numAct
=
0
;
Index
numAct
=
0
;
auto
numActIn
=
indicesIn
.
dim
(
0
);
auto
numActIn
=
indicesIn
.
dim
(
0
);
Index
batchIdx
=
0
;
Index
batchIdx
=
0
;
...
@@ -269,7 +269,7 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
...
@@ -269,7 +269,7 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
Index
numValidPoints
=
0
;
Index
numValidPoints
=
0
;
// Index validPoints[kernelVolume * (NDim + 1)];
// Index validPoints[kernelVolume * (NDim + 1)];
std
::
vector
<
Index
>
validPoints_
(
kernelVolume
*
(
NDim
+
1
));
std
::
vector
<
Index
>
validPoints_
(
kernelVolume
*
(
NDim
+
1
));
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
pointPtr
=
nullptr
;
Index
*
pointPtr
=
nullptr
;
Index
index
=
0
;
Index
index
=
0
;
for
(
int
j
=
0
;
j
<
numActIn
;
++
j
)
{
for
(
int
j
=
0
;
j
<
numActIn
;
++
j
)
{
...
@@ -296,6 +296,6 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
...
@@ -296,6 +296,6 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
return
numActIn
;
return
numActIn
;
}
}
}
// namespace spconv
}
// namespace spconv
#endif
#endif
mmdet3d/ops/spconv/include/spconv/indice.cu.h
View file @
f27d308f
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
#ifndef INDICE_CU_H_
#ifndef INDICE_CU_H_
#define INDICE_CU_H_
#define INDICE_CU_H_
#include <tensorview/tensorview.h>
#include <tensorview/helper_kernel.cu.h>
#include <spconv/geometry.h>
#include <spconv/geometry.h>
#include <tensorview/helper_kernel.cu.h>
#include <tensorview/tensorview.h>
namespace
spconv
{
namespace
spconv
{
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
,
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
,
...
@@ -115,7 +115,6 @@ __global__ void assignGridAndIndiceOutKernel(
...
@@ -115,7 +115,6 @@ __global__ void assignGridAndIndiceOutKernel(
int
numAct
,
tv
::
TensorView
<
Index
>
indicePairs
,
int
numAct
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
int
batchSize
)
{
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
int
batchSize
)
{
Index
index
;
Index
index
;
auto
indicesOutPtr
=
indicesOut
.
data
();
auto
indicesOutPtr
=
indicesOut
.
data
();
for
(
int
ix
:
tv
::
KernelLoopX
<
int
>
(
numAct
))
{
for
(
int
ix
:
tv
::
KernelLoopX
<
int
>
(
numAct
))
{
...
@@ -128,13 +127,11 @@ __global__ void assignGridAndIndiceOutKernel(
...
@@ -128,13 +127,11 @@ __global__ void assignGridAndIndiceOutKernel(
}
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
__global__
void
__global__
void
assignIndicePairsKernel
(
assignIndicePairsKernel
(
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
int
numActIn
,
int
numActIn
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
)
{
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
)
{
Index
index
;
Index
index
;
int
kernelVolume
=
indicePairs
.
dim
(
0
);
int
kernelVolume
=
indicePairs
.
dim
(
0
);
for
(
int
ix
:
tv
::
KernelLoopX
<
int
>
(
numActIn
))
{
for
(
int
ix
:
tv
::
KernelLoopX
<
int
>
(
numActIn
))
{
...
@@ -148,10 +145,9 @@ assignIndicePairsKernel(tv::TensorView<Index> indicesOut,
...
@@ -148,10 +145,9 @@ assignIndicePairsKernel(tv::TensorView<Index> indicesOut,
}
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
__global__
void
__global__
void
prepareSubMGridKernel
(
prepareSubMGridKernel
(
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
)
{
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
)
{
auto
numActIn
=
indicesIn
.
dim
(
0
);
auto
numActIn
=
indicesIn
.
dim
(
0
);
Index
spatialVolume
=
1
;
Index
spatialVolume
=
1
;
#pragma unroll
#pragma unroll
...
@@ -216,10 +212,9 @@ __global__ void resetGridKernel(const Index *indicePairUnique,
...
@@ -216,10 +212,9 @@ __global__ void resetGridKernel(const Index *indicePairUnique,
}
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
__global__
void
__global__
void
resetGridSubMKernel
(
resetGridSubMKernel
(
const
Index
*
indices
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
const
Index
*
indices
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
int
numAct
)
{
int
numAct
)
{
int
outSpatialShapeReg
[
NDim
];
int
outSpatialShapeReg
[
NDim
];
for
(
int
i
=
0
;
i
<
NDim
;
++
i
)
{
for
(
int
i
=
0
;
i
<
NDim
;
++
i
)
{
outSpatialShapeReg
[
i
]
=
outSpatialShape
[
i
];
outSpatialShapeReg
[
i
]
=
outSpatialShape
[
i
];
...
@@ -238,6 +233,6 @@ resetGridSubMKernel(const Index *indices, tv::TensorView<IndexGrid> gridsOut,
...
@@ -238,6 +233,6 @@ resetGridSubMKernel(const Index *indices, tv::TensorView<IndexGrid> gridsOut,
}
}
}
}
}
// namespace spconv
}
// namespace spconv
#endif
#endif
mmdet3d/ops/spconv/include/spconv/indice.h
View file @
f27d308f
...
@@ -16,64 +16,65 @@
...
@@ -16,64 +16,65 @@
#define SPARSE_CONV_INDICE_FUNCTOR_H_
#define SPARSE_CONV_INDICE_FUNCTOR_H_
#include <tensorview/tensorview.h>
#include <tensorview/tensorview.h>
namespace
spconv
namespace
spconv
{
{
namespace
functor
{
namespace
functor
{
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateConvIndicePairFunctorP1
struct
CreateConvIndicePairFunctorP1
{
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
Index
operator
()(
tv
::
TensorView
<
Index
>
indicesOut
,
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
);
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
);
};
};
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateConvIndicePairFunctorP2
struct
CreateConvIndicePairFunctorP2
{
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
Index
operator
()(
tv
::
TensorView
<
Index
>
indicesOut
,
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indice
sOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indice
Pairs
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
resetGrid
=
false
);
bool
transpose
,
bool
resetGrid
=
false
);
};
};
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateConvIndicePairFunctor
struct
CreateConvIndicePairFunctor
{
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
Index
operator
()(
tv
::
TensorView
<
Index
>
indicesOut
,
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
};
};
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateSubMIndicePairFunctor
struct
CreateSubMIndicePairFunctor
{
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
Index
operator
()(
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
};
};
}
// namespace functor
}
// namespace functor
}
// namespace spconv
}
// namespace spconv
#endif
#endif
mmdet3d/ops/spconv/include/spconv/maxpool.h
View file @
f27d308f
...
@@ -16,29 +16,24 @@
...
@@ -16,29 +16,24 @@
#define SPARSE_MAXPOOL_FUNCTOR_H_
#define SPARSE_MAXPOOL_FUNCTOR_H_
#include <tensorview/tensorview.h>
#include <tensorview/tensorview.h>
namespace
spconv
namespace
spconv
{
{
namespace
functor
{
namespace
functor
{
template
<
typename
Device
,
typename
T
,
typename
Index
>
template
<
typename
Device
,
typename
T
,
typename
Index
>
struct
SparseMaxPoolForwardFunctor
struct
SparseMaxPoolForwardFunctor
{
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
T
>
outFeatures
,
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
T
>
outFeatures
,
tv
::
TensorView
<
const
T
>
inFeatures
,
tv
::
TensorView
<
const
T
>
inFeatures
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
);
tv
::
TensorView
<
const
Index
>
indices
,
int
size
);
};
};
template
<
typename
Device
,
typename
T
,
typename
Index
>
template
<
typename
Device
,
typename
T
,
typename
Index
>
struct
SparseMaxPoolBackwardFunctor
struct
SparseMaxPoolBackwardFunctor
{
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
T
>
outFeatures
,
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
T
>
outFeatures
,
tv
::
TensorView
<
const
T
>
inFeatures
,
tv
::
TensorView
<
const
T
>
inFeatures
,
tv
::
TensorView
<
const
T
>
dout
,
tv
::
TensorView
<
const
T
>
dout
,
tv
::
TensorView
<
T
>
din
,
tv
::
TensorView
<
T
>
din
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
);
tv
::
TensorView
<
const
Index
>
indices
,
int
size
);
};
};
}
// namespace functor
}
// namespace functor
}
// namespace spconv
}
// namespace spconv
#endif
#endif
mmdet3d/ops/spconv/include/spconv/mp_helper.h
View file @
f27d308f
...
@@ -4,7 +4,8 @@
...
@@ -4,7 +4,8 @@
#include <utility>
#include <utility>
namespace
spconv
{
namespace
spconv
{
template
<
class
...
T
>
struct
mp_list
{};
template
<
class
...
T
>
struct
mp_list
{};
template
<
class
T
,
T
...
I
>
template
<
class
T
,
T
...
I
>
using
mp_list_c
=
mp_list
<
std
::
integral_constant
<
T
,
I
>
...
>
;
using
mp_list_c
=
mp_list
<
std
::
integral_constant
<
T
,
I
>
...
>
;
...
@@ -16,15 +17,17 @@ constexpr F mp_for_each_impl(mp_list<T...>, F &&f) {
...
@@ -16,15 +17,17 @@ constexpr F mp_for_each_impl(mp_list<T...>, F &&f) {
return
std
::
initializer_list
<
int
>
{(
f
(
T
()),
0
)...},
std
::
forward
<
F
>
(
f
);
return
std
::
initializer_list
<
int
>
{(
f
(
T
()),
0
)...},
std
::
forward
<
F
>
(
f
);
}
}
template
<
class
F
>
constexpr
F
mp_for_each_impl
(
mp_list
<>
,
F
&&
f
)
{
template
<
class
F
>
constexpr
F
mp_for_each_impl
(
mp_list
<>
,
F
&&
f
)
{
return
std
::
forward
<
F
>
(
f
);
return
std
::
forward
<
F
>
(
f
);
}
}
}
// namespace detail
}
// namespace detail
namespace
detail
{
namespace
detail
{
template
<
class
A
,
template
<
class
...
>
class
B
>
struct
mp_rename_impl
{
template
<
class
A
,
template
<
class
...
>
class
B
>
struct
mp_rename_impl
{
// An error "no type named 'type'" here means that the first argument to
// An error "no type named 'type'" here means that the first argument to
// mp_rename is not a list
// mp_rename is not a list
};
};
...
@@ -34,14 +37,15 @@ struct mp_rename_impl<A<T...>, B> {
...
@@ -34,14 +37,15 @@ struct mp_rename_impl<A<T...>, B> {
using
type
=
B
<
T
...
>
;
using
type
=
B
<
T
...
>
;
};
};
}
// namespace detail
}
// namespace detail
template
<
class
A
,
template
<
class
...
>
class
B
>
template
<
class
A
,
template
<
class
...
>
class
B
>
using
mp_rename
=
typename
detail
::
mp_rename_impl
<
A
,
B
>::
type
;
using
mp_rename
=
typename
detail
::
mp_rename_impl
<
A
,
B
>::
type
;
template
<
class
L
,
class
F
>
constexpr
F
mp_for_each
(
F
&&
f
)
{
template
<
class
L
,
class
F
>
constexpr
F
mp_for_each
(
F
&&
f
)
{
return
detail
::
mp_for_each_impl
(
mp_rename
<
L
,
mp_list
>
(),
std
::
forward
<
F
>
(
f
));
return
detail
::
mp_for_each_impl
(
mp_rename
<
L
,
mp_list
>
(),
std
::
forward
<
F
>
(
f
));
}
}
}
// namespace spconv
}
// namespace spconv
#endif
#endif
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment