Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
f27d308f
"vscode:/vscode.git/clone" did not exist on "ce9e07562bceea2741083dca4813cd6d30b5ec4b"
Commit
f27d308f
authored
Jun 07, 2020
by
yinchimaoliang
Browse files
merge master
parents
c66ae813
27ebcfac
Changes
80
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1457 additions
and
1163 deletions
+1457
-1163
mmdet3d/ops/gather_points/src/gather_points_cuda.cu
mmdet3d/ops/gather_points/src/gather_points_cuda.cu
+78
-68
mmdet3d/ops/group_points/src/group_points.cpp
mmdet3d/ops/group_points/src/group_points.cpp
+34
-28
mmdet3d/ops/group_points/src/group_points_cuda.cu
mmdet3d/ops/group_points/src/group_points_cuda.cu
+77
-64
mmdet3d/ops/interpolate/src/interpolate.cpp
mmdet3d/ops/interpolate/src/interpolate.cpp
+62
-53
mmdet3d/ops/interpolate/src/three_interpolate_cuda.cu
mmdet3d/ops/interpolate/src/three_interpolate_cuda.cu
+92
-80
mmdet3d/ops/interpolate/src/three_nn_cuda.cu
mmdet3d/ops/interpolate/src/three_nn_cuda.cu
+68
-56
mmdet3d/ops/iou3d/src/iou3d_kernel.cu
mmdet3d/ops/iou3d/src/iou3d_kernel.cu
+345
-296
mmdet3d/ops/roiaware_pool3d/__init__.py
mmdet3d/ops/roiaware_pool3d/__init__.py
+6
-2
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
+26
-0
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
+76
-0
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
+5
-0
mmdet3d/ops/sparse_block.py
mmdet3d/ops/sparse_block.py
+31
-7
mmdet3d/ops/spconv/include/paramsgrid.h
mmdet3d/ops/spconv/include/paramsgrid.h
+9
-3
mmdet3d/ops/spconv/include/prettyprint.h
mmdet3d/ops/spconv/include/prettyprint.h
+442
-394
mmdet3d/ops/spconv/include/spconv/box_iou.h
mmdet3d/ops/spconv/include/spconv/box_iou.h
+13
-14
mmdet3d/ops/spconv/include/spconv/geometry.h
mmdet3d/ops/spconv/include/spconv/geometry.h
+10
-10
mmdet3d/ops/spconv/include/spconv/indice.cu.h
mmdet3d/ops/spconv/include/spconv/indice.cu.h
+14
-19
mmdet3d/ops/spconv/include/spconv/indice.h
mmdet3d/ops/spconv/include/spconv/indice.h
+49
-48
mmdet3d/ops/spconv/include/spconv/maxpool.h
mmdet3d/ops/spconv/include/spconv/maxpool.h
+9
-14
mmdet3d/ops/spconv/include/spconv/mp_helper.h
mmdet3d/ops/spconv/include/spconv/mp_helper.h
+11
-7
No files found.
mmdet3d/ops/gather_points/src/gather_points_cuda.cu
View file @
f27d308f
...
...
@@ -3,82 +3,92 @@
#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
gather_points_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
out
)
{
// points: (B, C, N)
// idx: (B, M)
// output:
// out: (B, C, M)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
idx
+=
bs_idx
*
m
+
pt_idx
;
points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
[
0
]
=
points
[
idx
[
0
]];
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
out
)
{
// points: (B, C, N)
// idx: (B, M)
// output:
// out: (B, C, M)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
idx
+=
bs_idx
*
m
+
pt_idx
;
points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
[
0
]
=
points
[
idx
[
0
]];
}
void
gather_points_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, N)
// idx: (B, npoints)
// output:
// out: (B, C, npoints)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
gather_points_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
points
,
idx
,
out
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, N)
// idx: (B, npoints)
// output:
// out: (B, C, npoints)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
gather_points_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
points
,
idx
,
out
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
__global__
void
gather_points_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, M)
// idx: (B, M)
// output:
// grad_points: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
grad_out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
idx
+=
bs_idx
*
m
+
pt_idx
;
grad_points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]);
__global__
void
gather_points_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, M)
// idx: (B, M)
// output:
// grad_points: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
grad_out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
idx
+=
bs_idx
*
m
+
pt_idx
;
grad_points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]);
}
void
gather_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, npoints)
// idx: (B, npoints)
// output:
// grad_points: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
gather_points_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, npoints)
// idx: (B, npoints)
// output:
// grad_points: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
gather_points_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/group_points/src/group_points.cpp
View file @
f27d308f
#include <
torch/serialize/tensor
.h>
#include <
THC/THC
.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <THC/THC.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern
THCState
*
state
;
int
group_points_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
);
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
);
void
group_points_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
);
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
);
int
group_points_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
);
void
group_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
);
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
);
void
group_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
);
int
group_points_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
)
{
float
*
grad_points
=
grad_points_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
const
float
*
grad_out
=
grad_out_tensor
.
data
<
float
>
();
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
)
{
float
*
grad_points
=
grad_points_tensor
.
data
_ptr
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
_ptr
<
int
>
();
const
float
*
grad_out
=
grad_out_tensor
.
data
_ptr
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
group_points_grad_kernel_launcher
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
,
stream
);
return
1
;
group_points_grad_kernel_launcher
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
,
stream
);
return
1
;
}
int
group_points_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
)
{
const
float
*
points
=
points_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
float
*
out
=
out_tensor
.
data
<
float
>
();
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
)
{
const
float
*
points
=
points_tensor
.
data
_ptr
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
_ptr
<
int
>
();
float
*
out
=
out_tensor
.
data
_ptr
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
group_points_kernel_launcher
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
,
stream
);
return
1
;
group_points_kernel_launcher
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
,
stream
);
return
1
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
group_points_wrapper
,
"group_points_wrapper"
);
m
.
def
(
"backward"
,
&
group_points_grad_wrapper
,
"group_points_grad_wrapper"
);
m
.
def
(
"forward"
,
&
group_points_wrapper
,
"group_points_wrapper"
);
m
.
def
(
"backward"
,
&
group_points_grad_wrapper
,
"group_points_grad_wrapper"
);
}
mmdet3d/ops/group_points/src/group_points_cuda.cu
View file @
f27d308f
...
...
@@ -2,84 +2,97 @@
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m,
n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
group_points_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, npoints, nsample)
// idx: (B, npoints, nsample)
// output:
// grad_points: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
pt_idx
=
index
/
nsample
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
__global__
void
group_points_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, npoints, nsample)
// idx: (B, npoints, nsample)
// output:
// grad_points: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
pt_idx
=
index
/
nsample
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
int
sample_idx
=
index
%
nsample
;
grad_out
+=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
int
sample_idx
=
index
%
nsample
;
grad_out
+=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
atomicAdd
(
grad_points
+
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
]
,
grad_out
[
0
]);
atomicAdd
(
grad_points
+
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
],
grad_out
[
0
]);
}
void
group_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, npoints, nsample)
// idx: (B, npoints, nsample)
// output:
// grad_points: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
void
group_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, npoints, nsample)
// idx: (B, npoints, nsample)
// output:
// grad_points: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
group_points_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
);
group_points_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
__global__
void
group_points_kernel
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
out
)
{
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
pt_idx
=
index
/
nsample
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
__global__
void
group_points_kernel
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
out
)
{
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
pt_idx
=
index
/
nsample
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
int
sample_idx
=
index
%
nsample
;
int
sample_idx
=
index
%
nsample
;
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
int
in_idx
=
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
];
int
out_idx
=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
int
in_idx
=
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
];
int
out_idx
=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
out
[
out_idx
]
=
points
[
in_idx
];
out
[
out_idx
]
=
points
[
in_idx
];
}
void
group_points_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
group_points_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
group_points_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/interpolate/src/interpolate.cpp
View file @
f27d308f
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern
THCState
*
state
;
void
three_nn_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
unknown_tensor
,
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
);
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
);
void
three_nn_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
unknown
,
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
);
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
);
void
three_interpolate_wrapper
(
int
b
,
int
c
,
int
m
,
int
n
,
at
::
Tensor
points
_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
);
void
three_interpolate_wrapper
(
int
b
,
int
c
,
int
m
,
int
n
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx
_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
);
void
three_interpolate_kernel_launcher
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
);
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
);
void
three_interpolate_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
m
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
);
void
three_interpolate_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
m
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
);
void
three_interpolate_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
);
void
three_interpolate_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
);
void
three_nn_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
unknown_tensor
,
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
)
{
const
float
*
unknown
=
unknown_tensor
.
data
<
float
>
();
const
float
*
known
=
known_tensor
.
data
<
float
>
();
float
*
dist2
=
dist2_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_nn_kernel_launcher
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
,
stream
);
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
)
{
const
float
*
unknown
=
unknown_tensor
.
data_ptr
<
float
>
();
const
float
*
known
=
known_tensor
.
data_ptr
<
float
>
();
float
*
dist2
=
dist2_tensor
.
data_ptr
<
float
>
();
int
*
idx
=
idx_tensor
.
data_ptr
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_nn_kernel_launcher
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
,
stream
);
}
void
three_interpolate_wrapper
(
int
b
,
int
c
,
int
m
,
int
n
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
)
{
const
float
*
points
=
points_tensor
.
data
<
float
>
();
const
float
*
weight
=
weight_tensor
.
data
<
float
>
();
float
*
out
=
out_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_kernel_launcher
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
,
stream
);
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
)
{
const
float
*
points
=
points_tensor
.
data_ptr
<
float
>
();
const
float
*
weight
=
weight_tensor
.
data_ptr
<
float
>
();
float
*
out
=
out_tensor
.
data_ptr
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data_ptr
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_kernel_launcher
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
,
stream
);
}
void
three_interpolate_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
m
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
)
{
const
float
*
grad_out
=
grad_ou
t_tensor
.
data
<
float
>
();
const
float
*
weight
=
weight
_tensor
.
data
<
float
>
();
float
*
grad_points
=
grad_points
_tensor
.
data
<
floa
t
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_grad_kernel_launcher
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
,
stream
);
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
)
{
const
float
*
grad_out
=
grad_out_tensor
.
data_ptr
<
float
>
();
const
float
*
weight
=
weigh
t_tensor
.
data
_ptr
<
float
>
();
float
*
grad_points
=
grad_points
_tensor
.
data
_ptr
<
float
>
();
const
int
*
idx
=
idx
_tensor
.
data
_ptr
<
in
t
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_grad_kernel_launcher
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
,
stream
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"three_nn_wrapper"
,
&
three_nn_wrapper
,
"three_nn_wrapper"
);
m
.
def
(
"three_interpolate_wrapper"
,
&
three_interpolate_wrapper
,
"three_interpolate_wrapper"
);
m
.
def
(
"three_interpolate_grad_wrapper"
,
&
three_interpolate_grad_wrapper
,
"three_interpolate_grad_wrapper"
);
m
.
def
(
"three_nn_wrapper"
,
&
three_nn_wrapper
,
"three_nn_wrapper"
);
m
.
def
(
"three_interpolate_wrapper"
,
&
three_interpolate_wrapper
,
"three_interpolate_wrapper"
);
m
.
def
(
"three_interpolate_grad_wrapper"
,
&
three_interpolate_grad_wrapper
,
"three_interpolate_grad_wrapper"
);
}
mmdet3d/ops/interpolate/src/three_interpolate_cuda.cu
View file @
f27d308f
...
...
@@ -3,91 +3,103 @@
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
three_interpolate_kernel
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
out
)
{
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
weight
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
points
+=
bs_idx
*
c
*
m
+
c_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
out
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
[
pt_idx
]
=
weight
[
0
]
*
points
[
idx
[
0
]]
+
weight
[
1
]
*
points
[
idx
[
1
]]
+
weight
[
2
]
*
points
[
idx
[
2
]];
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
three_interpolate_kernel
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
out
)
{
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
weight
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
points
+=
bs_idx
*
c
*
m
+
c_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
out
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
[
pt_idx
]
=
weight
[
0
]
*
points
[
idx
[
0
]]
+
weight
[
1
]
*
points
[
idx
[
1
]]
+
weight
[
2
]
*
points
[
idx
[
2
]];
}
void
three_interpolate_kernel_launcher
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
__global__
void
three_interpolate_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, N)
//
weigh
t: (B,
N
,
3
)
//
output:
//
grad_points: (B, C, M)
int
bs_idx
=
blockIdx
.
z
;
int
c
_idx
=
blockIdx
.
y
;
int
pt
_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
grad_out
+=
bs_idx
*
c
*
n
+
c_idx
*
n
+
pt_idx
;
weigh
t
+=
bs_idx
*
n
*
3
+
pt
_idx
*
3
;
grad_points
+=
bs_idx
*
c
*
m
+
c
_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt
_idx
*
3
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]
*
weight
[
0
]);
atomicAdd
(
grad_points
+
idx
[
1
],
grad_out
[
0
]
*
weight
[
1
]);
atomicAdd
(
grad_points
+
idx
[
2
],
grad_out
[
0
]
*
weight
[
2
]);
__global__
void
three_interpolate_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
grad_points
)
{
//
grad_ou
t: (B,
C
,
N
)
//
weight: (B, N, 3)
//
output:
// grad_points: (B, C, M)
int
bs
_idx
=
blockIdx
.
z
;
int
c
_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
grad_ou
t
+=
bs_idx
*
c
*
n
+
c
_idx
*
n
+
pt_idx
;
weight
+=
bs_idx
*
n
*
3
+
pt
_idx
*
3
;
grad_points
+=
bs_idx
*
c
*
m
+
c
_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]
*
weight
[
0
]);
atomicAdd
(
grad_points
+
idx
[
1
],
grad_out
[
0
]
*
weight
[
1
]);
atomicAdd
(
grad_points
+
idx
[
2
],
grad_out
[
0
]
*
weight
[
2
]);
}
void
three_interpolate_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
void
three_interpolate_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/interpolate/src/three_nn_cuda.cu
View file @
f27d308f
...
...
@@ -3,72 +3,84 @@
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m,
n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
three_nn_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
unknown
,
const
float
*
__restrict__
known
,
float
*
__restrict__
dist2
,
int
*
__restrict__
idx
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
__global__
void
three_nn_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
unknown
,
const
float
*
__restrict__
known
,
float
*
__restrict__
dist2
,
int
*
__restrict__
idx
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
pt_idx
>=
n
)
return
;
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
pt_idx
>=
n
)
return
;
unknown
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
known
+=
bs_idx
*
m
*
3
;
dist2
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
unknown
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
known
+=
bs_idx
*
m
*
3
;
dist2
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
float
ux
=
unknown
[
0
];
float
uy
=
unknown
[
1
];
float
uz
=
unknown
[
2
];
float
ux
=
unknown
[
0
];
float
uy
=
unknown
[
1
];
float
uz
=
unknown
[
2
];
double
best1
=
1e40
,
best2
=
1e40
,
best3
=
1e40
;
int
besti1
=
0
,
besti2
=
0
,
besti3
=
0
;
for
(
int
k
=
0
;
k
<
m
;
++
k
)
{
float
x
=
known
[
k
*
3
+
0
];
float
y
=
known
[
k
*
3
+
1
];
float
z
=
known
[
k
*
3
+
2
];
float
d
=
(
ux
-
x
)
*
(
ux
-
x
)
+
(
uy
-
y
)
*
(
uy
-
y
)
+
(
uz
-
z
)
*
(
uz
-
z
);
if
(
d
<
best1
)
{
best3
=
best2
;
besti3
=
besti2
;
best2
=
best1
;
besti2
=
besti1
;
best1
=
d
;
besti1
=
k
;
}
else
if
(
d
<
best2
)
{
best3
=
best2
;
besti3
=
besti2
;
best2
=
d
;
besti2
=
k
;
}
else
if
(
d
<
best3
)
{
best3
=
d
;
besti3
=
k
;
}
double
best1
=
1e40
,
best2
=
1e40
,
best3
=
1e40
;
int
besti1
=
0
,
besti2
=
0
,
besti3
=
0
;
for
(
int
k
=
0
;
k
<
m
;
++
k
)
{
float
x
=
known
[
k
*
3
+
0
];
float
y
=
known
[
k
*
3
+
1
];
float
z
=
known
[
k
*
3
+
2
];
float
d
=
(
ux
-
x
)
*
(
ux
-
x
)
+
(
uy
-
y
)
*
(
uy
-
y
)
+
(
uz
-
z
)
*
(
uz
-
z
);
if
(
d
<
best1
)
{
best3
=
best2
;
besti3
=
besti2
;
best2
=
best1
;
besti2
=
besti1
;
best1
=
d
;
besti1
=
k
;
}
else
if
(
d
<
best2
)
{
best3
=
best2
;
besti3
=
besti2
;
best2
=
d
;
besti2
=
k
;
}
else
if
(
d
<
best3
)
{
best3
=
d
;
besti3
=
k
;
}
dist2
[
0
]
=
best1
;
dist2
[
1
]
=
best2
;
dist2
[
2
]
=
best3
;
idx
[
0
]
=
besti1
;
idx
[
1
]
=
besti2
;
idx
[
2
]
=
besti3
;
}
dist2
[
0
]
=
best1
;
dist2
[
1
]
=
best2
;
dist2
[
2
]
=
best3
;
idx
[
0
]
=
besti1
;
idx
[
1
]
=
besti2
;
idx
[
2
]
=
besti3
;
}
void
three_nn_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
unknown
,
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_nn_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
);
three_nn_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/iou3d/src/iou3d_kernel.cu
View file @
f27d308f
...
...
@@ -6,376 +6,425 @@
const
int
THREADS_PER_BLOCK_NMS
=
sizeof
(
unsigned
long
long
)
*
8
;
const
float
EPS
=
1e-8
;
struct
Point
{
float
x
,
y
;
__device__
Point
()
{}
__device__
Point
(
double
_x
,
double
_y
){
x
=
_x
,
y
=
_y
;
}
__device__
void
set
(
float
_x
,
float
_y
){
x
=
_x
;
y
=
_y
;
}
__device__
Point
operator
+
(
const
Point
&
b
)
const
{
return
Point
(
x
+
b
.
x
,
y
+
b
.
y
);
}
__device__
Point
operator
-
(
const
Point
&
b
)
const
{
return
Point
(
x
-
b
.
x
,
y
-
b
.
y
);
}
float
x
,
y
;
__device__
Point
()
{}
__device__
Point
(
double
_x
,
double
_y
)
{
x
=
_x
,
y
=
_y
;
}
__device__
void
set
(
float
_x
,
float
_y
)
{
x
=
_x
;
y
=
_y
;
}
__device__
Point
operator
+
(
const
Point
&
b
)
const
{
return
Point
(
x
+
b
.
x
,
y
+
b
.
y
);
}
__device__
Point
operator
-
(
const
Point
&
b
)
const
{
return
Point
(
x
-
b
.
x
,
y
-
b
.
y
);
}
};
__device__
inline
float
cross
(
const
Point
&
a
,
const
Point
&
b
){
return
a
.
x
*
b
.
y
-
a
.
y
*
b
.
x
;
__device__
inline
float
cross
(
const
Point
&
a
,
const
Point
&
b
)
{
return
a
.
x
*
b
.
y
-
a
.
y
*
b
.
x
;
}
__device__
inline
float
cross
(
const
Point
&
p1
,
const
Point
&
p2
,
const
Point
&
p0
){
return
(
p1
.
x
-
p0
.
x
)
*
(
p2
.
y
-
p0
.
y
)
-
(
p2
.
x
-
p0
.
x
)
*
(
p1
.
y
-
p0
.
y
);
__device__
inline
float
cross
(
const
Point
&
p1
,
const
Point
&
p2
,
const
Point
&
p0
)
{
return
(
p1
.
x
-
p0
.
x
)
*
(
p2
.
y
-
p0
.
y
)
-
(
p2
.
x
-
p0
.
x
)
*
(
p1
.
y
-
p0
.
y
);
}
__device__
int
check_rect_cross
(
const
Point
&
p1
,
const
Point
&
p2
,
const
Point
&
q1
,
const
Point
&
q2
){
int
ret
=
min
(
p1
.
x
,
p2
.
x
)
<=
max
(
q1
.
x
,
q2
.
x
)
&&
min
(
q1
.
x
,
q2
.
x
)
<=
max
(
p1
.
x
,
p2
.
x
)
&&
min
(
p1
.
y
,
p2
.
y
)
<=
max
(
q1
.
y
,
q2
.
y
)
&&
min
(
q1
.
y
,
q2
.
y
)
<=
max
(
p1
.
y
,
p2
.
y
);
return
ret
;
__device__
int
check_rect_cross
(
const
Point
&
p1
,
const
Point
&
p2
,
const
Point
&
q1
,
const
Point
&
q2
)
{
int
ret
=
min
(
p1
.
x
,
p2
.
x
)
<=
max
(
q1
.
x
,
q2
.
x
)
&&
min
(
q1
.
x
,
q2
.
x
)
<=
max
(
p1
.
x
,
p2
.
x
)
&&
min
(
p1
.
y
,
p2
.
y
)
<=
max
(
q1
.
y
,
q2
.
y
)
&&
min
(
q1
.
y
,
q2
.
y
)
<=
max
(
p1
.
y
,
p2
.
y
);
return
ret
;
}
__device__
inline
int
check_in_box2d
(
const
float
*
box
,
const
Point
&
p
){
//params: box (5) [x1, y1, x2, y2, angle]
const
float
MARGIN
=
1e-5
;
float
center_x
=
(
box
[
0
]
+
box
[
2
])
/
2
;
float
center_y
=
(
box
[
1
]
+
box
[
3
])
/
2
;
float
angle_cos
=
cos
(
-
box
[
4
]),
angle_sin
=
sin
(
-
box
[
4
]);
// rotate the point in the opposite direction of box
float
rot_x
=
(
p
.
x
-
center_x
)
*
angle_cos
+
(
p
.
y
-
center_y
)
*
angle_sin
+
center_x
;
float
rot_y
=
-
(
p
.
x
-
center_x
)
*
angle_sin
+
(
p
.
y
-
center_y
)
*
angle_cos
+
center_y
;
__device__
inline
int
check_in_box2d
(
const
float
*
box
,
const
Point
&
p
)
{
// params: box (5) [x1, y1, x2, y2, angle]
const
float
MARGIN
=
1e-5
;
float
center_x
=
(
box
[
0
]
+
box
[
2
])
/
2
;
float
center_y
=
(
box
[
1
]
+
box
[
3
])
/
2
;
float
angle_cos
=
cos
(
-
box
[
4
]),
angle_sin
=
sin
(
-
box
[
4
]);
// rotate the point in the opposite direction of box
float
rot_x
=
(
p
.
x
-
center_x
)
*
angle_cos
+
(
p
.
y
-
center_y
)
*
angle_sin
+
center_x
;
float
rot_y
=
-
(
p
.
x
-
center_x
)
*
angle_sin
+
(
p
.
y
-
center_y
)
*
angle_cos
+
center_y
;
#ifdef DEBUG
printf
(
"box: (%.3f, %.3f, %.3f, %.3f, %.3f)
\n
"
,
box
[
0
],
box
[
1
],
box
[
2
],
box
[
3
],
box
[
4
]);
printf
(
"center: (%.3f, %.3f), cossin(%.3f, %.3f), src(%.3f, %.3f), rot(%.3f, %.3f)
\n
"
,
center_x
,
center_y
,
angle_cos
,
angle_sin
,
p
.
x
,
p
.
y
,
rot_x
,
rot_y
);
printf
(
"box: (%.3f, %.3f, %.3f, %.3f, %.3f)
\n
"
,
box
[
0
],
box
[
1
],
box
[
2
],
box
[
3
],
box
[
4
]);
printf
(
"center: (%.3f, %.3f), cossin(%.3f, %.3f), src(%.3f, %.3f), rot(%.3f, "
"%.3f)
\n
"
,
center_x
,
center_y
,
angle_cos
,
angle_sin
,
p
.
x
,
p
.
y
,
rot_x
,
rot_y
);
#endif
return
(
rot_x
>
box
[
0
]
-
MARGIN
&&
rot_x
<
box
[
2
]
+
MARGIN
&&
rot_y
>
box
[
1
]
-
MARGIN
&&
rot_y
<
box
[
3
]
+
MARGIN
);
return
(
rot_x
>
box
[
0
]
-
MARGIN
&&
rot_x
<
box
[
2
]
+
MARGIN
&&
rot_y
>
box
[
1
]
-
MARGIN
&&
rot_y
<
box
[
3
]
+
MARGIN
);
}
__device__
inline
int
intersection
(
const
Point
&
p1
,
const
Point
&
p0
,
const
Point
&
q1
,
const
Point
&
q0
,
Point
&
ans
){
// fast exclusion
if
(
check_rect_cross
(
p0
,
p1
,
q0
,
q1
)
==
0
)
return
0
;
__device__
inline
int
intersection
(
const
Point
&
p1
,
const
Point
&
p0
,
const
Point
&
q1
,
const
Point
&
q0
,
Point
&
ans
)
{
// fast exclusion
if
(
check_rect_cross
(
p0
,
p1
,
q0
,
q1
)
==
0
)
return
0
;
// check cross standing
float
s1
=
cross
(
q0
,
p1
,
p0
);
float
s2
=
cross
(
p1
,
q1
,
p0
);
float
s3
=
cross
(
p0
,
q1
,
q0
);
float
s4
=
cross
(
q1
,
p1
,
q0
);
// check cross standing
float
s1
=
cross
(
q0
,
p1
,
p0
);
float
s2
=
cross
(
p1
,
q1
,
p0
);
float
s3
=
cross
(
p0
,
q1
,
q0
);
float
s4
=
cross
(
q1
,
p1
,
q0
);
if
(
!
(
s1
*
s2
>
0
&&
s3
*
s4
>
0
))
return
0
;
if
(
!
(
s1
*
s2
>
0
&&
s3
*
s4
>
0
))
return
0
;
// calculate intersection of two lines
float
s5
=
cross
(
q1
,
p1
,
p0
);
if
(
fabs
(
s5
-
s1
)
>
EPS
){
ans
.
x
=
(
s5
*
q0
.
x
-
s1
*
q1
.
x
)
/
(
s5
-
s1
);
ans
.
y
=
(
s5
*
q0
.
y
-
s1
*
q1
.
y
)
/
(
s5
-
s1
);
// calculate intersection of two lines
float
s5
=
cross
(
q1
,
p1
,
p0
);
if
(
fabs
(
s5
-
s1
)
>
EPS
)
{
ans
.
x
=
(
s5
*
q0
.
x
-
s1
*
q1
.
x
)
/
(
s5
-
s1
);
ans
.
y
=
(
s5
*
q0
.
y
-
s1
*
q1
.
y
)
/
(
s5
-
s1
);
}
else
{
float
a0
=
p0
.
y
-
p1
.
y
,
b0
=
p1
.
x
-
p0
.
x
,
c0
=
p0
.
x
*
p1
.
y
-
p1
.
x
*
p0
.
y
;
float
a1
=
q0
.
y
-
q1
.
y
,
b1
=
q1
.
x
-
q0
.
x
,
c1
=
q0
.
x
*
q1
.
y
-
q1
.
x
*
q0
.
y
;
float
D
=
a0
*
b1
-
a1
*
b0
;
}
else
{
float
a0
=
p0
.
y
-
p1
.
y
,
b0
=
p1
.
x
-
p0
.
x
,
c0
=
p0
.
x
*
p1
.
y
-
p1
.
x
*
p0
.
y
;
float
a1
=
q0
.
y
-
q1
.
y
,
b1
=
q1
.
x
-
q0
.
x
,
c1
=
q0
.
x
*
q1
.
y
-
q1
.
x
*
q0
.
y
;
float
D
=
a0
*
b1
-
a1
*
b0
;
ans
.
x
=
(
b0
*
c1
-
b1
*
c0
)
/
D
;
ans
.
y
=
(
a1
*
c0
-
a0
*
c1
)
/
D
;
}
ans
.
x
=
(
b0
*
c1
-
b1
*
c0
)
/
D
;
ans
.
y
=
(
a1
*
c0
-
a0
*
c1
)
/
D
;
}
return
1
;
return
1
;
}
__device__
inline
void
rotate_around_center
(
const
Point
&
center
,
const
float
angle_cos
,
const
float
angle_sin
,
Point
&
p
){
float
new_x
=
(
p
.
x
-
center
.
x
)
*
angle_cos
+
(
p
.
y
-
center
.
y
)
*
angle_sin
+
center
.
x
;
float
new_y
=
-
(
p
.
x
-
center
.
x
)
*
angle_sin
+
(
p
.
y
-
center
.
y
)
*
angle_cos
+
center
.
y
;
p
.
set
(
new_x
,
new_y
);
__device__
inline
void
rotate_around_center
(
const
Point
&
center
,
const
float
angle_cos
,
const
float
angle_sin
,
Point
&
p
)
{
float
new_x
=
(
p
.
x
-
center
.
x
)
*
angle_cos
+
(
p
.
y
-
center
.
y
)
*
angle_sin
+
center
.
x
;
float
new_y
=
-
(
p
.
x
-
center
.
x
)
*
angle_sin
+
(
p
.
y
-
center
.
y
)
*
angle_cos
+
center
.
y
;
p
.
set
(
new_x
,
new_y
);
}
__device__
inline
int
point_cmp
(
const
Point
&
a
,
const
Point
&
b
,
const
Point
&
center
){
return
atan2
(
a
.
y
-
center
.
y
,
a
.
x
-
center
.
x
)
>
atan2
(
b
.
y
-
center
.
y
,
b
.
x
-
center
.
x
);
__device__
inline
int
point_cmp
(
const
Point
&
a
,
const
Point
&
b
,
const
Point
&
center
)
{
return
atan2
(
a
.
y
-
center
.
y
,
a
.
x
-
center
.
x
)
>
atan2
(
b
.
y
-
center
.
y
,
b
.
x
-
center
.
x
);
}
__device__
inline
float
box_overlap
(
const
float
*
box_a
,
const
float
*
box_b
){
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
__device__
inline
float
box_overlap
(
const
float
*
box_a
,
const
float
*
box_b
)
{
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
float
a_x1
=
box_a
[
0
],
a_y1
=
box_a
[
1
],
a_x2
=
box_a
[
2
],
a_y2
=
box_a
[
3
],
a_angle
=
box_a
[
4
];
float
b_x1
=
box_b
[
0
],
b_y1
=
box_b
[
1
],
b_x2
=
box_b
[
2
],
b_y2
=
box_b
[
3
],
b_angle
=
box_b
[
4
];
float
a_x1
=
box_a
[
0
],
a_y1
=
box_a
[
1
],
a_x2
=
box_a
[
2
],
a_y2
=
box_a
[
3
],
a_angle
=
box_a
[
4
];
float
b_x1
=
box_b
[
0
],
b_y1
=
box_b
[
1
],
b_x2
=
box_b
[
2
],
b_y2
=
box_b
[
3
],
b_angle
=
box_b
[
4
];
Point
center_a
((
a_x1
+
a_x2
)
/
2
,
(
a_y1
+
a_y2
)
/
2
);
Point
center_b
((
b_x1
+
b_x2
)
/
2
,
(
b_y1
+
b_y2
)
/
2
);
Point
center_a
((
a_x1
+
a_x2
)
/
2
,
(
a_y1
+
a_y2
)
/
2
);
Point
center_b
((
b_x1
+
b_x2
)
/
2
,
(
b_y1
+
b_y2
)
/
2
);
#ifdef DEBUG
printf
(
"a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)
\n
"
,
a_x1
,
a_y1
,
a_x2
,
a_y2
,
a_angle
,
b_x1
,
b_y1
,
b_x2
,
b_y2
,
b_angle
);
printf
(
"center a: (%.3f, %.3f), b: (%.3f, %.3f)
\n
"
,
center_a
.
x
,
center_a
.
y
,
center_b
.
x
,
center_b
.
y
);
printf
(
"a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)
\n
"
,
a_x1
,
a_y1
,
a_x2
,
a_y2
,
a_angle
,
b_x1
,
b_y1
,
b_x2
,
b_y2
,
b_angle
);
printf
(
"center a: (%.3f, %.3f), b: (%.3f, %.3f)
\n
"
,
center_a
.
x
,
center_a
.
y
,
center_b
.
x
,
center_b
.
y
);
#endif
Point
box_a_corners
[
5
];
box_a_corners
[
0
].
set
(
a_x1
,
a_y1
);
box_a_corners
[
1
].
set
(
a_x2
,
a_y1
);
box_a_corners
[
2
].
set
(
a_x2
,
a_y2
);
box_a_corners
[
3
].
set
(
a_x1
,
a_y2
);
Point
box_a_corners
[
5
];
box_a_corners
[
0
].
set
(
a_x1
,
a_y1
);
box_a_corners
[
1
].
set
(
a_x2
,
a_y1
);
box_a_corners
[
2
].
set
(
a_x2
,
a_y2
);
box_a_corners
[
3
].
set
(
a_x1
,
a_y2
);
Point
box_b_corners
[
5
];
box_b_corners
[
0
].
set
(
b_x1
,
b_y1
);
box_b_corners
[
1
].
set
(
b_x2
,
b_y1
);
box_b_corners
[
2
].
set
(
b_x2
,
b_y2
);
box_b_corners
[
3
].
set
(
b_x1
,
b_y2
);
Point
box_b_corners
[
5
];
box_b_corners
[
0
].
set
(
b_x1
,
b_y1
);
box_b_corners
[
1
].
set
(
b_x2
,
b_y1
);
box_b_corners
[
2
].
set
(
b_x2
,
b_y2
);
box_b_corners
[
3
].
set
(
b_x1
,
b_y2
);
// get oriented corners
float
a_angle_cos
=
cos
(
a_angle
),
a_angle_sin
=
sin
(
a_angle
);
float
b_angle_cos
=
cos
(
b_angle
),
b_angle_sin
=
sin
(
b_angle
);
// get oriented corners
float
a_angle_cos
=
cos
(
a_angle
),
a_angle_sin
=
sin
(
a_angle
);
float
b_angle_cos
=
cos
(
b_angle
),
b_angle_sin
=
sin
(
b_angle
);
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
#ifdef DEBUG
printf
(
"before corner %d: a(%.3f, %.3f), b(%.3f, %.3f)
\n
"
,
k
,
box_a_corners
[
k
].
x
,
box_a_corners
[
k
].
y
,
box_b_corners
[
k
].
x
,
box_b_corners
[
k
].
y
);
printf
(
"before corner %d: a(%.3f, %.3f), b(%.3f, %.3f)
\n
"
,
k
,
box_a_corners
[
k
].
x
,
box_a_corners
[
k
].
y
,
box_b_corners
[
k
].
x
,
box_b_corners
[
k
].
y
);
#endif
rotate_around_center
(
center_a
,
a_angle_cos
,
a_angle_sin
,
box_a_corners
[
k
]);
rotate_around_center
(
center_b
,
b_angle_cos
,
b_angle_sin
,
box_b_corners
[
k
]);
rotate_around_center
(
center_a
,
a_angle_cos
,
a_angle_sin
,
box_a_corners
[
k
]);
rotate_around_center
(
center_b
,
b_angle_cos
,
b_angle_sin
,
box_b_corners
[
k
]);
#ifdef DEBUG
printf
(
"corner %d: a(%.3f, %.3f), b(%.3f, %.3f)
\n
"
,
k
,
box_a_corners
[
k
].
x
,
box_a_corners
[
k
].
y
,
box_b_corners
[
k
].
x
,
box_b_corners
[
k
].
y
);
printf
(
"corner %d: a(%.3f, %.3f), b(%.3f, %.3f)
\n
"
,
k
,
box_a_corners
[
k
].
x
,
box_a_corners
[
k
].
y
,
box_b_corners
[
k
].
x
,
box_b_corners
[
k
].
y
);
#endif
}
box_a_corners
[
4
]
=
box_a_corners
[
0
];
box_b_corners
[
4
]
=
box_b_corners
[
0
];
// get intersection of lines
Point
cross_points
[
16
];
Point
poly_center
;
int
cnt
=
0
,
flag
=
0
;
poly_center
.
set
(
0
,
0
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
flag
=
intersection
(
box_a_corners
[
i
+
1
],
box_a_corners
[
i
],
box_b_corners
[
j
+
1
],
box_b_corners
[
j
],
cross_points
[
cnt
]);
if
(
flag
)
{
poly_center
=
poly_center
+
cross_points
[
cnt
];
cnt
++
;
}
}
box_a_corners
[
4
]
=
box_a_corners
[
0
];
box_b_corners
[
4
]
=
box_b_corners
[
0
];
// get intersection of lines
Point
cross_points
[
16
];
Point
poly_center
;
int
cnt
=
0
,
flag
=
0
;
poly_center
.
set
(
0
,
0
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
for
(
int
j
=
0
;
j
<
4
;
j
++
){
flag
=
intersection
(
box_a_corners
[
i
+
1
],
box_a_corners
[
i
],
box_b_corners
[
j
+
1
],
box_b_corners
[
j
],
cross_points
[
cnt
]);
if
(
flag
){
poly_center
=
poly_center
+
cross_points
[
cnt
];
cnt
++
;
}
}
}
// check corners
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
if
(
check_in_box2d
(
box_a
,
box_b_corners
[
k
]))
{
poly_center
=
poly_center
+
box_b_corners
[
k
];
cross_points
[
cnt
]
=
box_b_corners
[
k
];
cnt
++
;
}
// check corners
for
(
int
k
=
0
;
k
<
4
;
k
++
){
if
(
check_in_box2d
(
box_a
,
box_b_corners
[
k
])){
poly_center
=
poly_center
+
box_b_corners
[
k
];
cross_points
[
cnt
]
=
box_b_corners
[
k
];
cnt
++
;
}
if
(
check_in_box2d
(
box_b
,
box_a_corners
[
k
])){
poly_center
=
poly_center
+
box_a_corners
[
k
];
cross_points
[
cnt
]
=
box_a_corners
[
k
];
cnt
++
;
}
if
(
check_in_box2d
(
box_b
,
box_a_corners
[
k
]))
{
poly_center
=
poly_center
+
box_a_corners
[
k
];
cross_points
[
cnt
]
=
box_a_corners
[
k
];
cnt
++
;
}
poly_center
.
x
/=
cnt
;
poly_center
.
y
/=
cnt
;
// sort the points of polygon
Point
temp
;
for
(
int
j
=
0
;
j
<
cnt
-
1
;
j
++
){
for
(
int
i
=
0
;
i
<
cnt
-
j
-
1
;
i
++
){
if
(
point_cmp
(
cross_points
[
i
],
cross_points
[
i
+
1
],
poly_center
))
{
temp
=
cross_points
[
i
];
cross_points
[
i
]
=
cross_points
[
i
+
1
];
cross_points
[
i
+
1
]
=
temp
;
}
}
}
poly_center
.
x
/=
cnt
;
poly_center
.
y
/=
cnt
;
// sort the points of polygon
Point
temp
;
for
(
int
j
=
0
;
j
<
cnt
-
1
;
j
++
)
{
for
(
int
i
=
0
;
i
<
cnt
-
j
-
1
;
i
++
)
{
if
(
point_cmp
(
cross_points
[
i
],
cross_points
[
i
+
1
],
poly_center
))
{
temp
=
cross_points
[
i
];
cross_points
[
i
]
=
cross_points
[
i
+
1
];
cross_points
[
i
+
1
]
=
temp
;
}
}
}
#ifdef DEBUG
printf
(
"cnt=%d
\n
"
,
cnt
);
for
(
int
i
=
0
;
i
<
cnt
;
i
++
){
printf
(
"All cross point %d: (%.3f, %.3f)
\n
"
,
i
,
cross_points
[
i
].
x
,
cross_points
[
i
].
y
);
}
printf
(
"cnt=%d
\n
"
,
cnt
);
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
printf
(
"All cross point %d: (%.3f, %.3f)
\n
"
,
i
,
cross_points
[
i
].
x
,
cross_points
[
i
].
y
);
}
#endif
// get the overlap areas
float
area
=
0
;
for
(
int
k
=
0
;
k
<
cnt
-
1
;
k
++
){
area
+=
cross
(
cross_points
[
k
]
-
cross_points
[
0
],
cross_points
[
k
+
1
]
-
cross_points
[
0
]);
}
// get the overlap areas
float
area
=
0
;
for
(
int
k
=
0
;
k
<
cnt
-
1
;
k
++
)
{
area
+=
cross
(
cross_points
[
k
]
-
cross_points
[
0
],
cross_points
[
k
+
1
]
-
cross_points
[
0
]);
}
return
fabs
(
area
)
/
2.0
;
return
fabs
(
area
)
/
2.0
;
}
__device__
inline
float
iou_bev
(
const
float
*
box_a
,
const
float
*
box_b
){
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
float
sa
=
(
box_a
[
2
]
-
box_a
[
0
])
*
(
box_a
[
3
]
-
box_a
[
1
]);
float
sb
=
(
box_b
[
2
]
-
box_b
[
0
])
*
(
box_b
[
3
]
-
box_b
[
1
]);
float
s_overlap
=
box_overlap
(
box_a
,
box_b
);
return
s_overlap
/
fmaxf
(
sa
+
sb
-
s_overlap
,
EPS
);
__device__
inline
float
iou_bev
(
const
float
*
box_a
,
const
float
*
box_b
)
{
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
float
sa
=
(
box_a
[
2
]
-
box_a
[
0
])
*
(
box_a
[
3
]
-
box_a
[
1
]);
float
sb
=
(
box_b
[
2
]
-
box_b
[
0
])
*
(
box_b
[
3
]
-
box_b
[
1
]);
float
s_overlap
=
box_overlap
(
box_a
,
box_b
);
return
s_overlap
/
fmaxf
(
sa
+
sb
-
s_overlap
,
EPS
);
}
__global__
void
boxes_overlap_kernel
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_overlap
){
const
int
a_idx
=
blockIdx
.
y
*
THREADS_PER_BLOCK
+
threadIdx
.
y
;
const
int
b_idx
=
blockIdx
.
x
*
THREADS_PER_BLOCK
+
threadIdx
.
x
;
if
(
a_idx
>=
num_a
||
b_idx
>=
num_b
){
return
;
}
const
float
*
cur_box_a
=
boxes_a
+
a_idx
*
5
;
const
float
*
cur_box_b
=
boxes_b
+
b_idx
*
5
;
float
s_overlap
=
box_overlap
(
cur_box_a
,
cur_box_b
);
ans_overlap
[
a_idx
*
num_b
+
b_idx
]
=
s_overlap
;
__global__
void
boxes_overlap_kernel
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_overlap
)
{
const
int
a_idx
=
blockIdx
.
y
*
THREADS_PER_BLOCK
+
threadIdx
.
y
;
const
int
b_idx
=
blockIdx
.
x
*
THREADS_PER_BLOCK
+
threadIdx
.
x
;
if
(
a_idx
>=
num_a
||
b_idx
>=
num_b
)
{
return
;
}
const
float
*
cur_box_a
=
boxes_a
+
a_idx
*
5
;
const
float
*
cur_box_b
=
boxes_b
+
b_idx
*
5
;
float
s_overlap
=
box_overlap
(
cur_box_a
,
cur_box_b
);
ans_overlap
[
a_idx
*
num_b
+
b_idx
]
=
s_overlap
;
}
__global__
void
boxes_iou_bev_kernel
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_iou
){
const
int
a_idx
=
blockIdx
.
y
*
THREADS_PER_BLOCK
+
threadIdx
.
y
;
const
int
b_idx
=
blockIdx
.
x
*
THREADS_PER_BLOCK
+
threadIdx
.
x
;
__global__
void
boxes_iou_bev_kernel
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_iou
)
{
const
int
a_idx
=
blockIdx
.
y
*
THREADS_PER_BLOCK
+
threadIdx
.
y
;
const
int
b_idx
=
blockIdx
.
x
*
THREADS_PER_BLOCK
+
threadIdx
.
x
;
if
(
a_idx
>=
num_a
||
b_idx
>=
num_b
){
return
;
}
if
(
a_idx
>=
num_a
||
b_idx
>=
num_b
)
{
return
;
}
const
float
*
cur_box_a
=
boxes_a
+
a_idx
*
5
;
const
float
*
cur_box_b
=
boxes_b
+
b_idx
*
5
;
float
cur_iou_bev
=
iou_bev
(
cur_box_a
,
cur_box_b
);
ans_iou
[
a_idx
*
num_b
+
b_idx
]
=
cur_iou_bev
;
const
float
*
cur_box_a
=
boxes_a
+
a_idx
*
5
;
const
float
*
cur_box_b
=
boxes_b
+
b_idx
*
5
;
float
cur_iou_bev
=
iou_bev
(
cur_box_a
,
cur_box_b
);
ans_iou
[
a_idx
*
num_b
+
b_idx
]
=
cur_iou_bev
;
}
__global__
void
nms_kernel
(
const
int
boxes_num
,
const
float
nms_overlap_thresh
,
const
float
*
boxes
,
unsigned
long
long
*
mask
){
//params: boxes (N, 5) [x1, y1, x2, y2, ry]
//params: mask (N, N/THREADS_PER_BLOCK_NMS)
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
// if (row_start > col_start) return;
const
int
row_size
=
fminf
(
boxes_num
-
row_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
const
int
col_size
=
fminf
(
boxes_num
-
col_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
__shared__
float
block_boxes
[
THREADS_PER_BLOCK_NMS
*
5
];
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
5
+
0
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
0
];
block_boxes
[
threadIdx
.
x
*
5
+
1
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
1
];
block_boxes
[
threadIdx
.
x
*
5
+
2
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
2
];
block_boxes
[
threadIdx
.
x
*
5
+
3
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
3
];
block_boxes
[
threadIdx
.
x
*
5
+
4
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
4
];
const
float
*
boxes
,
unsigned
long
long
*
mask
)
{
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
// if (row_start > col_start) return;
const
int
row_size
=
fminf
(
boxes_num
-
row_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
const
int
col_size
=
fminf
(
boxes_num
-
col_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
__shared__
float
block_boxes
[
THREADS_PER_BLOCK_NMS
*
5
];
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
5
+
0
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
0
];
block_boxes
[
threadIdx
.
x
*
5
+
1
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
1
];
block_boxes
[
threadIdx
.
x
*
5
+
2
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
2
];
block_boxes
[
threadIdx
.
x
*
5
+
3
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
3
];
block_boxes
[
threadIdx
.
x
*
5
+
4
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
4
];
}
__syncthreads
();
if
(
threadIdx
.
x
<
row_size
)
{
const
int
cur_box_idx
=
THREADS_PER_BLOCK_NMS
*
row_start
+
threadIdx
.
x
;
const
float
*
cur_box
=
boxes
+
cur_box_idx
*
5
;
int
i
=
0
;
unsigned
long
long
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
__syncthreads
();
if
(
threadIdx
.
x
<
row_size
)
{
const
int
cur_box_idx
=
THREADS_PER_BLOCK_NMS
*
row_start
+
threadIdx
.
x
;
const
float
*
cur_box
=
boxes
+
cur_box_idx
*
5
;
int
i
=
0
;
unsigned
long
long
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
iou_bev
(
cur_box
,
block_boxes
+
i
*
5
)
>
nms_overlap_thresh
){
t
|=
1ULL
<<
i
;
}
}
const
int
col_blocks
=
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
);
mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
iou_bev
(
cur_box
,
block_boxes
+
i
*
5
)
>
nms_overlap_thresh
)
{
t
|=
1ULL
<<
i
;
}
}
const
int
col_blocks
=
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
);
mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
}
}
__device__
inline
float
iou_normal
(
float
const
*
const
a
,
float
const
*
const
b
)
{
float
left
=
fmaxf
(
a
[
0
],
b
[
0
]),
right
=
fminf
(
a
[
2
],
b
[
2
]);
float
top
=
fmaxf
(
a
[
1
],
b
[
1
]),
bottom
=
fminf
(
a
[
3
],
b
[
3
]);
float
width
=
fmaxf
(
right
-
left
,
0.
f
),
height
=
fmaxf
(
bottom
-
top
,
0.
f
);
float
interS
=
width
*
height
;
float
Sa
=
(
a
[
2
]
-
a
[
0
])
*
(
a
[
3
]
-
a
[
1
]);
float
Sb
=
(
b
[
2
]
-
b
[
0
])
*
(
b
[
3
]
-
b
[
1
]);
return
interS
/
fmaxf
(
Sa
+
Sb
-
interS
,
EPS
);
__device__
inline
float
iou_normal
(
float
const
*
const
a
,
float
const
*
const
b
)
{
float
left
=
fmaxf
(
a
[
0
],
b
[
0
]),
right
=
fminf
(
a
[
2
],
b
[
2
]);
float
top
=
fmaxf
(
a
[
1
],
b
[
1
]),
bottom
=
fminf
(
a
[
3
],
b
[
3
]);
float
width
=
fmaxf
(
right
-
left
,
0.
f
),
height
=
fmaxf
(
bottom
-
top
,
0.
f
);
float
interS
=
width
*
height
;
float
Sa
=
(
a
[
2
]
-
a
[
0
])
*
(
a
[
3
]
-
a
[
1
]);
float
Sb
=
(
b
[
2
]
-
b
[
0
])
*
(
b
[
3
]
-
b
[
1
]);
return
interS
/
fmaxf
(
Sa
+
Sb
-
interS
,
EPS
);
}
__global__
void
nms_normal_kernel
(
const
int
boxes_num
,
const
float
nms_overlap_thresh
,
const
float
*
boxes
,
unsigned
long
long
*
mask
){
//params: boxes (N, 5) [x1, y1, x2, y2, ry]
//params: mask (N, N/THREADS_PER_BLOCK_NMS)
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
// if (row_start > col_start) return;
const
int
row_size
=
fminf
(
boxes_num
-
row_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
const
int
col_size
=
fminf
(
boxes_num
-
col_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
__shared__
float
block_boxes
[
THREADS_PER_BLOCK_NMS
*
5
];
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
5
+
0
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
0
];
block_boxes
[
threadIdx
.
x
*
5
+
1
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
1
];
block_boxes
[
threadIdx
.
x
*
5
+
2
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
2
];
block_boxes
[
threadIdx
.
x
*
5
+
3
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
3
];
block_boxes
[
threadIdx
.
x
*
5
+
4
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
4
];
__global__
void
nms_normal_kernel
(
const
int
boxes_num
,
const
float
nms_overlap_thresh
,
const
float
*
boxes
,
unsigned
long
long
*
mask
)
{
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
// if (row_start > col_start) return;
const
int
row_size
=
fminf
(
boxes_num
-
row_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
const
int
col_size
=
fminf
(
boxes_num
-
col_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
__shared__
float
block_boxes
[
THREADS_PER_BLOCK_NMS
*
5
];
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
5
+
0
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
0
];
block_boxes
[
threadIdx
.
x
*
5
+
1
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
1
];
block_boxes
[
threadIdx
.
x
*
5
+
2
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
2
];
block_boxes
[
threadIdx
.
x
*
5
+
3
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
3
];
block_boxes
[
threadIdx
.
x
*
5
+
4
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
4
];
}
__syncthreads
();
if
(
threadIdx
.
x
<
row_size
)
{
const
int
cur_box_idx
=
THREADS_PER_BLOCK_NMS
*
row_start
+
threadIdx
.
x
;
const
float
*
cur_box
=
boxes
+
cur_box_idx
*
5
;
int
i
=
0
;
unsigned
long
long
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
__syncthreads
();
if
(
threadIdx
.
x
<
row_size
)
{
const
int
cur_box_idx
=
THREADS_PER_BLOCK_NMS
*
row_start
+
threadIdx
.
x
;
const
float
*
cur_box
=
boxes
+
cur_box_idx
*
5
;
int
i
=
0
;
unsigned
long
long
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
iou_normal
(
cur_box
,
block_boxes
+
i
*
5
)
>
nms_overlap_thresh
){
t
|=
1ULL
<<
i
;
}
}
const
int
col_blocks
=
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
);
mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
iou_normal
(
cur_box
,
block_boxes
+
i
*
5
)
>
nms_overlap_thresh
)
{
t
|=
1ULL
<<
i
;
}
}
const
int
col_blocks
=
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
);
mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
}
}
void
boxesoverlapLauncher
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_overlap
)
{
dim3
blocks
(
DIVUP
(
num_b
,
THREADS_PER_BLOCK
),
DIVUP
(
num_a
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
,
THREADS_PER_BLOCK
);
void
boxesoverlapLauncher
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_overlap
){
dim3
blocks
(
DIVUP
(
num_b
,
THREADS_PER_BLOCK
),
DIVUP
(
num_a
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
,
THREADS_PER_BLOCK
);
boxes_overlap_kernel
<<<
blocks
,
threads
>>>
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_overlap
);
boxes_overlap_kernel
<<<
blocks
,
threads
>>>
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_overlap
);
#ifdef DEBUG
cudaDeviceSynchronize
();
// for using printf in kernel function
cudaDeviceSynchronize
();
// for using printf in kernel function
#endif
}
void
boxesioubevLauncher
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_iou
){
dim3
blocks
(
DIVUP
(
num_b
,
THREADS_PER_BLOCK
),
DIVUP
(
num_a
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
,
THREADS_PER_BLOCK
);
void
boxesioubevLauncher
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_iou
)
{
dim3
blocks
(
DIVUP
(
num_b
,
THREADS_PER_BLOCK
),
DIVUP
(
num_a
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
,
THREADS_PER_BLOCK
);
boxes_iou_bev_kernel
<<<
blocks
,
threads
>>>
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_iou
);
boxes_iou_bev_kernel
<<<
blocks
,
threads
>>>
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_iou
);
}
void
nmsLauncher
(
const
float
*
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
){
dim3
blocks
(
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
),
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
));
dim3
threads
(
THREADS_PER_BLOCK_NMS
);
nms_kernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_overlap_thresh
,
boxes
,
mask
);
void
nmsLauncher
(
const
float
*
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
)
{
dim3
blocks
(
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
),
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
));
dim3
threads
(
THREADS_PER_BLOCK_NMS
);
nms_kernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_overlap_thresh
,
boxes
,
mask
);
}
void
nmsNormalLauncher
(
const
float
*
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
){
dim3
blocks
(
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
),
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
));
dim3
threads
(
THREADS_PER_BLOCK_NMS
);
nms_normal_kernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_overlap_thresh
,
boxes
,
mask
);
void
nmsNormalLauncher
(
const
float
*
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
)
{
dim3
blocks
(
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
),
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
));
dim3
threads
(
THREADS_PER_BLOCK_NMS
);
nms_normal_kernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_overlap_thresh
,
boxes
,
mask
);
}
mmdet3d/ops/roiaware_pool3d/__init__.py
View file @
f27d308f
from
.points_in_boxes
import
points_in_boxes_cpu
,
points_in_boxes_gpu
from
.points_in_boxes
import
(
points_in_boxes_batch
,
points_in_boxes_cpu
,
points_in_boxes_gpu
)
from
.roiaware_pool3d
import
RoIAwarePool3d
__all__
=
[
'RoIAwarePool3d'
,
'points_in_boxes_gpu'
,
'points_in_boxes_cpu'
]
__all__
=
[
'RoIAwarePool3d'
,
'points_in_boxes_gpu'
,
'points_in_boxes_cpu'
,
'points_in_boxes_batch'
]
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
View file @
f27d308f
...
...
@@ -53,3 +53,29 @@ def points_in_boxes_cpu(points, boxes):
point_indices
)
return
point_indices
def
points_in_boxes_batch
(
points
,
boxes
):
"""Find points that are in boxes (CUDA)
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR coordinate
boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, w, l, h, ry] in LiDAR coordinate,
(x, y, z) is the bottom center
Returns:
box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0
"""
assert
boxes
.
shape
[
0
]
==
points
.
shape
[
0
]
assert
boxes
.
shape
[
2
]
==
7
batch_size
,
num_points
,
_
=
points
.
shape
num_boxes
=
boxes
.
shape
[
1
]
box_idxs_of_pts
=
points
.
new_zeros
((
batch_size
,
num_points
,
num_boxes
),
dtype
=
torch
.
int
).
fill_
(
0
)
roiaware_pool3d_ext
.
points_in_boxes_batch
(
boxes
.
contiguous
(),
points
.
contiguous
(),
box_idxs_of_pts
)
return
box_idxs_of_pts
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
View file @
f27d308f
...
...
@@ -77,6 +77,34 @@ __global__ void points_in_boxes_kernel(int batch_size, int boxes_num,
}
}
__global__
void
points_in_boxes_batch_kernel
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
// y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
// -1
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
batch_size
||
pt_idx
>=
pts_num
)
return
;
boxes
+=
bs_idx
*
boxes_num
*
7
;
pts
+=
bs_idx
*
pts_num
*
3
+
pt_idx
*
3
;
box_idx_of_points
+=
bs_idx
*
pts_num
*
boxes_num
+
pt_idx
*
boxes_num
;
float
local_x
=
0
,
local_y
=
0
;
int
cur_in_flag
=
0
;
for
(
int
k
=
0
;
k
<
boxes_num
;
k
++
)
{
cur_in_flag
=
check_pt_in_box3d
(
pts
,
boxes
+
k
*
7
,
local_x
,
local_y
);
if
(
cur_in_flag
)
{
box_idx_of_points
[
k
]
=
1
;
}
cur_in_flag
=
0
;
}
}
void
points_in_boxes_launcher
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
)
{
...
...
@@ -102,6 +130,30 @@ void points_in_boxes_launcher(int batch_size, int boxes_num, int pts_num,
#endif
}
void
points_in_boxes_batch_launcher
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center, each box params pts: (B, npoints, 3) [x, y, z] in
// LiDAR coordinate params boxes_idx_of_points: (B, npoints), default -1
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
pts_num
,
THREADS_PER_BLOCK
),
batch_size
);
dim3
threads
(
THREADS_PER_BLOCK
);
points_in_boxes_batch_kernel
<<<
blocks
,
threads
>>>
(
batch_size
,
boxes_num
,
pts_num
,
boxes
,
pts
,
box_idx_of_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
#ifdef DEBUG
cudaDeviceSynchronize
();
// for using printf in kernel function
#endif
}
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
...
...
@@ -126,3 +178,27 @@ int points_in_boxes_gpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
return
1
;
}
int
points_in_boxes_batch
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center. params pts: (B, npoints, 3) [x, y, z] in LiDAR
// coordinate params boxes_idx_of_points: (B, npoints), default -1
CHECK_INPUT
(
boxes_tensor
);
CHECK_INPUT
(
pts_tensor
);
CHECK_INPUT
(
box_idx_of_points_tensor
);
int
batch_size
=
boxes_tensor
.
size
(
0
);
int
boxes_num
=
boxes_tensor
.
size
(
1
);
int
pts_num
=
pts_tensor
.
size
(
1
);
const
float
*
boxes
=
boxes_tensor
.
data_ptr
<
float
>
();
const
float
*
pts
=
pts_tensor
.
data_ptr
<
float
>
();
int
*
box_idx_of_points
=
box_idx_of_points_tensor
.
data_ptr
<
int
>
();
points_in_boxes_batch_launcher
(
batch_size
,
boxes_num
,
pts_num
,
boxes
,
pts
,
box_idx_of_points
);
return
1
;
}
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
View file @
f27d308f
...
...
@@ -44,6 +44,9 @@ int points_in_boxes_cpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
);
int
points_in_boxes_batch
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
);
int
roiaware_pool3d_gpu
(
at
::
Tensor
rois
,
at
::
Tensor
pts
,
at
::
Tensor
pts_feature
,
at
::
Tensor
argmax
,
at
::
Tensor
pts_idx_of_voxels
,
at
::
Tensor
pooled_features
,
int
pool_method
)
{
...
...
@@ -127,6 +130,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"roiaware pool3d backward (CUDA)"
);
m
.
def
(
"points_in_boxes_gpu"
,
&
points_in_boxes_gpu
,
"points_in_boxes_gpu forward (CUDA)"
);
m
.
def
(
"points_in_boxes_batch"
,
&
points_in_boxes_batch
,
"points_in_boxes_batch forward (CUDA)"
);
m
.
def
(
"points_in_boxes_cpu"
,
&
points_in_boxes_cpu
,
"points_in_boxes_cpu forward (CPU)"
);
}
mmdet3d/ops/sparse_block.py
View file @
f27d308f
...
...
@@ -6,6 +6,21 @@ from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
class
SparseBottleneck
(
Bottleneck
,
spconv
.
SparseModule
):
"""Sparse bottleneck block for PartA^2.
Bottleneck block implemented with submanifold sparse convolution.
Args:
inplanes (int): inplanes of block.
planes (int): planes of block.
stride (int): stride of the first block. Default: 1
downsample (None | Module): down sample module for block.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
expansion
=
4
def
__init__
(
self
,
...
...
@@ -15,10 +30,7 @@ class SparseBottleneck(Bottleneck, spconv.SparseModule):
downsample
=
None
,
conv_cfg
=
None
,
norm_cfg
=
None
):
"""Sparse bottleneck block for PartA^2.
Bottleneck block implemented with submanifold sparse convolution.
"""
spconv
.
SparseModule
.
__init__
(
self
)
Bottleneck
.
__init__
(
self
,
...
...
@@ -53,6 +65,21 @@ class SparseBottleneck(Bottleneck, spconv.SparseModule):
class
SparseBasicBlock
(
BasicBlock
,
spconv
.
SparseModule
):
"""Sparse basic block for PartA^2.
Sparse basic block implemented with submanifold sparse convolution.
Args:
inplanes (int): inplanes of block.
planes (int): planes of block.
stride (int): stride of the first block. Default: 1
downsample (None | Module): down sample module for block.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
expansion
=
1
def
__init__
(
self
,
...
...
@@ -62,10 +89,6 @@ class SparseBasicBlock(BasicBlock, spconv.SparseModule):
downsample
=
None
,
conv_cfg
=
None
,
norm_cfg
=
None
):
"""Sparse basic block for PartA^2.
Sparse basic block implemented with submanifold sparse convolution.
"""
spconv
.
SparseModule
.
__init__
(
self
)
BasicBlock
.
__init__
(
self
,
...
...
@@ -125,6 +148,7 @@ def make_sparse_convmodule(in_channels,
spconv.SparseSequential: sparse convolution module.
"""
assert
isinstance
(
order
,
tuple
)
and
len
(
order
)
<=
3
assert
set
(
order
)
|
{
'conv'
,
'norm'
,
'act'
}
==
{
'conv'
,
'norm'
,
'act'
}
conv_cfg
=
dict
(
type
=
conv_type
,
indice_key
=
indice_key
)
...
...
mmdet3d/ops/spconv/include/paramsgrid.h
View file @
f27d308f
...
...
@@ -18,13 +18,19 @@
#include <vector>
namespace
detail
{
template
<
class
T
>
int
getTotalSize
(
std
::
vector
<
T
>
arg
)
{
return
arg
.
size
();
}
template
<
class
T
>
int
getTotalSize
(
std
::
vector
<
T
>
arg
)
{
return
arg
.
size
();
}
template
<
class
T
,
class
...
TArgs
>
int
getTotalSize
(
std
::
vector
<
T
>
arg
,
std
::
vector
<
TArgs
>
...
args
)
{
return
arg
.
size
()
*
getTotalSize
(
args
...);
}
template
<
typename
T
>
int
getSize
(
std
::
vector
<
T
>
arg
)
{
return
arg
.
size
();
}
template
<
typename
T
>
int
getSize
(
std
::
vector
<
T
>
arg
)
{
return
arg
.
size
();
}
template
<
int
Idx
,
class
TT
,
class
T
>
void
assigner
(
TT
&
src
,
std
::
vector
<
int
>
counter
,
std
::
vector
<
T
>
&
arg
)
{
...
...
@@ -37,7 +43,7 @@ void assigner(TT &src, std::vector<int> counter, std::vector<T> &arg,
std
::
get
<
Idx
>
(
src
)
=
arg
[
counter
[
Idx
]];
assigner
<
Idx
+
1
>
(
src
,
counter
,
args
...);
}
}
// namespace detail
}
// namespace detail
template
<
class
...
TArgs
>
std
::
vector
<
std
::
tuple
<
TArgs
...
>>
paramsGrid
(
std
::
vector
<
TArgs
>
...
args
)
{
int
length
=
detail
::
getTotalSize
(
args
...);
...
...
mmdet3d/ops/spconv/include/prettyprint.h
View file @
f27d308f
...
...
@@ -22,424 +22,472 @@
#include <utility>
#include <valarray>
namespace
pretty_print
{
namespace
detail
{
// SFINAE type trait to detect whether T::const_iterator exists.
struct
sfinae_base
{
using
yes
=
char
;
using
no
=
yes
[
2
];
};
template
<
typename
T
>
struct
has_const_iterator
:
private
sfinae_base
{
private:
template
<
typename
C
>
static
yes
&
test
(
typename
C
::
const_iterator
*
);
template
<
typename
C
>
static
no
&
test
(...);
public:
static
const
bool
value
=
sizeof
(
test
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
using
type
=
T
;
};
template
<
typename
T
>
struct
has_begin_end
:
private
sfinae_base
{
private:
template
<
typename
C
>
static
yes
&
f
(
typename
std
::
enable_if
<
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
begin
)),
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
>::
type
*
);
template
<
typename
C
>
static
no
&
f
(...);
template
<
typename
C
>
static
yes
&
g
(
typename
std
::
enable_if
<
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
end
)),
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
,
void
>::
type
*
);
template
<
typename
C
>
static
no
&
g
(...);
public:
static
bool
const
beg_value
=
sizeof
(
f
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
static
bool
const
end_value
=
sizeof
(
g
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
};
}
// namespace detail
// Holds the delimiter values for a specific character type
template
<
typename
TChar
>
struct
delimiters_values
{
using
char_type
=
TChar
;
const
char_type
*
prefix
;
const
char_type
*
delimiter
;
const
char_type
*
postfix
;
};
// Defines the delimiter values for a specific container and character type
template
<
typename
T
,
typename
TChar
>
struct
delimiters
{
using
type
=
delimiters_values
<
TChar
>
;
static
const
type
values
;
};
// Functor to print containers. You can use this directly if you want
// to specificy a non-default delimiters type. The printing logic can
// be customized by specializing the nested template.
template
<
typename
T
,
typename
TChar
=
char
,
typename
TCharTraits
=
::
std
::
char_traits
<
TChar
>,
typename
TDelimiters
=
delimiters
<
T
,
TChar
>>
struct
print_container_helper
{
using
delimiters_type
=
TDelimiters
;
using
ostream_type
=
std
::
basic_ostream
<
TChar
,
TCharTraits
>
;
template
<
typename
U
>
struct
printer
{
static
void
print_body
(
const
U
&
c
,
ostream_type
&
stream
)
{
using
std
::
begin
;
using
std
::
end
;
auto
it
=
begin
(
c
);
const
auto
the_end
=
end
(
c
);
if
(
it
!=
the_end
)
{
for
(
;
;
)
{
stream
<<
*
it
;
if
(
++
it
==
the_end
)
break
;
if
(
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
delimiters_type
::
values
.
delimiter
;
}
}
}
};
print_container_helper
(
const
T
&
container
)
:
container_
(
container
)
{
}
inline
void
operator
()(
ostream_type
&
stream
)
const
{
if
(
delimiters_type
::
values
.
prefix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
prefix
;
printer
<
T
>::
print_body
(
container_
,
stream
);
if
(
delimiters_type
::
values
.
postfix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
postfix
;
namespace
pretty_print
{
namespace
detail
{
// SFINAE type trait to detect whether T::const_iterator exists.
struct
sfinae_base
{
using
yes
=
char
;
using
no
=
yes
[
2
];
};
template
<
typename
T
>
struct
has_const_iterator
:
private
sfinae_base
{
private:
template
<
typename
C
>
static
yes
&
test
(
typename
C
::
const_iterator
*
);
template
<
typename
C
>
static
no
&
test
(...);
public:
static
const
bool
value
=
sizeof
(
test
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
using
type
=
T
;
};
template
<
typename
T
>
struct
has_begin_end
:
private
sfinae_base
{
private:
template
<
typename
C
>
static
yes
&
f
(
typename
std
::
enable_if
<
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
begin
)),
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
>::
type
*
);
template
<
typename
C
>
static
no
&
f
(...);
template
<
typename
C
>
static
yes
&
g
(
typename
std
::
enable_if
<
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
end
)),
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
,
void
>::
type
*
);
template
<
typename
C
>
static
no
&
g
(...);
public:
static
bool
const
beg_value
=
sizeof
(
f
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
static
bool
const
end_value
=
sizeof
(
g
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
};
}
// namespace detail
// Holds the delimiter values for a specific character type
template
<
typename
TChar
>
struct
delimiters_values
{
using
char_type
=
TChar
;
const
char_type
*
prefix
;
const
char_type
*
delimiter
;
const
char_type
*
postfix
;
};
// Defines the delimiter values for a specific container and character type
template
<
typename
T
,
typename
TChar
>
struct
delimiters
{
using
type
=
delimiters_values
<
TChar
>
;
static
const
type
values
;
};
// Functor to print containers. You can use this directly if you want
// to specificy a non-default delimiters type. The printing logic can
// be customized by specializing the nested template.
template
<
typename
T
,
typename
TChar
=
char
,
typename
TCharTraits
=
::
std
::
char_traits
<
TChar
>,
typename
TDelimiters
=
delimiters
<
T
,
TChar
>>
struct
print_container_helper
{
using
delimiters_type
=
TDelimiters
;
using
ostream_type
=
std
::
basic_ostream
<
TChar
,
TCharTraits
>
;
template
<
typename
U
>
struct
printer
{
static
void
print_body
(
const
U
&
c
,
ostream_type
&
stream
)
{
using
std
::
begin
;
using
std
::
end
;
auto
it
=
begin
(
c
);
const
auto
the_end
=
end
(
c
);
if
(
it
!=
the_end
)
{
for
(;;)
{
stream
<<
*
it
;
if
(
++
it
==
the_end
)
break
;
if
(
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
delimiters_type
::
values
.
delimiter
;
}
private:
const
T
&
container_
;
};
// Specialization for pairs
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
T1
,
typename
T2
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
pair
<
T1
,
T2
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
static
void
print_body
(
const
std
::
pair
<
T1
,
T2
>
&
c
,
ostream_type
&
stream
)
{
stream
<<
c
.
first
;
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
c
.
second
;
}
};
// Specialization for tuples
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
...
Args
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
tuple
<
Args
...
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
using
element_type
=
std
::
tuple
<
Args
...
>
;
template
<
std
::
size_t
I
>
struct
Int
{
};
static
void
print_body
(
const
element_type
&
c
,
ostream_type
&
stream
)
{
tuple_print
(
c
,
stream
,
Int
<
0
>
());
}
static
void
tuple_print
(
const
element_type
&
,
ostream_type
&
,
Int
<
sizeof
...(
Args
)
>
)
{
}
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
typename
std
::
conditional
<
sizeof
...(
Args
)
!=
0
,
Int
<
0
>
,
std
::
nullptr_t
>::
type
)
{
stream
<<
std
::
get
<
0
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
1
>
());
}
template
<
std
::
size_t
N
>
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
Int
<
N
>
)
{
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
std
::
get
<
N
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
N
+
1
>
());
}
};
// Prints a print_container_helper to the specified stream.
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>
&
helper
)
{
helper
(
stream
);
return
stream
;
}
// Basic is_container template; specialize to derive from std::true_type for all desired container types
template
<
typename
T
>
struct
is_container
:
public
std
::
integral_constant
<
bool
,
detail
::
has_const_iterator
<
T
>::
value
&&
detail
::
has_begin_end
<
T
>::
beg_value
&&
detail
::
has_begin_end
<
T
>::
end_value
>
{
};
template
<
typename
T
,
std
::
size_t
N
>
struct
is_container
<
T
[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
is_container
<
char
[
N
]
>
:
std
::
false_type
{
};
template
<
typename
T
>
struct
is_container
<
std
::
valarray
<
T
>>
:
std
::
true_type
{
};
template
<
typename
T1
,
typename
T2
>
struct
is_container
<
std
::
pair
<
T1
,
T2
>>
:
std
::
true_type
{
};
template
<
typename
...
Args
>
struct
is_container
<
std
::
tuple
<
Args
...
>>
:
std
::
true_type
{
};
// Default delimiters
template
<
typename
T
>
struct
delimiters
<
T
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
char
>
delimiters
<
T
,
char
>::
values
=
{
"["
,
", "
,
"]"
};
template
<
typename
T
>
struct
delimiters
<
T
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
wchar_t
>
delimiters
<
T
,
wchar_t
>::
values
=
{
L"["
,
L", "
,
L"]"
};
// Delimiters for (multi)set and unordered_(multi)set
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
// Delimiters for pair and tuple
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
char
>
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<
::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
template
<
typename
...
Args
>
struct
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
char
>
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
...
Args
>
struct
delimiters
<
::
std
::
tuple
<
Args
...
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
tuple
<
Args
...
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
// Type-erasing helper class for easy use of custom delimiters.
// Requires TCharTraits = std::char_traits<TChar> and TChar = char or wchar_t, and MyDelims needs to be defined for TChar.
// Usage: "cout << pretty_print::custom_delims<MyDelims>(x)".
struct
custom_delims_base
{
virtual
~
custom_delims_base
()
{
}
virtual
std
::
ostream
&
stream
(
::
std
::
ostream
&
)
=
0
;
virtual
std
::
wostream
&
stream
(
::
std
::
wostream
&
)
=
0
;
};
template
<
typename
T
,
typename
Delims
>
struct
custom_delims_wrapper
:
custom_delims_base
{
custom_delims_wrapper
(
const
T
&
t_
)
:
t
(
t_
)
{
}
std
::
ostream
&
stream
(
std
::
ostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
char
,
std
::
char_traits
<
char
>
,
Delims
>
(
t
);
}
std
::
wostream
&
stream
(
std
::
wostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
wchar_t
,
std
::
char_traits
<
wchar_t
>
,
Delims
>
(
t
);
}
private:
const
T
&
t
;
};
template
<
typename
Delims
>
struct
custom_delims
{
template
<
typename
Container
>
custom_delims
(
const
Container
&
c
)
:
base
(
new
custom_delims_wrapper
<
Container
,
Delims
>
(
c
))
{
}
std
::
unique_ptr
<
custom_delims_base
>
base
;
};
template
<
typename
TChar
,
typename
TCharTraits
,
typename
Delims
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
s
,
const
custom_delims
<
Delims
>
&
p
)
{
return
p
.
base
->
stream
(
s
);
}
}
};
print_container_helper
(
const
T
&
container
)
:
container_
(
container
)
{}
inline
void
operator
()(
ostream_type
&
stream
)
const
{
if
(
delimiters_type
::
values
.
prefix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
prefix
;
printer
<
T
>::
print_body
(
container_
,
stream
);
if
(
delimiters_type
::
values
.
postfix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
postfix
;
}
private:
const
T
&
container_
;
};
// Specialization for pairs
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
T1
,
typename
T2
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
pair
<
T1
,
T2
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
static
void
print_body
(
const
std
::
pair
<
T1
,
T2
>
&
c
,
ostream_type
&
stream
)
{
stream
<<
c
.
first
;
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
c
.
second
;
}
};
// Specialization for tuples
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
...
Args
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
tuple
<
Args
...
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
using
element_type
=
std
::
tuple
<
Args
...
>
;
template
<
std
::
size_t
I
>
struct
Int
{};
static
void
print_body
(
const
element_type
&
c
,
ostream_type
&
stream
)
{
tuple_print
(
c
,
stream
,
Int
<
0
>
());
}
static
void
tuple_print
(
const
element_type
&
,
ostream_type
&
,
Int
<
sizeof
...(
Args
)
>
)
{}
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
typename
std
::
conditional
<
sizeof
...(
Args
)
!=
0
,
Int
<
0
>
,
std
::
nullptr_t
>::
type
)
{
stream
<<
std
::
get
<
0
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
1
>
());
}
template
<
std
::
size_t
N
>
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
Int
<
N
>
)
{
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
std
::
get
<
N
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
N
+
1
>
());
}
};
// Prints a print_container_helper to the specified stream.
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>
&
helper
)
{
helper
(
stream
);
return
stream
;
}
// Basic is_container template; specialize to derive from std::true_type for all
// desired container types
template
<
typename
T
>
struct
is_container
:
public
std
::
integral_constant
<
bool
,
detail
::
has_const_iterator
<
T
>::
value
&&
detail
::
has_begin_end
<
T
>::
beg_value
&&
detail
::
has_begin_end
<
T
>::
end_value
>
{};
template
<
typename
T
,
std
::
size_t
N
>
struct
is_container
<
T
[
N
]
>
:
std
::
true_type
{};
template
<
std
::
size_t
N
>
struct
is_container
<
char
[
N
]
>
:
std
::
false_type
{};
template
<
typename
T
>
struct
is_container
<
std
::
valarray
<
T
>>
:
std
::
true_type
{};
template
<
typename
T1
,
typename
T2
>
struct
is_container
<
std
::
pair
<
T1
,
T2
>>
:
std
::
true_type
{};
template
<
typename
...
Args
>
struct
is_container
<
std
::
tuple
<
Args
...
>>
:
std
::
true_type
{};
// Default delimiters
template
<
typename
T
>
struct
delimiters
<
T
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
char
>
delimiters
<
T
,
char
>::
values
=
{
"["
,
", "
,
"]"
};
template
<
typename
T
>
struct
delimiters
<
T
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
wchar_t
>
delimiters
<
T
,
wchar_t
>::
values
=
{
L"["
,
L", "
,
L"]"
};
// Delimiters for (multi)set and unordered_(multi)set
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
// Delimiters for pair and tuple
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
char
>
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
template
<
typename
...
Args
>
struct
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
char
>
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
...
Args
>
struct
delimiters
<::
std
::
tuple
<
Args
...
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
tuple
<
Args
...
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
// Type-erasing helper class for easy use of custom delimiters.
// Requires TCharTraits = std::char_traits<TChar> and TChar = char or wchar_t,
// and MyDelims needs to be defined for TChar. Usage: "cout <<
// pretty_print::custom_delims<MyDelims>(x)".
struct
custom_delims_base
{
virtual
~
custom_delims_base
()
{}
virtual
std
::
ostream
&
stream
(
::
std
::
ostream
&
)
=
0
;
virtual
std
::
wostream
&
stream
(
::
std
::
wostream
&
)
=
0
;
};
template
<
typename
T
,
typename
Delims
>
struct
custom_delims_wrapper
:
custom_delims_base
{
custom_delims_wrapper
(
const
T
&
t_
)
:
t
(
t_
)
{}
std
::
ostream
&
stream
(
std
::
ostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
char
,
std
::
char_traits
<
char
>
,
Delims
>
(
t
);
}
std
::
wostream
&
stream
(
std
::
wostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
wchar_t
,
std
::
char_traits
<
wchar_t
>
,
Delims
>
(
t
);
}
private:
const
T
&
t
;
};
template
<
typename
Delims
>
struct
custom_delims
{
template
<
typename
Container
>
custom_delims
(
const
Container
&
c
)
:
base
(
new
custom_delims_wrapper
<
Container
,
Delims
>
(
c
))
{}
std
::
unique_ptr
<
custom_delims_base
>
base
;
};
template
<
typename
TChar
,
typename
TCharTraits
,
typename
Delims
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
s
,
const
custom_delims
<
Delims
>
&
p
)
{
return
p
.
base
->
stream
(
s
);
}
// A wrapper for a C-style array given as pointer-plus-size.
// Usage: std::cout << pretty_print_array(arr, n) << std::endl;
template
<
typename
T
>
struct
array_wrapper_n
{
typedef
const
T
*
const_iterator
;
typedef
T
value_type
;
// A wrapper for a C-style array given as pointer-plus-size.
// Usage: std::cout << pretty_print_array(arr, n) << std::endl;
array_wrapper_n
(
const
T
*
const
a
,
size_t
n
)
:
_array
(
a
),
_n
(
n
)
{
}
inline
const_iterator
begin
()
const
{
return
_array
;
}
inline
const_iterator
end
()
const
{
return
_array
+
_n
;
}
template
<
typename
T
>
struct
array_wrapper_n
{
typedef
const
T
*
const_iterator
;
typedef
T
value_type
;
private:
const
T
*
const
_array
;
size_t
_n
;
};
array_wrapper_n
(
const
T
*
const
a
,
size_t
n
)
:
_array
(
a
),
_n
(
n
)
{}
inline
const_iterator
begin
()
const
{
return
_array
;
}
inline
const_iterator
end
()
const
{
return
_array
+
_n
;
}
private:
const
T
*
const
_array
;
size_t
_n
;
};
// A wrapper for hash-table based containers that offer local iterators to each bucket.
// Usage: std::cout << bucket_print(m, 4) << std::endl; (Prints bucket 5 of container m.)
// A wrapper for hash-table based containers that offer local iterators to each
// bucket. Usage: std::cout << bucket_print(m, 4) << std::endl; (Prints bucket
// 5 of container m.)
template
<
typename
T
>
struct
bucket_print_wrapper
{
typedef
typename
T
::
const_local_iterator
const_iterator
;
typedef
typename
T
::
size_type
size_type
;
template
<
typename
T
>
struct
bucket_print_wrapper
{
typedef
typename
T
::
const_local_iterator
const_iterator
;
typedef
typename
T
::
size_type
size_type
;
const_iterator
begin
()
const
{
return
m_map
.
cbegin
(
n
);
}
const_iterator
end
()
const
{
return
m_map
.
cend
(
n
);
}
const_iterator
begin
()
const
{
return
m_map
.
cbegin
(
n
);
}
bucket_print_wrapper
(
const
T
&
m
,
size_type
bucket
)
:
m_map
(
m
),
n
(
bucket
)
{
}
const_iterator
end
()
const
{
return
m_map
.
cend
(
n
);
}
private:
const
T
&
m_map
;
const
size_type
n
;
};
bucket_print_wrapper
(
const
T
&
m
,
size_type
bucket
)
:
m_map
(
m
),
n
(
bucket
)
{}
}
// namespace pretty_print
private:
const
T
&
m_map
;
const
size_type
n
;
};
}
// namespace pretty_print
// Global accessor functions for the convenience wrappers
template
<
typename
T
>
inline
pretty_print
::
array_wrapper_n
<
T
>
pretty_print_array
(
const
T
*
const
a
,
size_t
n
)
{
return
pretty_print
::
array_wrapper_n
<
T
>
(
a
,
n
);
template
<
typename
T
>
inline
pretty_print
::
array_wrapper_n
<
T
>
pretty_print_array
(
const
T
*
const
a
,
size_t
n
)
{
return
pretty_print
::
array_wrapper_n
<
T
>
(
a
,
n
);
}
template
<
typename
T
>
pretty_print
::
bucket_print_wrapper
<
T
>
bucket_print
(
const
T
&
m
,
typename
T
::
size_type
n
)
{
return
pretty_print
::
bucket_print_wrapper
<
T
>
(
m
,
n
);
template
<
typename
T
>
pretty_print
::
bucket_print_wrapper
<
T
>
bucket_print
(
const
T
&
m
,
typename
T
::
size_type
n
)
{
return
pretty_print
::
bucket_print_wrapper
<
T
>
(
m
,
n
);
}
// Main magic entry point: An overload snuck into namespace std.
// Can we do better?
namespace
std
{
// Prints a container to the stream using default delimiters
namespace
std
{
// Prints a container to the stream using default delimiters
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
>
inline
typename
enable_if
<
::
pretty_print
::
is_container
<
T
>::
value
,
basic_ostream
<
TChar
,
TCharTraits
>
&>::
type
operator
<<
(
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
T
&
container
)
{
return
stream
<<
::
pretty_print
::
print_container_helper
<
T
,
TChar
,
TCharTraits
>
(
container
);
}
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
>
inline
typename
enable_if
<::
pretty_print
::
is_container
<
T
>::
value
,
basic_ostream
<
TChar
,
TCharTraits
>
&>::
type
operator
<<
(
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
T
&
container
)
{
return
stream
<<
::
pretty_print
::
print_container_helper
<
T
,
TChar
,
TCharTraits
>
(
container
);
}
}
// namespace std
#endif // H_PRETTY_PRINT
#endif // H_PRETTY_PRINT
mmdet3d/ops/spconv/include/spconv/box_iou.h
View file @
f27d308f
...
...
@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef BOX_IOU_H
#define BOX_IOU_H
#include <pybind11/pybind11.h>
// must include pybind11/eigen.h if using eigen matrix as arguments.
#include <pybind11/numpy.h>
#include <algorithm>
#include <boost/geometry.hpp>
#include <pybind11/numpy.h>
namespace
spconv
{
// #include "voxelnet/core/cc/pybind11_helper.h"
...
...
@@ -40,9 +40,10 @@ inline py::array_t<DType> zeros(std::vector<long int> shape) {
}
template
<
typename
DType
>
py
::
array_t
<
DType
>
rbbox_iou
(
py
::
array_t
<
DType
>
box_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
py
::
array_t
<
DType
>
rbbox_iou
(
py
::
array_t
<
DType
>
box_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
namespace
bg
=
boost
::
geometry
;
typedef
bg
::
model
::
point
<
DType
,
2
,
bg
::
cs
::
cartesian
>
point_t
;
typedef
bg
::
model
::
polygon
<
point_t
>
polygon_t
;
...
...
@@ -61,8 +62,7 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners,
}
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
if
(
standup_iou_r
(
n
,
k
)
<=
standup_thresh
)
continue
;
if
(
standup_iou_r
(
n
,
k
)
<=
standup_thresh
)
continue
;
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
0
,
0
),
box_corners_r
(
n
,
0
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
1
,
0
),
box_corners_r
(
n
,
1
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
2
,
0
),
box_corners_r
(
n
,
2
,
1
)));
...
...
@@ -99,9 +99,10 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners,
}
template
<
typename
DType
>
py
::
array_t
<
DType
>
rbbox_intersection
(
py
::
array_t
<
DType
>
box_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
py
::
array_t
<
DType
>
rbbox_intersection
(
py
::
array_t
<
DType
>
box_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
namespace
bg
=
boost
::
geometry
;
typedef
bg
::
model
::
point
<
DType
,
2
,
bg
::
cs
::
cartesian
>
point_t
;
typedef
bg
::
model
::
polygon
<
point_t
>
polygon_t
;
...
...
@@ -120,8 +121,7 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne
}
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
if
(
standup_iou_r
(
n
,
k
)
<=
standup_thresh
)
continue
;
if
(
standup_iou_r
(
n
,
k
)
<=
standup_thresh
)
continue
;
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
0
,
0
),
box_corners_r
(
n
,
0
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
1
,
0
),
box_corners_r
(
n
,
1
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
2
,
0
),
box_corners_r
(
n
,
2
,
1
)));
...
...
@@ -152,6 +152,5 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne
return
overlaps
;
}
}
// namespace spconv
}
// namespace spconv
#endif
mmdet3d/ops/spconv/include/spconv/geometry.h
View file @
f27d308f
...
...
@@ -15,9 +15,10 @@
#ifndef SPCONV_GEOMETRY_H_
#define SPCONV_GEOMETRY_H_
#include <tensorview/tensorview.h>
#include <iostream>
#include <limits>
#include <tensorview/tensorview.h>
namespace
spconv
{
template
<
typename
Index
,
unsigned
NDim
>
...
...
@@ -70,8 +71,7 @@ TV_HOST_DEVICE Index getValidOutPos(const Index *input_pos,
}
out
[
pointCounter
*
(
NDim
+
1
)
+
NDim
]
=
offset
;
if
(
valid
)
++
pointCounter
;
if
(
valid
)
++
pointCounter
;
counter
[
NDim
-
1
]
+=
1
;
#pragma unroll
for
(
int
c
=
NDim
-
1
;
c
>=
0
;
--
c
)
{
...
...
@@ -128,8 +128,7 @@ TV_HOST_DEVICE Index getValidOutPosTranspose(
m
*=
kernelSize
[
j
];
}
out
[
pointCounter
*
(
NDim
+
1
)
+
NDim
]
=
offset
;
if
(
valid
)
++
pointCounter
;
if
(
valid
)
++
pointCounter
;
counter
[
NDim
-
1
]
+=
1
;
#pragma unroll
for
(
int
c
=
NDim
-
1
;
c
>=
0
;
--
c
)
{
...
...
@@ -167,7 +166,7 @@ Index getIndicePairsConv(tv::TensorView<const Index> indicesIn,
}
Index
numValidPoints
=
0
;
std
::
vector
<
Index
>
validPoints_
(
kernelVolume
*
(
NDim
+
1
));
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
pointPtr
=
nullptr
;
for
(
int
j
=
0
;
j
<
numActIn
;
++
j
)
{
batchIdx
=
indicesIn
(
j
,
0
);
...
...
@@ -218,7 +217,7 @@ Index getIndicePairsDeConv(tv::TensorView<const Index> indicesIn,
}
Index
numValidPoints
=
0
;
std
::
vector
<
Index
>
validPoints_
(
kernelVolume
*
(
NDim
+
1
));
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
pointPtr
=
nullptr
;
for
(
int
j
=
0
;
j
<
numActIn
;
++
j
)
{
batchIdx
=
indicesIn
(
j
,
0
);
...
...
@@ -252,7 +251,8 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
tv
::
TensorView
<
Index
>
indiceNum
,
const
Index
*
const
kernelSize
,
const
Index
*
const
stride
,
const
Index
*
const
padding
,
const
Index
*
dilation
,
const
Index
*
const
outSpatialShape
)
{
const
Index
*
dilation
,
const
Index
*
const
outSpatialShape
)
{
Index
numAct
=
0
;
auto
numActIn
=
indicesIn
.
dim
(
0
);
Index
batchIdx
=
0
;
...
...
@@ -269,7 +269,7 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
Index
numValidPoints
=
0
;
// Index validPoints[kernelVolume * (NDim + 1)];
std
::
vector
<
Index
>
validPoints_
(
kernelVolume
*
(
NDim
+
1
));
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
pointPtr
=
nullptr
;
Index
index
=
0
;
for
(
int
j
=
0
;
j
<
numActIn
;
++
j
)
{
...
...
@@ -296,6 +296,6 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
return
numActIn
;
}
}
// namespace spconv
}
// namespace spconv
#endif
mmdet3d/ops/spconv/include/spconv/indice.cu.h
View file @
f27d308f
...
...
@@ -14,9 +14,9 @@
#ifndef INDICE_CU_H_
#define INDICE_CU_H_
#include <tensorview/tensorview.h>
#include <tensorview/helper_kernel.cu.h>
#include <spconv/geometry.h>
#include <tensorview/helper_kernel.cu.h>
#include <tensorview/tensorview.h>
namespace
spconv
{
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
,
...
...
@@ -115,7 +115,6 @@ __global__ void assignGridAndIndiceOutKernel(
int
numAct
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
int
batchSize
)
{
Index
index
;
auto
indicesOutPtr
=
indicesOut
.
data
();
for
(
int
ix
:
tv
::
KernelLoopX
<
int
>
(
numAct
))
{
...
...
@@ -128,13 +127,11 @@ __global__ void assignGridAndIndiceOutKernel(
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
__global__
void
assignIndicePairsKernel
(
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
int
numActIn
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
)
{
__global__
void
assignIndicePairsKernel
(
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
int
numActIn
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
)
{
Index
index
;
int
kernelVolume
=
indicePairs
.
dim
(
0
);
for
(
int
ix
:
tv
::
KernelLoopX
<
int
>
(
numActIn
))
{
...
...
@@ -148,10 +145,9 @@ assignIndicePairsKernel(tv::TensorView<Index> indicesOut,
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
__global__
void
prepareSubMGridKernel
(
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
)
{
__global__
void
prepareSubMGridKernel
(
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
)
{
auto
numActIn
=
indicesIn
.
dim
(
0
);
Index
spatialVolume
=
1
;
#pragma unroll
...
...
@@ -216,10 +212,9 @@ __global__ void resetGridKernel(const Index *indicePairUnique,
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
__global__
void
resetGridSubMKernel
(
const
Index
*
indices
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
int
numAct
)
{
__global__
void
resetGridSubMKernel
(
const
Index
*
indices
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
int
numAct
)
{
int
outSpatialShapeReg
[
NDim
];
for
(
int
i
=
0
;
i
<
NDim
;
++
i
)
{
outSpatialShapeReg
[
i
]
=
outSpatialShape
[
i
];
...
...
@@ -238,6 +233,6 @@ resetGridSubMKernel(const Index *indices, tv::TensorView<IndexGrid> gridsOut,
}
}
}
// namespace spconv
}
// namespace spconv
#endif
mmdet3d/ops/spconv/include/spconv/indice.h
View file @
f27d308f
...
...
@@ -16,64 +16,65 @@
#define SPARSE_CONV_INDICE_FUNCTOR_H_
#include <tensorview/tensorview.h>
namespace
spconv
{
namespace
functor
{
namespace
spconv
{
namespace
functor
{
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateConvIndicePairFunctorP1
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
);
struct
CreateConvIndicePairFunctorP1
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
);
};
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateConvIndicePairFunctorP2
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indice
sOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
struct
CreateConvIndicePairFunctorP2
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indice
Pairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
};
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateConvIndicePairFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
struct
CreateConvIndicePairFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
};
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateSubMIndicePairFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
struct
CreateSubMIndicePairFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
};
}
// namespace functor
}
// namespace spconv
}
// namespace functor
}
// namespace spconv
#endif
mmdet3d/ops/spconv/include/spconv/maxpool.h
View file @
f27d308f
...
...
@@ -16,29 +16,24 @@
#define SPARSE_MAXPOOL_FUNCTOR_H_
#include <tensorview/tensorview.h>
namespace
spconv
{
namespace
functor
{
namespace
spconv
{
namespace
functor
{
template
<
typename
Device
,
typename
T
,
typename
Index
>
struct
SparseMaxPoolForwardFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
T
>
outFeatures
,
struct
SparseMaxPoolForwardFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
T
>
outFeatures
,
tv
::
TensorView
<
const
T
>
inFeatures
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
);
};
template
<
typename
Device
,
typename
T
,
typename
Index
>
struct
SparseMaxPoolBackwardFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
T
>
outFeatures
,
struct
SparseMaxPoolBackwardFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
T
>
outFeatures
,
tv
::
TensorView
<
const
T
>
inFeatures
,
tv
::
TensorView
<
const
T
>
dout
,
tv
::
TensorView
<
T
>
din
,
tv
::
TensorView
<
const
T
>
dout
,
tv
::
TensorView
<
T
>
din
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
);
};
}
// namespace functor
}
// namespace spconv
}
// namespace functor
}
// namespace spconv
#endif
mmdet3d/ops/spconv/include/spconv/mp_helper.h
View file @
f27d308f
...
...
@@ -4,7 +4,8 @@
#include <utility>
namespace
spconv
{
template
<
class
...
T
>
struct
mp_list
{};
template
<
class
...
T
>
struct
mp_list
{};
template
<
class
T
,
T
...
I
>
using
mp_list_c
=
mp_list
<
std
::
integral_constant
<
T
,
I
>
...
>
;
...
...
@@ -16,15 +17,17 @@ constexpr F mp_for_each_impl(mp_list<T...>, F &&f) {
return
std
::
initializer_list
<
int
>
{(
f
(
T
()),
0
)...},
std
::
forward
<
F
>
(
f
);
}
template
<
class
F
>
constexpr
F
mp_for_each_impl
(
mp_list
<>
,
F
&&
f
)
{
template
<
class
F
>
constexpr
F
mp_for_each_impl
(
mp_list
<>
,
F
&&
f
)
{
return
std
::
forward
<
F
>
(
f
);
}
}
// namespace detail
}
// namespace detail
namespace
detail
{
template
<
class
A
,
template
<
class
...
>
class
B
>
struct
mp_rename_impl
{
template
<
class
A
,
template
<
class
...
>
class
B
>
struct
mp_rename_impl
{
// An error "no type named 'type'" here means that the first argument to
// mp_rename is not a list
};
...
...
@@ -34,14 +37,15 @@ struct mp_rename_impl<A<T...>, B> {
using
type
=
B
<
T
...
>
;
};
}
// namespace detail
}
// namespace detail
template
<
class
A
,
template
<
class
...
>
class
B
>
using
mp_rename
=
typename
detail
::
mp_rename_impl
<
A
,
B
>::
type
;
template
<
class
L
,
class
F
>
constexpr
F
mp_for_each
(
F
&&
f
)
{
template
<
class
L
,
class
F
>
constexpr
F
mp_for_each
(
F
&&
f
)
{
return
detail
::
mp_for_each_impl
(
mp_rename
<
L
,
mp_list
>
(),
std
::
forward
<
F
>
(
f
));
}
}
// namespace spconv
}
// namespace spconv
#endif
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment