Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
f27d308f
Commit
f27d308f
authored
Jun 07, 2020
by
yinchimaoliang
Browse files
merge master
parents
c66ae813
27ebcfac
Changes
80
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1457 additions
and
1163 deletions
+1457
-1163
mmdet3d/ops/gather_points/src/gather_points_cuda.cu
mmdet3d/ops/gather_points/src/gather_points_cuda.cu
+78
-68
mmdet3d/ops/group_points/src/group_points.cpp
mmdet3d/ops/group_points/src/group_points.cpp
+34
-28
mmdet3d/ops/group_points/src/group_points_cuda.cu
mmdet3d/ops/group_points/src/group_points_cuda.cu
+77
-64
mmdet3d/ops/interpolate/src/interpolate.cpp
mmdet3d/ops/interpolate/src/interpolate.cpp
+62
-53
mmdet3d/ops/interpolate/src/three_interpolate_cuda.cu
mmdet3d/ops/interpolate/src/three_interpolate_cuda.cu
+92
-80
mmdet3d/ops/interpolate/src/three_nn_cuda.cu
mmdet3d/ops/interpolate/src/three_nn_cuda.cu
+68
-56
mmdet3d/ops/iou3d/src/iou3d_kernel.cu
mmdet3d/ops/iou3d/src/iou3d_kernel.cu
+345
-296
mmdet3d/ops/roiaware_pool3d/__init__.py
mmdet3d/ops/roiaware_pool3d/__init__.py
+6
-2
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
+26
-0
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
+76
-0
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
+5
-0
mmdet3d/ops/sparse_block.py
mmdet3d/ops/sparse_block.py
+31
-7
mmdet3d/ops/spconv/include/paramsgrid.h
mmdet3d/ops/spconv/include/paramsgrid.h
+9
-3
mmdet3d/ops/spconv/include/prettyprint.h
mmdet3d/ops/spconv/include/prettyprint.h
+442
-394
mmdet3d/ops/spconv/include/spconv/box_iou.h
mmdet3d/ops/spconv/include/spconv/box_iou.h
+13
-14
mmdet3d/ops/spconv/include/spconv/geometry.h
mmdet3d/ops/spconv/include/spconv/geometry.h
+10
-10
mmdet3d/ops/spconv/include/spconv/indice.cu.h
mmdet3d/ops/spconv/include/spconv/indice.cu.h
+14
-19
mmdet3d/ops/spconv/include/spconv/indice.h
mmdet3d/ops/spconv/include/spconv/indice.h
+49
-48
mmdet3d/ops/spconv/include/spconv/maxpool.h
mmdet3d/ops/spconv/include/spconv/maxpool.h
+9
-14
mmdet3d/ops/spconv/include/spconv/mp_helper.h
mmdet3d/ops/spconv/include/spconv/mp_helper.h
+11
-7
No files found.
mmdet3d/ops/gather_points/src/gather_points_cuda.cu
View file @
f27d308f
...
...
@@ -3,82 +3,92 @@
#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
gather_points_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
out
)
{
// points: (B, C, N)
// idx: (B, M)
// output:
// out: (B, C, M)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
idx
+=
bs_idx
*
m
+
pt_idx
;
points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
[
0
]
=
points
[
idx
[
0
]];
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
out
)
{
// points: (B, C, N)
// idx: (B, M)
// output:
// out: (B, C, M)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
idx
+=
bs_idx
*
m
+
pt_idx
;
points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
[
0
]
=
points
[
idx
[
0
]];
}
void
gather_points_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, N)
// idx: (B, npoints)
// output:
// out: (B, C, npoints)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
gather_points_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
points
,
idx
,
out
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, N)
// idx: (B, npoints)
// output:
// out: (B, C, npoints)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
gather_points_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
points
,
idx
,
out
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
__global__
void
gather_points_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, M)
// idx: (B, M)
// output:
// grad_points: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
grad_out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
idx
+=
bs_idx
*
m
+
pt_idx
;
grad_points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]);
__global__
void
gather_points_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, M)
// idx: (B, M)
// output:
// grad_points: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
m
)
return
;
grad_out
+=
bs_idx
*
c
*
m
+
c_idx
*
m
+
pt_idx
;
idx
+=
bs_idx
*
m
+
pt_idx
;
grad_points
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]);
}
void
gather_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, npoints)
// idx: (B, npoints)
// output:
// grad_points: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
gather_points_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, npoints)
// idx: (B, npoints)
// output:
// grad_points: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
gather_points_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
grad_out
,
idx
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/group_points/src/group_points.cpp
View file @
f27d308f
#include <
torch/serialize/tensor
.h>
#include <
THC/THC
.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <THC/THC.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern
THCState
*
state
;
int
group_points_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
);
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
);
void
group_points_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
);
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
);
int
group_points_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
);
void
group_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
);
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
);
void
group_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
);
int
group_points_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
)
{
float
*
grad_points
=
grad_points_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
const
float
*
grad_out
=
grad_out_tensor
.
data
<
float
>
();
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
grad_points_tensor
)
{
float
*
grad_points
=
grad_points_tensor
.
data
_ptr
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
_ptr
<
int
>
();
const
float
*
grad_out
=
grad_out_tensor
.
data
_ptr
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
group_points_grad_kernel_launcher
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
,
stream
);
return
1
;
group_points_grad_kernel_launcher
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
,
stream
);
return
1
;
}
int
group_points_wrapper
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
)
{
const
float
*
points
=
points_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
float
*
out
=
out_tensor
.
data
<
float
>
();
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
out_tensor
)
{
const
float
*
points
=
points_tensor
.
data
_ptr
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
_ptr
<
int
>
();
float
*
out
=
out_tensor
.
data
_ptr
<
float
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
group_points_kernel_launcher
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
,
stream
);
return
1
;
group_points_kernel_launcher
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
,
stream
);
return
1
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
group_points_wrapper
,
"group_points_wrapper"
);
m
.
def
(
"backward"
,
&
group_points_grad_wrapper
,
"group_points_grad_wrapper"
);
m
.
def
(
"forward"
,
&
group_points_wrapper
,
"group_points_wrapper"
);
m
.
def
(
"backward"
,
&
group_points_grad_wrapper
,
"group_points_grad_wrapper"
);
}
mmdet3d/ops/group_points/src/group_points_cuda.cu
View file @
f27d308f
...
...
@@ -2,84 +2,97 @@
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m,
n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
group_points_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, npoints, nsample)
// idx: (B, npoints, nsample)
// output:
// grad_points: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
pt_idx
=
index
/
nsample
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
__global__
void
group_points_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, npoints, nsample)
// idx: (B, npoints, nsample)
// output:
// grad_points: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
pt_idx
=
index
/
nsample
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
int
sample_idx
=
index
%
nsample
;
grad_out
+=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
int
sample_idx
=
index
%
nsample
;
grad_out
+=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
atomicAdd
(
grad_points
+
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
]
,
grad_out
[
0
]);
atomicAdd
(
grad_points
+
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
],
grad_out
[
0
]);
}
void
group_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, npoints, nsample)
// idx: (B, npoints, nsample)
// output:
// grad_points: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
void
group_points_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
grad_out
,
const
int
*
idx
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, npoints, nsample)
// idx: (B, npoints, nsample)
// output:
// grad_points: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
group_points_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
);
group_points_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
grad_out
,
idx
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
__global__
void
group_points_kernel
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
out
)
{
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
pt_idx
=
index
/
nsample
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
__global__
void
group_points_kernel
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
float
*
__restrict__
out
)
{
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
pt_idx
=
index
/
nsample
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
npoints
)
return
;
int
sample_idx
=
index
%
nsample
;
int
sample_idx
=
index
%
nsample
;
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
int
in_idx
=
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
];
int
out_idx
=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
idx
+=
bs_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
int
in_idx
=
bs_idx
*
c
*
n
+
c_idx
*
n
+
idx
[
0
];
int
out_idx
=
bs_idx
*
c
*
npoints
*
nsample
+
c_idx
*
npoints
*
nsample
+
pt_idx
*
nsample
+
sample_idx
;
out
[
out_idx
]
=
points
[
in_idx
];
out
[
out_idx
]
=
points
[
in_idx
];
}
void
group_points_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
npoints
,
int
nsample
,
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
const
float
*
points
,
const
int
*
idx
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
npoints
*
nsample
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
group_points_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
group_points_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
npoints
,
nsample
,
points
,
idx
,
out
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/interpolate/src/interpolate.cpp
View file @
f27d308f
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
extern
THCState
*
state
;
void
three_nn_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
unknown_tensor
,
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
);
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
);
void
three_nn_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
unknown
,
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
);
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
);
void
three_interpolate_wrapper
(
int
b
,
int
c
,
int
m
,
int
n
,
at
::
Tensor
points
_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
);
void
three_interpolate_wrapper
(
int
b
,
int
c
,
int
m
,
int
n
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx
_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
);
void
three_interpolate_kernel_launcher
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
);
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
);
void
three_interpolate_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
m
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
);
void
three_interpolate_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
m
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
);
void
three_interpolate_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
);
void
three_interpolate_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
);
void
three_nn_wrapper
(
int
b
,
int
n
,
int
m
,
at
::
Tensor
unknown_tensor
,
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
)
{
const
float
*
unknown
=
unknown_tensor
.
data
<
float
>
();
const
float
*
known
=
known_tensor
.
data
<
float
>
();
float
*
dist2
=
dist2_tensor
.
data
<
float
>
();
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_nn_kernel_launcher
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
,
stream
);
at
::
Tensor
known_tensor
,
at
::
Tensor
dist2_tensor
,
at
::
Tensor
idx_tensor
)
{
const
float
*
unknown
=
unknown_tensor
.
data_ptr
<
float
>
();
const
float
*
known
=
known_tensor
.
data_ptr
<
float
>
();
float
*
dist2
=
dist2_tensor
.
data_ptr
<
float
>
();
int
*
idx
=
idx_tensor
.
data_ptr
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_nn_kernel_launcher
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
,
stream
);
}
void
three_interpolate_wrapper
(
int
b
,
int
c
,
int
m
,
int
n
,
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
)
{
const
float
*
points
=
points_tensor
.
data
<
float
>
();
const
float
*
weight
=
weight_tensor
.
data
<
float
>
();
float
*
out
=
out_tensor
.
data
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_kernel_launcher
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
,
stream
);
at
::
Tensor
points_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
out_tensor
)
{
const
float
*
points
=
points_tensor
.
data_ptr
<
float
>
();
const
float
*
weight
=
weight_tensor
.
data_ptr
<
float
>
();
float
*
out
=
out_tensor
.
data_ptr
<
float
>
();
const
int
*
idx
=
idx_tensor
.
data_ptr
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_kernel_launcher
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
,
stream
);
}
void
three_interpolate_grad_wrapper
(
int
b
,
int
c
,
int
n
,
int
m
,
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
)
{
const
float
*
grad_out
=
grad_ou
t_tensor
.
data
<
float
>
();
const
float
*
weight
=
weight
_tensor
.
data
<
float
>
();
float
*
grad_points
=
grad_points
_tensor
.
data
<
floa
t
>
();
const
int
*
idx
=
idx_tensor
.
data
<
int
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_grad_kernel_launcher
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
,
stream
);
at
::
Tensor
grad_out_tensor
,
at
::
Tensor
idx_tensor
,
at
::
Tensor
weight_tensor
,
at
::
Tensor
grad_points_tensor
)
{
const
float
*
grad_out
=
grad_out_tensor
.
data_ptr
<
float
>
();
const
float
*
weight
=
weigh
t_tensor
.
data
_ptr
<
float
>
();
float
*
grad_points
=
grad_points
_tensor
.
data
_ptr
<
float
>
();
const
int
*
idx
=
idx
_tensor
.
data
_ptr
<
in
t
>
();
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
three_interpolate_grad_kernel_launcher
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
,
stream
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"three_nn_wrapper"
,
&
three_nn_wrapper
,
"three_nn_wrapper"
);
m
.
def
(
"three_interpolate_wrapper"
,
&
three_interpolate_wrapper
,
"three_interpolate_wrapper"
);
m
.
def
(
"three_interpolate_grad_wrapper"
,
&
three_interpolate_grad_wrapper
,
"three_interpolate_grad_wrapper"
);
m
.
def
(
"three_nn_wrapper"
,
&
three_nn_wrapper
,
"three_nn_wrapper"
);
m
.
def
(
"three_interpolate_wrapper"
,
&
three_interpolate_wrapper
,
"three_interpolate_wrapper"
);
m
.
def
(
"three_interpolate_grad_wrapper"
,
&
three_interpolate_grad_wrapper
,
"three_interpolate_grad_wrapper"
);
}
mmdet3d/ops/interpolate/src/three_interpolate_cuda.cu
View file @
f27d308f
...
...
@@ -3,91 +3,103 @@
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
three_interpolate_kernel
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
out
)
{
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
weight
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
points
+=
bs_idx
*
c
*
m
+
c_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
out
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
[
pt_idx
]
=
weight
[
0
]
*
points
[
idx
[
0
]]
+
weight
[
1
]
*
points
[
idx
[
1
]]
+
weight
[
2
]
*
points
[
idx
[
2
]];
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
three_interpolate_kernel
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
__restrict__
points
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
out
)
{
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
int
bs_idx
=
blockIdx
.
z
;
int
c_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
weight
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
points
+=
bs_idx
*
c
*
m
+
c_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
out
+=
bs_idx
*
c
*
n
+
c_idx
*
n
;
out
[
pt_idx
]
=
weight
[
0
]
*
points
[
idx
[
0
]]
+
weight
[
1
]
*
points
[
idx
[
1
]]
+
weight
[
2
]
*
points
[
idx
[
2
]];
}
void
three_interpolate_kernel_launcher
(
int
b
,
int
c
,
int
m
,
int
n
,
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
const
float
*
points
,
const
int
*
idx
,
const
float
*
weight
,
float
*
out
,
cudaStream_t
stream
)
{
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
m
,
n
,
points
,
idx
,
weight
,
out
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
__global__
void
three_interpolate_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
grad_points
)
{
// grad_out: (B, C, N)
//
weigh
t: (B,
N
,
3
)
//
output:
//
grad_points: (B, C, M)
int
bs_idx
=
blockIdx
.
z
;
int
c
_idx
=
blockIdx
.
y
;
int
pt
_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
grad_out
+=
bs_idx
*
c
*
n
+
c_idx
*
n
+
pt_idx
;
weigh
t
+=
bs_idx
*
n
*
3
+
pt
_idx
*
3
;
grad_points
+=
bs_idx
*
c
*
m
+
c
_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt
_idx
*
3
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]
*
weight
[
0
]);
atomicAdd
(
grad_points
+
idx
[
1
],
grad_out
[
0
]
*
weight
[
1
]);
atomicAdd
(
grad_points
+
idx
[
2
],
grad_out
[
0
]
*
weight
[
2
]);
__global__
void
three_interpolate_grad_kernel
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
__restrict__
grad_out
,
const
int
*
__restrict__
idx
,
const
float
*
__restrict__
weight
,
float
*
__restrict__
grad_points
)
{
//
grad_ou
t: (B,
C
,
N
)
//
weight: (B, N, 3)
//
output:
// grad_points: (B, C, M)
int
bs
_idx
=
blockIdx
.
z
;
int
c
_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
c_idx
>=
c
||
pt_idx
>=
n
)
return
;
grad_ou
t
+=
bs_idx
*
c
*
n
+
c
_idx
*
n
+
pt_idx
;
weight
+=
bs_idx
*
n
*
3
+
pt
_idx
*
3
;
grad_points
+=
bs_idx
*
c
*
m
+
c
_idx
*
m
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
atomicAdd
(
grad_points
+
idx
[
0
],
grad_out
[
0
]
*
weight
[
0
]);
atomicAdd
(
grad_points
+
idx
[
1
],
grad_out
[
0
]
*
weight
[
1
]);
atomicAdd
(
grad_points
+
idx
[
2
],
grad_out
[
0
]
*
weight
[
2
]);
}
void
three_interpolate_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
void
three_interpolate_grad_kernel_launcher
(
int
b
,
int
c
,
int
n
,
int
m
,
const
float
*
grad_out
,
const
int
*
idx
,
const
float
*
weight
,
float
*
grad_points
,
cudaStream_t
stream
)
{
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
c
,
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_interpolate_grad_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
c
,
n
,
m
,
grad_out
,
idx
,
weight
,
grad_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/interpolate/src/three_nn_cuda.cu
View file @
f27d308f
...
...
@@ -3,72 +3,84 @@
#include <stdlib.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DIVUP(m,
n) ((m) / (n) + ((m) % (n) > 0))
__global__
void
three_nn_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
unknown
,
const
float
*
__restrict__
known
,
float
*
__restrict__
dist2
,
int
*
__restrict__
idx
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
__global__
void
three_nn_kernel
(
int
b
,
int
n
,
int
m
,
const
float
*
__restrict__
unknown
,
const
float
*
__restrict__
known
,
float
*
__restrict__
dist2
,
int
*
__restrict__
idx
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
pt_idx
>=
n
)
return
;
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
pt_idx
>=
n
)
return
;
unknown
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
known
+=
bs_idx
*
m
*
3
;
dist2
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
unknown
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
known
+=
bs_idx
*
m
*
3
;
dist2
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
idx
+=
bs_idx
*
n
*
3
+
pt_idx
*
3
;
float
ux
=
unknown
[
0
];
float
uy
=
unknown
[
1
];
float
uz
=
unknown
[
2
];
float
ux
=
unknown
[
0
];
float
uy
=
unknown
[
1
];
float
uz
=
unknown
[
2
];
double
best1
=
1e40
,
best2
=
1e40
,
best3
=
1e40
;
int
besti1
=
0
,
besti2
=
0
,
besti3
=
0
;
for
(
int
k
=
0
;
k
<
m
;
++
k
)
{
float
x
=
known
[
k
*
3
+
0
];
float
y
=
known
[
k
*
3
+
1
];
float
z
=
known
[
k
*
3
+
2
];
float
d
=
(
ux
-
x
)
*
(
ux
-
x
)
+
(
uy
-
y
)
*
(
uy
-
y
)
+
(
uz
-
z
)
*
(
uz
-
z
);
if
(
d
<
best1
)
{
best3
=
best2
;
besti3
=
besti2
;
best2
=
best1
;
besti2
=
besti1
;
best1
=
d
;
besti1
=
k
;
}
else
if
(
d
<
best2
)
{
best3
=
best2
;
besti3
=
besti2
;
best2
=
d
;
besti2
=
k
;
}
else
if
(
d
<
best3
)
{
best3
=
d
;
besti3
=
k
;
}
double
best1
=
1e40
,
best2
=
1e40
,
best3
=
1e40
;
int
besti1
=
0
,
besti2
=
0
,
besti3
=
0
;
for
(
int
k
=
0
;
k
<
m
;
++
k
)
{
float
x
=
known
[
k
*
3
+
0
];
float
y
=
known
[
k
*
3
+
1
];
float
z
=
known
[
k
*
3
+
2
];
float
d
=
(
ux
-
x
)
*
(
ux
-
x
)
+
(
uy
-
y
)
*
(
uy
-
y
)
+
(
uz
-
z
)
*
(
uz
-
z
);
if
(
d
<
best1
)
{
best3
=
best2
;
besti3
=
besti2
;
best2
=
best1
;
besti2
=
besti1
;
best1
=
d
;
besti1
=
k
;
}
else
if
(
d
<
best2
)
{
best3
=
best2
;
besti3
=
besti2
;
best2
=
d
;
besti2
=
k
;
}
else
if
(
d
<
best3
)
{
best3
=
d
;
besti3
=
k
;
}
dist2
[
0
]
=
best1
;
dist2
[
1
]
=
best2
;
dist2
[
2
]
=
best3
;
idx
[
0
]
=
besti1
;
idx
[
1
]
=
besti2
;
idx
[
2
]
=
besti3
;
}
dist2
[
0
]
=
best1
;
dist2
[
1
]
=
best2
;
dist2
[
2
]
=
best3
;
idx
[
0
]
=
besti1
;
idx
[
1
]
=
besti2
;
idx
[
2
]
=
besti3
;
}
void
three_nn_kernel_launcher
(
int
b
,
int
n
,
int
m
,
const
float
*
unknown
,
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
const
float
*
known
,
float
*
dist2
,
int
*
idx
,
cudaStream_t
stream
)
{
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
n
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
three_nn_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
);
three_nn_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
unknown
,
known
,
dist2
,
idx
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
}
mmdet3d/ops/iou3d/src/iou3d_kernel.cu
View file @
f27d308f
...
...
@@ -6,376 +6,425 @@
const
int
THREADS_PER_BLOCK_NMS
=
sizeof
(
unsigned
long
long
)
*
8
;
const
float
EPS
=
1e-8
;
struct
Point
{
float
x
,
y
;
__device__
Point
()
{}
__device__
Point
(
double
_x
,
double
_y
){
x
=
_x
,
y
=
_y
;
}
__device__
void
set
(
float
_x
,
float
_y
){
x
=
_x
;
y
=
_y
;
}
__device__
Point
operator
+
(
const
Point
&
b
)
const
{
return
Point
(
x
+
b
.
x
,
y
+
b
.
y
);
}
__device__
Point
operator
-
(
const
Point
&
b
)
const
{
return
Point
(
x
-
b
.
x
,
y
-
b
.
y
);
}
float
x
,
y
;
__device__
Point
()
{}
__device__
Point
(
double
_x
,
double
_y
)
{
x
=
_x
,
y
=
_y
;
}
__device__
void
set
(
float
_x
,
float
_y
)
{
x
=
_x
;
y
=
_y
;
}
__device__
Point
operator
+
(
const
Point
&
b
)
const
{
return
Point
(
x
+
b
.
x
,
y
+
b
.
y
);
}
__device__
Point
operator
-
(
const
Point
&
b
)
const
{
return
Point
(
x
-
b
.
x
,
y
-
b
.
y
);
}
};
__device__
inline
float
cross
(
const
Point
&
a
,
const
Point
&
b
){
return
a
.
x
*
b
.
y
-
a
.
y
*
b
.
x
;
__device__
inline
float
cross
(
const
Point
&
a
,
const
Point
&
b
)
{
return
a
.
x
*
b
.
y
-
a
.
y
*
b
.
x
;
}
__device__
inline
float
cross
(
const
Point
&
p1
,
const
Point
&
p2
,
const
Point
&
p0
){
return
(
p1
.
x
-
p0
.
x
)
*
(
p2
.
y
-
p0
.
y
)
-
(
p2
.
x
-
p0
.
x
)
*
(
p1
.
y
-
p0
.
y
);
__device__
inline
float
cross
(
const
Point
&
p1
,
const
Point
&
p2
,
const
Point
&
p0
)
{
return
(
p1
.
x
-
p0
.
x
)
*
(
p2
.
y
-
p0
.
y
)
-
(
p2
.
x
-
p0
.
x
)
*
(
p1
.
y
-
p0
.
y
);
}
__device__
int
check_rect_cross
(
const
Point
&
p1
,
const
Point
&
p2
,
const
Point
&
q1
,
const
Point
&
q2
){
int
ret
=
min
(
p1
.
x
,
p2
.
x
)
<=
max
(
q1
.
x
,
q2
.
x
)
&&
min
(
q1
.
x
,
q2
.
x
)
<=
max
(
p1
.
x
,
p2
.
x
)
&&
min
(
p1
.
y
,
p2
.
y
)
<=
max
(
q1
.
y
,
q2
.
y
)
&&
min
(
q1
.
y
,
q2
.
y
)
<=
max
(
p1
.
y
,
p2
.
y
);
return
ret
;
__device__
int
check_rect_cross
(
const
Point
&
p1
,
const
Point
&
p2
,
const
Point
&
q1
,
const
Point
&
q2
)
{
int
ret
=
min
(
p1
.
x
,
p2
.
x
)
<=
max
(
q1
.
x
,
q2
.
x
)
&&
min
(
q1
.
x
,
q2
.
x
)
<=
max
(
p1
.
x
,
p2
.
x
)
&&
min
(
p1
.
y
,
p2
.
y
)
<=
max
(
q1
.
y
,
q2
.
y
)
&&
min
(
q1
.
y
,
q2
.
y
)
<=
max
(
p1
.
y
,
p2
.
y
);
return
ret
;
}
__device__
inline
int
check_in_box2d
(
const
float
*
box
,
const
Point
&
p
){
//params: box (5) [x1, y1, x2, y2, angle]
const
float
MARGIN
=
1e-5
;
float
center_x
=
(
box
[
0
]
+
box
[
2
])
/
2
;
float
center_y
=
(
box
[
1
]
+
box
[
3
])
/
2
;
float
angle_cos
=
cos
(
-
box
[
4
]),
angle_sin
=
sin
(
-
box
[
4
]);
// rotate the point in the opposite direction of box
float
rot_x
=
(
p
.
x
-
center_x
)
*
angle_cos
+
(
p
.
y
-
center_y
)
*
angle_sin
+
center_x
;
float
rot_y
=
-
(
p
.
x
-
center_x
)
*
angle_sin
+
(
p
.
y
-
center_y
)
*
angle_cos
+
center_y
;
__device__
inline
int
check_in_box2d
(
const
float
*
box
,
const
Point
&
p
)
{
// params: box (5) [x1, y1, x2, y2, angle]
const
float
MARGIN
=
1e-5
;
float
center_x
=
(
box
[
0
]
+
box
[
2
])
/
2
;
float
center_y
=
(
box
[
1
]
+
box
[
3
])
/
2
;
float
angle_cos
=
cos
(
-
box
[
4
]),
angle_sin
=
sin
(
-
box
[
4
]);
// rotate the point in the opposite direction of box
float
rot_x
=
(
p
.
x
-
center_x
)
*
angle_cos
+
(
p
.
y
-
center_y
)
*
angle_sin
+
center_x
;
float
rot_y
=
-
(
p
.
x
-
center_x
)
*
angle_sin
+
(
p
.
y
-
center_y
)
*
angle_cos
+
center_y
;
#ifdef DEBUG
printf
(
"box: (%.3f, %.3f, %.3f, %.3f, %.3f)
\n
"
,
box
[
0
],
box
[
1
],
box
[
2
],
box
[
3
],
box
[
4
]);
printf
(
"center: (%.3f, %.3f), cossin(%.3f, %.3f), src(%.3f, %.3f), rot(%.3f, %.3f)
\n
"
,
center_x
,
center_y
,
angle_cos
,
angle_sin
,
p
.
x
,
p
.
y
,
rot_x
,
rot_y
);
printf
(
"box: (%.3f, %.3f, %.3f, %.3f, %.3f)
\n
"
,
box
[
0
],
box
[
1
],
box
[
2
],
box
[
3
],
box
[
4
]);
printf
(
"center: (%.3f, %.3f), cossin(%.3f, %.3f), src(%.3f, %.3f), rot(%.3f, "
"%.3f)
\n
"
,
center_x
,
center_y
,
angle_cos
,
angle_sin
,
p
.
x
,
p
.
y
,
rot_x
,
rot_y
);
#endif
return
(
rot_x
>
box
[
0
]
-
MARGIN
&&
rot_x
<
box
[
2
]
+
MARGIN
&&
rot_y
>
box
[
1
]
-
MARGIN
&&
rot_y
<
box
[
3
]
+
MARGIN
);
return
(
rot_x
>
box
[
0
]
-
MARGIN
&&
rot_x
<
box
[
2
]
+
MARGIN
&&
rot_y
>
box
[
1
]
-
MARGIN
&&
rot_y
<
box
[
3
]
+
MARGIN
);
}
__device__
inline
int
intersection
(
const
Point
&
p1
,
const
Point
&
p0
,
const
Point
&
q1
,
const
Point
&
q0
,
Point
&
ans
){
// fast exclusion
if
(
check_rect_cross
(
p0
,
p1
,
q0
,
q1
)
==
0
)
return
0
;
__device__
inline
int
intersection
(
const
Point
&
p1
,
const
Point
&
p0
,
const
Point
&
q1
,
const
Point
&
q0
,
Point
&
ans
)
{
// fast exclusion
if
(
check_rect_cross
(
p0
,
p1
,
q0
,
q1
)
==
0
)
return
0
;
// check cross standing
float
s1
=
cross
(
q0
,
p1
,
p0
);
float
s2
=
cross
(
p1
,
q1
,
p0
);
float
s3
=
cross
(
p0
,
q1
,
q0
);
float
s4
=
cross
(
q1
,
p1
,
q0
);
// check cross standing
float
s1
=
cross
(
q0
,
p1
,
p0
);
float
s2
=
cross
(
p1
,
q1
,
p0
);
float
s3
=
cross
(
p0
,
q1
,
q0
);
float
s4
=
cross
(
q1
,
p1
,
q0
);
if
(
!
(
s1
*
s2
>
0
&&
s3
*
s4
>
0
))
return
0
;
if
(
!
(
s1
*
s2
>
0
&&
s3
*
s4
>
0
))
return
0
;
// calculate intersection of two lines
float
s5
=
cross
(
q1
,
p1
,
p0
);
if
(
fabs
(
s5
-
s1
)
>
EPS
){
ans
.
x
=
(
s5
*
q0
.
x
-
s1
*
q1
.
x
)
/
(
s5
-
s1
);
ans
.
y
=
(
s5
*
q0
.
y
-
s1
*
q1
.
y
)
/
(
s5
-
s1
);
// calculate intersection of two lines
float
s5
=
cross
(
q1
,
p1
,
p0
);
if
(
fabs
(
s5
-
s1
)
>
EPS
)
{
ans
.
x
=
(
s5
*
q0
.
x
-
s1
*
q1
.
x
)
/
(
s5
-
s1
);
ans
.
y
=
(
s5
*
q0
.
y
-
s1
*
q1
.
y
)
/
(
s5
-
s1
);
}
else
{
float
a0
=
p0
.
y
-
p1
.
y
,
b0
=
p1
.
x
-
p0
.
x
,
c0
=
p0
.
x
*
p1
.
y
-
p1
.
x
*
p0
.
y
;
float
a1
=
q0
.
y
-
q1
.
y
,
b1
=
q1
.
x
-
q0
.
x
,
c1
=
q0
.
x
*
q1
.
y
-
q1
.
x
*
q0
.
y
;
float
D
=
a0
*
b1
-
a1
*
b0
;
}
else
{
float
a0
=
p0
.
y
-
p1
.
y
,
b0
=
p1
.
x
-
p0
.
x
,
c0
=
p0
.
x
*
p1
.
y
-
p1
.
x
*
p0
.
y
;
float
a1
=
q0
.
y
-
q1
.
y
,
b1
=
q1
.
x
-
q0
.
x
,
c1
=
q0
.
x
*
q1
.
y
-
q1
.
x
*
q0
.
y
;
float
D
=
a0
*
b1
-
a1
*
b0
;
ans
.
x
=
(
b0
*
c1
-
b1
*
c0
)
/
D
;
ans
.
y
=
(
a1
*
c0
-
a0
*
c1
)
/
D
;
}
ans
.
x
=
(
b0
*
c1
-
b1
*
c0
)
/
D
;
ans
.
y
=
(
a1
*
c0
-
a0
*
c1
)
/
D
;
}
return
1
;
return
1
;
}
__device__
inline
void
rotate_around_center
(
const
Point
&
center
,
const
float
angle_cos
,
const
float
angle_sin
,
Point
&
p
){
float
new_x
=
(
p
.
x
-
center
.
x
)
*
angle_cos
+
(
p
.
y
-
center
.
y
)
*
angle_sin
+
center
.
x
;
float
new_y
=
-
(
p
.
x
-
center
.
x
)
*
angle_sin
+
(
p
.
y
-
center
.
y
)
*
angle_cos
+
center
.
y
;
p
.
set
(
new_x
,
new_y
);
__device__
inline
void
rotate_around_center
(
const
Point
&
center
,
const
float
angle_cos
,
const
float
angle_sin
,
Point
&
p
)
{
float
new_x
=
(
p
.
x
-
center
.
x
)
*
angle_cos
+
(
p
.
y
-
center
.
y
)
*
angle_sin
+
center
.
x
;
float
new_y
=
-
(
p
.
x
-
center
.
x
)
*
angle_sin
+
(
p
.
y
-
center
.
y
)
*
angle_cos
+
center
.
y
;
p
.
set
(
new_x
,
new_y
);
}
__device__
inline
int
point_cmp
(
const
Point
&
a
,
const
Point
&
b
,
const
Point
&
center
){
return
atan2
(
a
.
y
-
center
.
y
,
a
.
x
-
center
.
x
)
>
atan2
(
b
.
y
-
center
.
y
,
b
.
x
-
center
.
x
);
__device__
inline
int
point_cmp
(
const
Point
&
a
,
const
Point
&
b
,
const
Point
&
center
)
{
return
atan2
(
a
.
y
-
center
.
y
,
a
.
x
-
center
.
x
)
>
atan2
(
b
.
y
-
center
.
y
,
b
.
x
-
center
.
x
);
}
__device__
inline
float
box_overlap
(
const
float
*
box_a
,
const
float
*
box_b
){
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
__device__
inline
float
box_overlap
(
const
float
*
box_a
,
const
float
*
box_b
)
{
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
float
a_x1
=
box_a
[
0
],
a_y1
=
box_a
[
1
],
a_x2
=
box_a
[
2
],
a_y2
=
box_a
[
3
],
a_angle
=
box_a
[
4
];
float
b_x1
=
box_b
[
0
],
b_y1
=
box_b
[
1
],
b_x2
=
box_b
[
2
],
b_y2
=
box_b
[
3
],
b_angle
=
box_b
[
4
];
float
a_x1
=
box_a
[
0
],
a_y1
=
box_a
[
1
],
a_x2
=
box_a
[
2
],
a_y2
=
box_a
[
3
],
a_angle
=
box_a
[
4
];
float
b_x1
=
box_b
[
0
],
b_y1
=
box_b
[
1
],
b_x2
=
box_b
[
2
],
b_y2
=
box_b
[
3
],
b_angle
=
box_b
[
4
];
Point
center_a
((
a_x1
+
a_x2
)
/
2
,
(
a_y1
+
a_y2
)
/
2
);
Point
center_b
((
b_x1
+
b_x2
)
/
2
,
(
b_y1
+
b_y2
)
/
2
);
Point
center_a
((
a_x1
+
a_x2
)
/
2
,
(
a_y1
+
a_y2
)
/
2
);
Point
center_b
((
b_x1
+
b_x2
)
/
2
,
(
b_y1
+
b_y2
)
/
2
);
#ifdef DEBUG
printf
(
"a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)
\n
"
,
a_x1
,
a_y1
,
a_x2
,
a_y2
,
a_angle
,
b_x1
,
b_y1
,
b_x2
,
b_y2
,
b_angle
);
printf
(
"center a: (%.3f, %.3f), b: (%.3f, %.3f)
\n
"
,
center_a
.
x
,
center_a
.
y
,
center_b
.
x
,
center_b
.
y
);
printf
(
"a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)
\n
"
,
a_x1
,
a_y1
,
a_x2
,
a_y2
,
a_angle
,
b_x1
,
b_y1
,
b_x2
,
b_y2
,
b_angle
);
printf
(
"center a: (%.3f, %.3f), b: (%.3f, %.3f)
\n
"
,
center_a
.
x
,
center_a
.
y
,
center_b
.
x
,
center_b
.
y
);
#endif
Point
box_a_corners
[
5
];
box_a_corners
[
0
].
set
(
a_x1
,
a_y1
);
box_a_corners
[
1
].
set
(
a_x2
,
a_y1
);
box_a_corners
[
2
].
set
(
a_x2
,
a_y2
);
box_a_corners
[
3
].
set
(
a_x1
,
a_y2
);
Point
box_a_corners
[
5
];
box_a_corners
[
0
].
set
(
a_x1
,
a_y1
);
box_a_corners
[
1
].
set
(
a_x2
,
a_y1
);
box_a_corners
[
2
].
set
(
a_x2
,
a_y2
);
box_a_corners
[
3
].
set
(
a_x1
,
a_y2
);
Point
box_b_corners
[
5
];
box_b_corners
[
0
].
set
(
b_x1
,
b_y1
);
box_b_corners
[
1
].
set
(
b_x2
,
b_y1
);
box_b_corners
[
2
].
set
(
b_x2
,
b_y2
);
box_b_corners
[
3
].
set
(
b_x1
,
b_y2
);
Point
box_b_corners
[
5
];
box_b_corners
[
0
].
set
(
b_x1
,
b_y1
);
box_b_corners
[
1
].
set
(
b_x2
,
b_y1
);
box_b_corners
[
2
].
set
(
b_x2
,
b_y2
);
box_b_corners
[
3
].
set
(
b_x1
,
b_y2
);
// get oriented corners
float
a_angle_cos
=
cos
(
a_angle
),
a_angle_sin
=
sin
(
a_angle
);
float
b_angle_cos
=
cos
(
b_angle
),
b_angle_sin
=
sin
(
b_angle
);
// get oriented corners
float
a_angle_cos
=
cos
(
a_angle
),
a_angle_sin
=
sin
(
a_angle
);
float
b_angle_cos
=
cos
(
b_angle
),
b_angle_sin
=
sin
(
b_angle
);
for
(
int
k
=
0
;
k
<
4
;
k
++
){
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
#ifdef DEBUG
printf
(
"before corner %d: a(%.3f, %.3f), b(%.3f, %.3f)
\n
"
,
k
,
box_a_corners
[
k
].
x
,
box_a_corners
[
k
].
y
,
box_b_corners
[
k
].
x
,
box_b_corners
[
k
].
y
);
printf
(
"before corner %d: a(%.3f, %.3f), b(%.3f, %.3f)
\n
"
,
k
,
box_a_corners
[
k
].
x
,
box_a_corners
[
k
].
y
,
box_b_corners
[
k
].
x
,
box_b_corners
[
k
].
y
);
#endif
rotate_around_center
(
center_a
,
a_angle_cos
,
a_angle_sin
,
box_a_corners
[
k
]);
rotate_around_center
(
center_b
,
b_angle_cos
,
b_angle_sin
,
box_b_corners
[
k
]);
rotate_around_center
(
center_a
,
a_angle_cos
,
a_angle_sin
,
box_a_corners
[
k
]);
rotate_around_center
(
center_b
,
b_angle_cos
,
b_angle_sin
,
box_b_corners
[
k
]);
#ifdef DEBUG
printf
(
"corner %d: a(%.3f, %.3f), b(%.3f, %.3f)
\n
"
,
k
,
box_a_corners
[
k
].
x
,
box_a_corners
[
k
].
y
,
box_b_corners
[
k
].
x
,
box_b_corners
[
k
].
y
);
printf
(
"corner %d: a(%.3f, %.3f), b(%.3f, %.3f)
\n
"
,
k
,
box_a_corners
[
k
].
x
,
box_a_corners
[
k
].
y
,
box_b_corners
[
k
].
x
,
box_b_corners
[
k
].
y
);
#endif
}
box_a_corners
[
4
]
=
box_a_corners
[
0
];
box_b_corners
[
4
]
=
box_b_corners
[
0
];
// get intersection of lines
Point
cross_points
[
16
];
Point
poly_center
;
int
cnt
=
0
,
flag
=
0
;
poly_center
.
set
(
0
,
0
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
flag
=
intersection
(
box_a_corners
[
i
+
1
],
box_a_corners
[
i
],
box_b_corners
[
j
+
1
],
box_b_corners
[
j
],
cross_points
[
cnt
]);
if
(
flag
)
{
poly_center
=
poly_center
+
cross_points
[
cnt
];
cnt
++
;
}
}
box_a_corners
[
4
]
=
box_a_corners
[
0
];
box_b_corners
[
4
]
=
box_b_corners
[
0
];
// get intersection of lines
Point
cross_points
[
16
];
Point
poly_center
;
int
cnt
=
0
,
flag
=
0
;
poly_center
.
set
(
0
,
0
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
for
(
int
j
=
0
;
j
<
4
;
j
++
){
flag
=
intersection
(
box_a_corners
[
i
+
1
],
box_a_corners
[
i
],
box_b_corners
[
j
+
1
],
box_b_corners
[
j
],
cross_points
[
cnt
]);
if
(
flag
){
poly_center
=
poly_center
+
cross_points
[
cnt
];
cnt
++
;
}
}
}
// check corners
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
if
(
check_in_box2d
(
box_a
,
box_b_corners
[
k
]))
{
poly_center
=
poly_center
+
box_b_corners
[
k
];
cross_points
[
cnt
]
=
box_b_corners
[
k
];
cnt
++
;
}
// check corners
for
(
int
k
=
0
;
k
<
4
;
k
++
){
if
(
check_in_box2d
(
box_a
,
box_b_corners
[
k
])){
poly_center
=
poly_center
+
box_b_corners
[
k
];
cross_points
[
cnt
]
=
box_b_corners
[
k
];
cnt
++
;
}
if
(
check_in_box2d
(
box_b
,
box_a_corners
[
k
])){
poly_center
=
poly_center
+
box_a_corners
[
k
];
cross_points
[
cnt
]
=
box_a_corners
[
k
];
cnt
++
;
}
if
(
check_in_box2d
(
box_b
,
box_a_corners
[
k
]))
{
poly_center
=
poly_center
+
box_a_corners
[
k
];
cross_points
[
cnt
]
=
box_a_corners
[
k
];
cnt
++
;
}
poly_center
.
x
/=
cnt
;
poly_center
.
y
/=
cnt
;
// sort the points of polygon
Point
temp
;
for
(
int
j
=
0
;
j
<
cnt
-
1
;
j
++
){
for
(
int
i
=
0
;
i
<
cnt
-
j
-
1
;
i
++
){
if
(
point_cmp
(
cross_points
[
i
],
cross_points
[
i
+
1
],
poly_center
))
{
temp
=
cross_points
[
i
];
cross_points
[
i
]
=
cross_points
[
i
+
1
];
cross_points
[
i
+
1
]
=
temp
;
}
}
}
poly_center
.
x
/=
cnt
;
poly_center
.
y
/=
cnt
;
// sort the points of polygon
Point
temp
;
for
(
int
j
=
0
;
j
<
cnt
-
1
;
j
++
)
{
for
(
int
i
=
0
;
i
<
cnt
-
j
-
1
;
i
++
)
{
if
(
point_cmp
(
cross_points
[
i
],
cross_points
[
i
+
1
],
poly_center
))
{
temp
=
cross_points
[
i
];
cross_points
[
i
]
=
cross_points
[
i
+
1
];
cross_points
[
i
+
1
]
=
temp
;
}
}
}
#ifdef DEBUG
printf
(
"cnt=%d
\n
"
,
cnt
);
for
(
int
i
=
0
;
i
<
cnt
;
i
++
){
printf
(
"All cross point %d: (%.3f, %.3f)
\n
"
,
i
,
cross_points
[
i
].
x
,
cross_points
[
i
].
y
);
}
printf
(
"cnt=%d
\n
"
,
cnt
);
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
printf
(
"All cross point %d: (%.3f, %.3f)
\n
"
,
i
,
cross_points
[
i
].
x
,
cross_points
[
i
].
y
);
}
#endif
// get the overlap areas
float
area
=
0
;
for
(
int
k
=
0
;
k
<
cnt
-
1
;
k
++
){
area
+=
cross
(
cross_points
[
k
]
-
cross_points
[
0
],
cross_points
[
k
+
1
]
-
cross_points
[
0
]);
}
// get the overlap areas
float
area
=
0
;
for
(
int
k
=
0
;
k
<
cnt
-
1
;
k
++
)
{
area
+=
cross
(
cross_points
[
k
]
-
cross_points
[
0
],
cross_points
[
k
+
1
]
-
cross_points
[
0
]);
}
return
fabs
(
area
)
/
2.0
;
return
fabs
(
area
)
/
2.0
;
}
__device__
inline
float
iou_bev
(
const
float
*
box_a
,
const
float
*
box_b
){
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
float
sa
=
(
box_a
[
2
]
-
box_a
[
0
])
*
(
box_a
[
3
]
-
box_a
[
1
]);
float
sb
=
(
box_b
[
2
]
-
box_b
[
0
])
*
(
box_b
[
3
]
-
box_b
[
1
]);
float
s_overlap
=
box_overlap
(
box_a
,
box_b
);
return
s_overlap
/
fmaxf
(
sa
+
sb
-
s_overlap
,
EPS
);
__device__
inline
float
iou_bev
(
const
float
*
box_a
,
const
float
*
box_b
)
{
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
float
sa
=
(
box_a
[
2
]
-
box_a
[
0
])
*
(
box_a
[
3
]
-
box_a
[
1
]);
float
sb
=
(
box_b
[
2
]
-
box_b
[
0
])
*
(
box_b
[
3
]
-
box_b
[
1
]);
float
s_overlap
=
box_overlap
(
box_a
,
box_b
);
return
s_overlap
/
fmaxf
(
sa
+
sb
-
s_overlap
,
EPS
);
}
__global__
void
boxes_overlap_kernel
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_overlap
){
const
int
a_idx
=
blockIdx
.
y
*
THREADS_PER_BLOCK
+
threadIdx
.
y
;
const
int
b_idx
=
blockIdx
.
x
*
THREADS_PER_BLOCK
+
threadIdx
.
x
;
if
(
a_idx
>=
num_a
||
b_idx
>=
num_b
){
return
;
}
const
float
*
cur_box_a
=
boxes_a
+
a_idx
*
5
;
const
float
*
cur_box_b
=
boxes_b
+
b_idx
*
5
;
float
s_overlap
=
box_overlap
(
cur_box_a
,
cur_box_b
);
ans_overlap
[
a_idx
*
num_b
+
b_idx
]
=
s_overlap
;
__global__
void
boxes_overlap_kernel
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_overlap
)
{
const
int
a_idx
=
blockIdx
.
y
*
THREADS_PER_BLOCK
+
threadIdx
.
y
;
const
int
b_idx
=
blockIdx
.
x
*
THREADS_PER_BLOCK
+
threadIdx
.
x
;
if
(
a_idx
>=
num_a
||
b_idx
>=
num_b
)
{
return
;
}
const
float
*
cur_box_a
=
boxes_a
+
a_idx
*
5
;
const
float
*
cur_box_b
=
boxes_b
+
b_idx
*
5
;
float
s_overlap
=
box_overlap
(
cur_box_a
,
cur_box_b
);
ans_overlap
[
a_idx
*
num_b
+
b_idx
]
=
s_overlap
;
}
__global__
void
boxes_iou_bev_kernel
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_iou
){
const
int
a_idx
=
blockIdx
.
y
*
THREADS_PER_BLOCK
+
threadIdx
.
y
;
const
int
b_idx
=
blockIdx
.
x
*
THREADS_PER_BLOCK
+
threadIdx
.
x
;
__global__
void
boxes_iou_bev_kernel
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_iou
)
{
const
int
a_idx
=
blockIdx
.
y
*
THREADS_PER_BLOCK
+
threadIdx
.
y
;
const
int
b_idx
=
blockIdx
.
x
*
THREADS_PER_BLOCK
+
threadIdx
.
x
;
if
(
a_idx
>=
num_a
||
b_idx
>=
num_b
){
return
;
}
if
(
a_idx
>=
num_a
||
b_idx
>=
num_b
)
{
return
;
}
const
float
*
cur_box_a
=
boxes_a
+
a_idx
*
5
;
const
float
*
cur_box_b
=
boxes_b
+
b_idx
*
5
;
float
cur_iou_bev
=
iou_bev
(
cur_box_a
,
cur_box_b
);
ans_iou
[
a_idx
*
num_b
+
b_idx
]
=
cur_iou_bev
;
const
float
*
cur_box_a
=
boxes_a
+
a_idx
*
5
;
const
float
*
cur_box_b
=
boxes_b
+
b_idx
*
5
;
float
cur_iou_bev
=
iou_bev
(
cur_box_a
,
cur_box_b
);
ans_iou
[
a_idx
*
num_b
+
b_idx
]
=
cur_iou_bev
;
}
__global__
void
nms_kernel
(
const
int
boxes_num
,
const
float
nms_overlap_thresh
,
const
float
*
boxes
,
unsigned
long
long
*
mask
){
//params: boxes (N, 5) [x1, y1, x2, y2, ry]
//params: mask (N, N/THREADS_PER_BLOCK_NMS)
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
// if (row_start > col_start) return;
const
int
row_size
=
fminf
(
boxes_num
-
row_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
const
int
col_size
=
fminf
(
boxes_num
-
col_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
__shared__
float
block_boxes
[
THREADS_PER_BLOCK_NMS
*
5
];
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
5
+
0
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
0
];
block_boxes
[
threadIdx
.
x
*
5
+
1
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
1
];
block_boxes
[
threadIdx
.
x
*
5
+
2
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
2
];
block_boxes
[
threadIdx
.
x
*
5
+
3
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
3
];
block_boxes
[
threadIdx
.
x
*
5
+
4
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
4
];
const
float
*
boxes
,
unsigned
long
long
*
mask
)
{
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
// if (row_start > col_start) return;
const
int
row_size
=
fminf
(
boxes_num
-
row_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
const
int
col_size
=
fminf
(
boxes_num
-
col_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
__shared__
float
block_boxes
[
THREADS_PER_BLOCK_NMS
*
5
];
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
5
+
0
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
0
];
block_boxes
[
threadIdx
.
x
*
5
+
1
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
1
];
block_boxes
[
threadIdx
.
x
*
5
+
2
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
2
];
block_boxes
[
threadIdx
.
x
*
5
+
3
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
3
];
block_boxes
[
threadIdx
.
x
*
5
+
4
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
4
];
}
__syncthreads
();
if
(
threadIdx
.
x
<
row_size
)
{
const
int
cur_box_idx
=
THREADS_PER_BLOCK_NMS
*
row_start
+
threadIdx
.
x
;
const
float
*
cur_box
=
boxes
+
cur_box_idx
*
5
;
int
i
=
0
;
unsigned
long
long
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
__syncthreads
();
if
(
threadIdx
.
x
<
row_size
)
{
const
int
cur_box_idx
=
THREADS_PER_BLOCK_NMS
*
row_start
+
threadIdx
.
x
;
const
float
*
cur_box
=
boxes
+
cur_box_idx
*
5
;
int
i
=
0
;
unsigned
long
long
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
iou_bev
(
cur_box
,
block_boxes
+
i
*
5
)
>
nms_overlap_thresh
){
t
|=
1ULL
<<
i
;
}
}
const
int
col_blocks
=
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
);
mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
iou_bev
(
cur_box
,
block_boxes
+
i
*
5
)
>
nms_overlap_thresh
)
{
t
|=
1ULL
<<
i
;
}
}
const
int
col_blocks
=
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
);
mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
}
}
__device__
inline
float
iou_normal
(
float
const
*
const
a
,
float
const
*
const
b
)
{
float
left
=
fmaxf
(
a
[
0
],
b
[
0
]),
right
=
fminf
(
a
[
2
],
b
[
2
]);
float
top
=
fmaxf
(
a
[
1
],
b
[
1
]),
bottom
=
fminf
(
a
[
3
],
b
[
3
]);
float
width
=
fmaxf
(
right
-
left
,
0.
f
),
height
=
fmaxf
(
bottom
-
top
,
0.
f
);
float
interS
=
width
*
height
;
float
Sa
=
(
a
[
2
]
-
a
[
0
])
*
(
a
[
3
]
-
a
[
1
]);
float
Sb
=
(
b
[
2
]
-
b
[
0
])
*
(
b
[
3
]
-
b
[
1
]);
return
interS
/
fmaxf
(
Sa
+
Sb
-
interS
,
EPS
);
__device__
inline
float
iou_normal
(
float
const
*
const
a
,
float
const
*
const
b
)
{
float
left
=
fmaxf
(
a
[
0
],
b
[
0
]),
right
=
fminf
(
a
[
2
],
b
[
2
]);
float
top
=
fmaxf
(
a
[
1
],
b
[
1
]),
bottom
=
fminf
(
a
[
3
],
b
[
3
]);
float
width
=
fmaxf
(
right
-
left
,
0.
f
),
height
=
fmaxf
(
bottom
-
top
,
0.
f
);
float
interS
=
width
*
height
;
float
Sa
=
(
a
[
2
]
-
a
[
0
])
*
(
a
[
3
]
-
a
[
1
]);
float
Sb
=
(
b
[
2
]
-
b
[
0
])
*
(
b
[
3
]
-
b
[
1
]);
return
interS
/
fmaxf
(
Sa
+
Sb
-
interS
,
EPS
);
}
__global__
void
nms_normal_kernel
(
const
int
boxes_num
,
const
float
nms_overlap_thresh
,
const
float
*
boxes
,
unsigned
long
long
*
mask
){
//params: boxes (N, 5) [x1, y1, x2, y2, ry]
//params: mask (N, N/THREADS_PER_BLOCK_NMS)
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
// if (row_start > col_start) return;
const
int
row_size
=
fminf
(
boxes_num
-
row_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
const
int
col_size
=
fminf
(
boxes_num
-
col_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
__shared__
float
block_boxes
[
THREADS_PER_BLOCK_NMS
*
5
];
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
5
+
0
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
0
];
block_boxes
[
threadIdx
.
x
*
5
+
1
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
1
];
block_boxes
[
threadIdx
.
x
*
5
+
2
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
2
];
block_boxes
[
threadIdx
.
x
*
5
+
3
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
3
];
block_boxes
[
threadIdx
.
x
*
5
+
4
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
4
];
__global__
void
nms_normal_kernel
(
const
int
boxes_num
,
const
float
nms_overlap_thresh
,
const
float
*
boxes
,
unsigned
long
long
*
mask
)
{
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
// if (row_start > col_start) return;
const
int
row_size
=
fminf
(
boxes_num
-
row_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
const
int
col_size
=
fminf
(
boxes_num
-
col_start
*
THREADS_PER_BLOCK_NMS
,
THREADS_PER_BLOCK_NMS
);
__shared__
float
block_boxes
[
THREADS_PER_BLOCK_NMS
*
5
];
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
5
+
0
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
0
];
block_boxes
[
threadIdx
.
x
*
5
+
1
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
1
];
block_boxes
[
threadIdx
.
x
*
5
+
2
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
2
];
block_boxes
[
threadIdx
.
x
*
5
+
3
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
3
];
block_boxes
[
threadIdx
.
x
*
5
+
4
]
=
boxes
[(
THREADS_PER_BLOCK_NMS
*
col_start
+
threadIdx
.
x
)
*
5
+
4
];
}
__syncthreads
();
if
(
threadIdx
.
x
<
row_size
)
{
const
int
cur_box_idx
=
THREADS_PER_BLOCK_NMS
*
row_start
+
threadIdx
.
x
;
const
float
*
cur_box
=
boxes
+
cur_box_idx
*
5
;
int
i
=
0
;
unsigned
long
long
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
__syncthreads
();
if
(
threadIdx
.
x
<
row_size
)
{
const
int
cur_box_idx
=
THREADS_PER_BLOCK_NMS
*
row_start
+
threadIdx
.
x
;
const
float
*
cur_box
=
boxes
+
cur_box_idx
*
5
;
int
i
=
0
;
unsigned
long
long
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
iou_normal
(
cur_box
,
block_boxes
+
i
*
5
)
>
nms_overlap_thresh
){
t
|=
1ULL
<<
i
;
}
}
const
int
col_blocks
=
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
);
mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
iou_normal
(
cur_box
,
block_boxes
+
i
*
5
)
>
nms_overlap_thresh
)
{
t
|=
1ULL
<<
i
;
}
}
const
int
col_blocks
=
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
);
mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
}
}
void
boxesoverlapLauncher
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_overlap
)
{
dim3
blocks
(
DIVUP
(
num_b
,
THREADS_PER_BLOCK
),
DIVUP
(
num_a
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
,
THREADS_PER_BLOCK
);
void
boxesoverlapLauncher
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_overlap
){
dim3
blocks
(
DIVUP
(
num_b
,
THREADS_PER_BLOCK
),
DIVUP
(
num_a
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
,
THREADS_PER_BLOCK
);
boxes_overlap_kernel
<<<
blocks
,
threads
>>>
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_overlap
);
boxes_overlap_kernel
<<<
blocks
,
threads
>>>
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_overlap
);
#ifdef DEBUG
cudaDeviceSynchronize
();
// for using printf in kernel function
cudaDeviceSynchronize
();
// for using printf in kernel function
#endif
}
void
boxesioubevLauncher
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_iou
){
dim3
blocks
(
DIVUP
(
num_b
,
THREADS_PER_BLOCK
),
DIVUP
(
num_a
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
,
THREADS_PER_BLOCK
);
void
boxesioubevLauncher
(
const
int
num_a
,
const
float
*
boxes_a
,
const
int
num_b
,
const
float
*
boxes_b
,
float
*
ans_iou
)
{
dim3
blocks
(
DIVUP
(
num_b
,
THREADS_PER_BLOCK
),
DIVUP
(
num_a
,
THREADS_PER_BLOCK
));
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
,
THREADS_PER_BLOCK
);
boxes_iou_bev_kernel
<<<
blocks
,
threads
>>>
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_iou
);
boxes_iou_bev_kernel
<<<
blocks
,
threads
>>>
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_iou
);
}
void
nmsLauncher
(
const
float
*
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
){
dim3
blocks
(
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
),
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
));
dim3
threads
(
THREADS_PER_BLOCK_NMS
);
nms_kernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_overlap_thresh
,
boxes
,
mask
);
void
nmsLauncher
(
const
float
*
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
)
{
dim3
blocks
(
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
),
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
));
dim3
threads
(
THREADS_PER_BLOCK_NMS
);
nms_kernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_overlap_thresh
,
boxes
,
mask
);
}
void
nmsNormalLauncher
(
const
float
*
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
){
dim3
blocks
(
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
),
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
));
dim3
threads
(
THREADS_PER_BLOCK_NMS
);
nms_normal_kernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_overlap_thresh
,
boxes
,
mask
);
void
nmsNormalLauncher
(
const
float
*
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
)
{
dim3
blocks
(
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
),
DIVUP
(
boxes_num
,
THREADS_PER_BLOCK_NMS
));
dim3
threads
(
THREADS_PER_BLOCK_NMS
);
nms_normal_kernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_overlap_thresh
,
boxes
,
mask
);
}
mmdet3d/ops/roiaware_pool3d/__init__.py
View file @
f27d308f
from
.points_in_boxes
import
points_in_boxes_cpu
,
points_in_boxes_gpu
from
.points_in_boxes
import
(
points_in_boxes_batch
,
points_in_boxes_cpu
,
points_in_boxes_gpu
)
from
.roiaware_pool3d
import
RoIAwarePool3d
__all__
=
[
'RoIAwarePool3d'
,
'points_in_boxes_gpu'
,
'points_in_boxes_cpu'
]
__all__
=
[
'RoIAwarePool3d'
,
'points_in_boxes_gpu'
,
'points_in_boxes_cpu'
,
'points_in_boxes_batch'
]
mmdet3d/ops/roiaware_pool3d/points_in_boxes.py
View file @
f27d308f
...
...
@@ -53,3 +53,29 @@ def points_in_boxes_cpu(points, boxes):
point_indices
)
return
point_indices
def
points_in_boxes_batch
(
points
,
boxes
):
"""Find points that are in boxes (CUDA)
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR coordinate
boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, w, l, h, ry] in LiDAR coordinate,
(x, y, z) is the bottom center
Returns:
box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0
"""
assert
boxes
.
shape
[
0
]
==
points
.
shape
[
0
]
assert
boxes
.
shape
[
2
]
==
7
batch_size
,
num_points
,
_
=
points
.
shape
num_boxes
=
boxes
.
shape
[
1
]
box_idxs_of_pts
=
points
.
new_zeros
((
batch_size
,
num_points
,
num_boxes
),
dtype
=
torch
.
int
).
fill_
(
0
)
roiaware_pool3d_ext
.
points_in_boxes_batch
(
boxes
.
contiguous
(),
points
.
contiguous
(),
box_idxs_of_pts
)
return
box_idxs_of_pts
mmdet3d/ops/roiaware_pool3d/src/points_in_boxes_cuda.cu
View file @
f27d308f
...
...
@@ -77,6 +77,34 @@ __global__ void points_in_boxes_kernel(int batch_size, int boxes_num,
}
}
__global__
void
points_in_boxes_batch_kernel
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
// y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
// -1
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
batch_size
||
pt_idx
>=
pts_num
)
return
;
boxes
+=
bs_idx
*
boxes_num
*
7
;
pts
+=
bs_idx
*
pts_num
*
3
+
pt_idx
*
3
;
box_idx_of_points
+=
bs_idx
*
pts_num
*
boxes_num
+
pt_idx
*
boxes_num
;
float
local_x
=
0
,
local_y
=
0
;
int
cur_in_flag
=
0
;
for
(
int
k
=
0
;
k
<
boxes_num
;
k
++
)
{
cur_in_flag
=
check_pt_in_box3d
(
pts
,
boxes
+
k
*
7
,
local_x
,
local_y
);
if
(
cur_in_flag
)
{
box_idx_of_points
[
k
]
=
1
;
}
cur_in_flag
=
0
;
}
}
void
points_in_boxes_launcher
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
)
{
...
...
@@ -102,6 +130,30 @@ void points_in_boxes_launcher(int batch_size, int boxes_num, int pts_num,
#endif
}
void
points_in_boxes_batch_launcher
(
int
batch_size
,
int
boxes_num
,
int
pts_num
,
const
float
*
boxes
,
const
float
*
pts
,
int
*
box_idx_of_points
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center, each box params pts: (B, npoints, 3) [x, y, z] in
// LiDAR coordinate params boxes_idx_of_points: (B, npoints), default -1
cudaError_t
err
;
dim3
blocks
(
DIVUP
(
pts_num
,
THREADS_PER_BLOCK
),
batch_size
);
dim3
threads
(
THREADS_PER_BLOCK
);
points_in_boxes_batch_kernel
<<<
blocks
,
threads
>>>
(
batch_size
,
boxes_num
,
pts_num
,
boxes
,
pts
,
box_idx_of_points
);
err
=
cudaGetLastError
();
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"CUDA kernel failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
#ifdef DEBUG
cudaDeviceSynchronize
();
// for using printf in kernel function
#endif
}
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
...
...
@@ -126,3 +178,27 @@ int points_in_boxes_gpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
return
1
;
}
int
points_in_boxes_batch
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
)
{
// params boxes: (B, N, 7) [x, y, z, w, l, h, rz] in LiDAR coordinate, z is
// the bottom center. params pts: (B, npoints, 3) [x, y, z] in LiDAR
// coordinate params boxes_idx_of_points: (B, npoints), default -1
CHECK_INPUT
(
boxes_tensor
);
CHECK_INPUT
(
pts_tensor
);
CHECK_INPUT
(
box_idx_of_points_tensor
);
int
batch_size
=
boxes_tensor
.
size
(
0
);
int
boxes_num
=
boxes_tensor
.
size
(
1
);
int
pts_num
=
pts_tensor
.
size
(
1
);
const
float
*
boxes
=
boxes_tensor
.
data_ptr
<
float
>
();
const
float
*
pts
=
pts_tensor
.
data_ptr
<
float
>
();
int
*
box_idx_of_points
=
box_idx_of_points_tensor
.
data_ptr
<
int
>
();
points_in_boxes_batch_launcher
(
batch_size
,
boxes_num
,
pts_num
,
boxes
,
pts
,
box_idx_of_points
);
return
1
;
}
mmdet3d/ops/roiaware_pool3d/src/roiaware_pool3d.cpp
View file @
f27d308f
...
...
@@ -44,6 +44,9 @@ int points_in_boxes_cpu(at::Tensor boxes_tensor, at::Tensor pts_tensor,
int
points_in_boxes_gpu
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
);
int
points_in_boxes_batch
(
at
::
Tensor
boxes_tensor
,
at
::
Tensor
pts_tensor
,
at
::
Tensor
box_idx_of_points_tensor
);
int
roiaware_pool3d_gpu
(
at
::
Tensor
rois
,
at
::
Tensor
pts
,
at
::
Tensor
pts_feature
,
at
::
Tensor
argmax
,
at
::
Tensor
pts_idx_of_voxels
,
at
::
Tensor
pooled_features
,
int
pool_method
)
{
...
...
@@ -127,6 +130,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"roiaware pool3d backward (CUDA)"
);
m
.
def
(
"points_in_boxes_gpu"
,
&
points_in_boxes_gpu
,
"points_in_boxes_gpu forward (CUDA)"
);
m
.
def
(
"points_in_boxes_batch"
,
&
points_in_boxes_batch
,
"points_in_boxes_batch forward (CUDA)"
);
m
.
def
(
"points_in_boxes_cpu"
,
&
points_in_boxes_cpu
,
"points_in_boxes_cpu forward (CPU)"
);
}
mmdet3d/ops/sparse_block.py
View file @
f27d308f
...
...
@@ -6,6 +6,21 @@ from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
class
SparseBottleneck
(
Bottleneck
,
spconv
.
SparseModule
):
"""Sparse bottleneck block for PartA^2.
Bottleneck block implemented with submanifold sparse convolution.
Args:
inplanes (int): inplanes of block.
planes (int): planes of block.
stride (int): stride of the first block. Default: 1
downsample (None | Module): down sample module for block.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
expansion
=
4
def
__init__
(
self
,
...
...
@@ -15,10 +30,7 @@ class SparseBottleneck(Bottleneck, spconv.SparseModule):
downsample
=
None
,
conv_cfg
=
None
,
norm_cfg
=
None
):
"""Sparse bottleneck block for PartA^2.
Bottleneck block implemented with submanifold sparse convolution.
"""
spconv
.
SparseModule
.
__init__
(
self
)
Bottleneck
.
__init__
(
self
,
...
...
@@ -53,6 +65,21 @@ class SparseBottleneck(Bottleneck, spconv.SparseModule):
class
SparseBasicBlock
(
BasicBlock
,
spconv
.
SparseModule
):
"""Sparse basic block for PartA^2.
Sparse basic block implemented with submanifold sparse convolution.
Args:
inplanes (int): inplanes of block.
planes (int): planes of block.
stride (int): stride of the first block. Default: 1
downsample (None | Module): down sample module for block.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
expansion
=
1
def
__init__
(
self
,
...
...
@@ -62,10 +89,6 @@ class SparseBasicBlock(BasicBlock, spconv.SparseModule):
downsample
=
None
,
conv_cfg
=
None
,
norm_cfg
=
None
):
"""Sparse basic block for PartA^2.
Sparse basic block implemented with submanifold sparse convolution.
"""
spconv
.
SparseModule
.
__init__
(
self
)
BasicBlock
.
__init__
(
self
,
...
...
@@ -125,6 +148,7 @@ def make_sparse_convmodule(in_channels,
spconv.SparseSequential: sparse convolution module.
"""
assert
isinstance
(
order
,
tuple
)
and
len
(
order
)
<=
3
assert
set
(
order
)
|
{
'conv'
,
'norm'
,
'act'
}
==
{
'conv'
,
'norm'
,
'act'
}
conv_cfg
=
dict
(
type
=
conv_type
,
indice_key
=
indice_key
)
...
...
mmdet3d/ops/spconv/include/paramsgrid.h
View file @
f27d308f
...
...
@@ -18,13 +18,19 @@
#include <vector>
namespace
detail
{
template
<
class
T
>
int
getTotalSize
(
std
::
vector
<
T
>
arg
)
{
return
arg
.
size
();
}
template
<
class
T
>
int
getTotalSize
(
std
::
vector
<
T
>
arg
)
{
return
arg
.
size
();
}
template
<
class
T
,
class
...
TArgs
>
int
getTotalSize
(
std
::
vector
<
T
>
arg
,
std
::
vector
<
TArgs
>
...
args
)
{
return
arg
.
size
()
*
getTotalSize
(
args
...);
}
template
<
typename
T
>
int
getSize
(
std
::
vector
<
T
>
arg
)
{
return
arg
.
size
();
}
template
<
typename
T
>
int
getSize
(
std
::
vector
<
T
>
arg
)
{
return
arg
.
size
();
}
template
<
int
Idx
,
class
TT
,
class
T
>
void
assigner
(
TT
&
src
,
std
::
vector
<
int
>
counter
,
std
::
vector
<
T
>
&
arg
)
{
...
...
@@ -37,7 +43,7 @@ void assigner(TT &src, std::vector<int> counter, std::vector<T> &arg,
std
::
get
<
Idx
>
(
src
)
=
arg
[
counter
[
Idx
]];
assigner
<
Idx
+
1
>
(
src
,
counter
,
args
...);
}
}
// namespace detail
}
// namespace detail
template
<
class
...
TArgs
>
std
::
vector
<
std
::
tuple
<
TArgs
...
>>
paramsGrid
(
std
::
vector
<
TArgs
>
...
args
)
{
int
length
=
detail
::
getTotalSize
(
args
...);
...
...
mmdet3d/ops/spconv/include/prettyprint.h
View file @
f27d308f
...
...
@@ -22,424 +22,472 @@
#include <utility>
#include <valarray>
namespace
pretty_print
{
namespace
detail
{
// SFINAE type trait to detect whether T::const_iterator exists.
struct
sfinae_base
{
using
yes
=
char
;
using
no
=
yes
[
2
];
};
template
<
typename
T
>
struct
has_const_iterator
:
private
sfinae_base
{
private:
template
<
typename
C
>
static
yes
&
test
(
typename
C
::
const_iterator
*
);
template
<
typename
C
>
static
no
&
test
(...);
public:
static
const
bool
value
=
sizeof
(
test
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
using
type
=
T
;
};
template
<
typename
T
>
struct
has_begin_end
:
private
sfinae_base
{
private:
template
<
typename
C
>
static
yes
&
f
(
typename
std
::
enable_if
<
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
begin
)),
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
>::
type
*
);
template
<
typename
C
>
static
no
&
f
(...);
template
<
typename
C
>
static
yes
&
g
(
typename
std
::
enable_if
<
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
end
)),
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
,
void
>::
type
*
);
template
<
typename
C
>
static
no
&
g
(...);
public:
static
bool
const
beg_value
=
sizeof
(
f
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
static
bool
const
end_value
=
sizeof
(
g
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
};
}
// namespace detail
// Holds the delimiter values for a specific character type
template
<
typename
TChar
>
struct
delimiters_values
{
using
char_type
=
TChar
;
const
char_type
*
prefix
;
const
char_type
*
delimiter
;
const
char_type
*
postfix
;
};
// Defines the delimiter values for a specific container and character type
template
<
typename
T
,
typename
TChar
>
struct
delimiters
{
using
type
=
delimiters_values
<
TChar
>
;
static
const
type
values
;
};
// Functor to print containers. You can use this directly if you want
// to specificy a non-default delimiters type. The printing logic can
// be customized by specializing the nested template.
template
<
typename
T
,
typename
TChar
=
char
,
typename
TCharTraits
=
::
std
::
char_traits
<
TChar
>,
typename
TDelimiters
=
delimiters
<
T
,
TChar
>>
struct
print_container_helper
{
using
delimiters_type
=
TDelimiters
;
using
ostream_type
=
std
::
basic_ostream
<
TChar
,
TCharTraits
>
;
template
<
typename
U
>
struct
printer
{
static
void
print_body
(
const
U
&
c
,
ostream_type
&
stream
)
{
using
std
::
begin
;
using
std
::
end
;
auto
it
=
begin
(
c
);
const
auto
the_end
=
end
(
c
);
if
(
it
!=
the_end
)
{
for
(
;
;
)
{
stream
<<
*
it
;
if
(
++
it
==
the_end
)
break
;
if
(
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
delimiters_type
::
values
.
delimiter
;
}
}
}
};
print_container_helper
(
const
T
&
container
)
:
container_
(
container
)
{
}
inline
void
operator
()(
ostream_type
&
stream
)
const
{
if
(
delimiters_type
::
values
.
prefix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
prefix
;
printer
<
T
>::
print_body
(
container_
,
stream
);
if
(
delimiters_type
::
values
.
postfix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
postfix
;
namespace
pretty_print
{
namespace
detail
{
// SFINAE type trait to detect whether T::const_iterator exists.
struct
sfinae_base
{
using
yes
=
char
;
using
no
=
yes
[
2
];
};
template
<
typename
T
>
struct
has_const_iterator
:
private
sfinae_base
{
private:
template
<
typename
C
>
static
yes
&
test
(
typename
C
::
const_iterator
*
);
template
<
typename
C
>
static
no
&
test
(...);
public:
static
const
bool
value
=
sizeof
(
test
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
using
type
=
T
;
};
template
<
typename
T
>
struct
has_begin_end
:
private
sfinae_base
{
private:
template
<
typename
C
>
static
yes
&
f
(
typename
std
::
enable_if
<
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
begin
)),
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
>::
type
*
);
template
<
typename
C
>
static
no
&
f
(...);
template
<
typename
C
>
static
yes
&
g
(
typename
std
::
enable_if
<
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
end
)),
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
,
void
>::
type
*
);
template
<
typename
C
>
static
no
&
g
(...);
public:
static
bool
const
beg_value
=
sizeof
(
f
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
static
bool
const
end_value
=
sizeof
(
g
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
};
}
// namespace detail
// Holds the delimiter values for a specific character type
template
<
typename
TChar
>
struct
delimiters_values
{
using
char_type
=
TChar
;
const
char_type
*
prefix
;
const
char_type
*
delimiter
;
const
char_type
*
postfix
;
};
// Defines the delimiter values for a specific container and character type
template
<
typename
T
,
typename
TChar
>
struct
delimiters
{
using
type
=
delimiters_values
<
TChar
>
;
static
const
type
values
;
};
// Functor to print containers. You can use this directly if you want
// to specificy a non-default delimiters type. The printing logic can
// be customized by specializing the nested template.
template
<
typename
T
,
typename
TChar
=
char
,
typename
TCharTraits
=
::
std
::
char_traits
<
TChar
>,
typename
TDelimiters
=
delimiters
<
T
,
TChar
>>
struct
print_container_helper
{
using
delimiters_type
=
TDelimiters
;
using
ostream_type
=
std
::
basic_ostream
<
TChar
,
TCharTraits
>
;
template
<
typename
U
>
struct
printer
{
static
void
print_body
(
const
U
&
c
,
ostream_type
&
stream
)
{
using
std
::
begin
;
using
std
::
end
;
auto
it
=
begin
(
c
);
const
auto
the_end
=
end
(
c
);
if
(
it
!=
the_end
)
{
for
(;;)
{
stream
<<
*
it
;
if
(
++
it
==
the_end
)
break
;
if
(
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
delimiters_type
::
values
.
delimiter
;
}
private:
const
T
&
container_
;
};
// Specialization for pairs
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
T1
,
typename
T2
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
pair
<
T1
,
T2
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
static
void
print_body
(
const
std
::
pair
<
T1
,
T2
>
&
c
,
ostream_type
&
stream
)
{
stream
<<
c
.
first
;
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
c
.
second
;
}
};
// Specialization for tuples
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
...
Args
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
tuple
<
Args
...
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
using
element_type
=
std
::
tuple
<
Args
...
>
;
template
<
std
::
size_t
I
>
struct
Int
{
};
static
void
print_body
(
const
element_type
&
c
,
ostream_type
&
stream
)
{
tuple_print
(
c
,
stream
,
Int
<
0
>
());
}
static
void
tuple_print
(
const
element_type
&
,
ostream_type
&
,
Int
<
sizeof
...(
Args
)
>
)
{
}
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
typename
std
::
conditional
<
sizeof
...(
Args
)
!=
0
,
Int
<
0
>
,
std
::
nullptr_t
>::
type
)
{
stream
<<
std
::
get
<
0
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
1
>
());
}
template
<
std
::
size_t
N
>
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
Int
<
N
>
)
{
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
std
::
get
<
N
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
N
+
1
>
());
}
};
// Prints a print_container_helper to the specified stream.
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>
&
helper
)
{
helper
(
stream
);
return
stream
;
}
// Basic is_container template; specialize to derive from std::true_type for all desired container types
template
<
typename
T
>
struct
is_container
:
public
std
::
integral_constant
<
bool
,
detail
::
has_const_iterator
<
T
>::
value
&&
detail
::
has_begin_end
<
T
>::
beg_value
&&
detail
::
has_begin_end
<
T
>::
end_value
>
{
};
template
<
typename
T
,
std
::
size_t
N
>
struct
is_container
<
T
[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
is_container
<
char
[
N
]
>
:
std
::
false_type
{
};
template
<
typename
T
>
struct
is_container
<
std
::
valarray
<
T
>>
:
std
::
true_type
{
};
template
<
typename
T1
,
typename
T2
>
struct
is_container
<
std
::
pair
<
T1
,
T2
>>
:
std
::
true_type
{
};
template
<
typename
...
Args
>
struct
is_container
<
std
::
tuple
<
Args
...
>>
:
std
::
true_type
{
};
// Default delimiters
template
<
typename
T
>
struct
delimiters
<
T
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
char
>
delimiters
<
T
,
char
>::
values
=
{
"["
,
", "
,
"]"
};
template
<
typename
T
>
struct
delimiters
<
T
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
wchar_t
>
delimiters
<
T
,
wchar_t
>::
values
=
{
L"["
,
L", "
,
L"]"
};
// Delimiters for (multi)set and unordered_(multi)set
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
// Delimiters for pair and tuple
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
char
>
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<
::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
template
<
typename
...
Args
>
struct
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
char
>
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
...
Args
>
struct
delimiters
<
::
std
::
tuple
<
Args
...
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
tuple
<
Args
...
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
// Type-erasing helper class for easy use of custom delimiters.
// Requires TCharTraits = std::char_traits<TChar> and TChar = char or wchar_t, and MyDelims needs to be defined for TChar.
// Usage: "cout << pretty_print::custom_delims<MyDelims>(x)".
struct
custom_delims_base
{
virtual
~
custom_delims_base
()
{
}
virtual
std
::
ostream
&
stream
(
::
std
::
ostream
&
)
=
0
;
virtual
std
::
wostream
&
stream
(
::
std
::
wostream
&
)
=
0
;
};
template
<
typename
T
,
typename
Delims
>
struct
custom_delims_wrapper
:
custom_delims_base
{
custom_delims_wrapper
(
const
T
&
t_
)
:
t
(
t_
)
{
}
std
::
ostream
&
stream
(
std
::
ostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
char
,
std
::
char_traits
<
char
>
,
Delims
>
(
t
);
}
std
::
wostream
&
stream
(
std
::
wostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
wchar_t
,
std
::
char_traits
<
wchar_t
>
,
Delims
>
(
t
);
}
private:
const
T
&
t
;
};
template
<
typename
Delims
>
struct
custom_delims
{
template
<
typename
Container
>
custom_delims
(
const
Container
&
c
)
:
base
(
new
custom_delims_wrapper
<
Container
,
Delims
>
(
c
))
{
}
std
::
unique_ptr
<
custom_delims_base
>
base
;
};
template
<
typename
TChar
,
typename
TCharTraits
,
typename
Delims
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
s
,
const
custom_delims
<
Delims
>
&
p
)
{
return
p
.
base
->
stream
(
s
);
}
}
};
print_container_helper
(
const
T
&
container
)
:
container_
(
container
)
{}
inline
void
operator
()(
ostream_type
&
stream
)
const
{
if
(
delimiters_type
::
values
.
prefix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
prefix
;
printer
<
T
>::
print_body
(
container_
,
stream
);
if
(
delimiters_type
::
values
.
postfix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
postfix
;
}
private:
const
T
&
container_
;
};
// Specialization for pairs
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
T1
,
typename
T2
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
pair
<
T1
,
T2
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
static
void
print_body
(
const
std
::
pair
<
T1
,
T2
>
&
c
,
ostream_type
&
stream
)
{
stream
<<
c
.
first
;
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
c
.
second
;
}
};
// Specialization for tuples
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
...
Args
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
tuple
<
Args
...
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
using
element_type
=
std
::
tuple
<
Args
...
>
;
template
<
std
::
size_t
I
>
struct
Int
{};
static
void
print_body
(
const
element_type
&
c
,
ostream_type
&
stream
)
{
tuple_print
(
c
,
stream
,
Int
<
0
>
());
}
static
void
tuple_print
(
const
element_type
&
,
ostream_type
&
,
Int
<
sizeof
...(
Args
)
>
)
{}
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
typename
std
::
conditional
<
sizeof
...(
Args
)
!=
0
,
Int
<
0
>
,
std
::
nullptr_t
>::
type
)
{
stream
<<
std
::
get
<
0
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
1
>
());
}
template
<
std
::
size_t
N
>
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
Int
<
N
>
)
{
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
std
::
get
<
N
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
N
+
1
>
());
}
};
// Prints a print_container_helper to the specified stream.
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>
&
helper
)
{
helper
(
stream
);
return
stream
;
}
// Basic is_container template; specialize to derive from std::true_type for all
// desired container types
template
<
typename
T
>
struct
is_container
:
public
std
::
integral_constant
<
bool
,
detail
::
has_const_iterator
<
T
>::
value
&&
detail
::
has_begin_end
<
T
>::
beg_value
&&
detail
::
has_begin_end
<
T
>::
end_value
>
{};
template
<
typename
T
,
std
::
size_t
N
>
struct
is_container
<
T
[
N
]
>
:
std
::
true_type
{};
template
<
std
::
size_t
N
>
struct
is_container
<
char
[
N
]
>
:
std
::
false_type
{};
template
<
typename
T
>
struct
is_container
<
std
::
valarray
<
T
>>
:
std
::
true_type
{};
template
<
typename
T1
,
typename
T2
>
struct
is_container
<
std
::
pair
<
T1
,
T2
>>
:
std
::
true_type
{};
template
<
typename
...
Args
>
struct
is_container
<
std
::
tuple
<
Args
...
>>
:
std
::
true_type
{};
// Default delimiters
template
<
typename
T
>
struct
delimiters
<
T
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
char
>
delimiters
<
T
,
char
>::
values
=
{
"["
,
", "
,
"]"
};
template
<
typename
T
>
struct
delimiters
<
T
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
wchar_t
>
delimiters
<
T
,
wchar_t
>::
values
=
{
L"["
,
L", "
,
L"]"
};
// Delimiters for (multi)set and unordered_(multi)set
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
// Delimiters for pair and tuple
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
char
>
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
template
<
typename
...
Args
>
struct
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
char
>
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
...
Args
>
struct
delimiters
<::
std
::
tuple
<
Args
...
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
tuple
<
Args
...
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
// Type-erasing helper class for easy use of custom delimiters.
// Requires TCharTraits = std::char_traits<TChar> and TChar = char or wchar_t,
// and MyDelims needs to be defined for TChar. Usage: "cout <<
// pretty_print::custom_delims<MyDelims>(x)".
struct
custom_delims_base
{
virtual
~
custom_delims_base
()
{}
virtual
std
::
ostream
&
stream
(
::
std
::
ostream
&
)
=
0
;
virtual
std
::
wostream
&
stream
(
::
std
::
wostream
&
)
=
0
;
};
template
<
typename
T
,
typename
Delims
>
struct
custom_delims_wrapper
:
custom_delims_base
{
custom_delims_wrapper
(
const
T
&
t_
)
:
t
(
t_
)
{}
std
::
ostream
&
stream
(
std
::
ostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
char
,
std
::
char_traits
<
char
>
,
Delims
>
(
t
);
}
std
::
wostream
&
stream
(
std
::
wostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
wchar_t
,
std
::
char_traits
<
wchar_t
>
,
Delims
>
(
t
);
}
private:
const
T
&
t
;
};
template
<
typename
Delims
>
struct
custom_delims
{
template
<
typename
Container
>
custom_delims
(
const
Container
&
c
)
:
base
(
new
custom_delims_wrapper
<
Container
,
Delims
>
(
c
))
{}
std
::
unique_ptr
<
custom_delims_base
>
base
;
};
template
<
typename
TChar
,
typename
TCharTraits
,
typename
Delims
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
s
,
const
custom_delims
<
Delims
>
&
p
)
{
return
p
.
base
->
stream
(
s
);
}
// A wrapper for a C-style array given as pointer-plus-size.
// Usage: std::cout << pretty_print_array(arr, n) << std::endl;
template
<
typename
T
>
struct
array_wrapper_n
{
typedef
const
T
*
const_iterator
;
typedef
T
value_type
;
// A wrapper for a C-style array given as pointer-plus-size.
// Usage: std::cout << pretty_print_array(arr, n) << std::endl;
array_wrapper_n
(
const
T
*
const
a
,
size_t
n
)
:
_array
(
a
),
_n
(
n
)
{
}
inline
const_iterator
begin
()
const
{
return
_array
;
}
inline
const_iterator
end
()
const
{
return
_array
+
_n
;
}
template
<
typename
T
>
struct
array_wrapper_n
{
typedef
const
T
*
const_iterator
;
typedef
T
value_type
;
private:
const
T
*
const
_array
;
size_t
_n
;
};
array_wrapper_n
(
const
T
*
const
a
,
size_t
n
)
:
_array
(
a
),
_n
(
n
)
{}
inline
const_iterator
begin
()
const
{
return
_array
;
}
inline
const_iterator
end
()
const
{
return
_array
+
_n
;
}
private:
const
T
*
const
_array
;
size_t
_n
;
};
// A wrapper for hash-table based containers that offer local iterators to each bucket.
// Usage: std::cout << bucket_print(m, 4) << std::endl; (Prints bucket 5 of container m.)
// A wrapper for hash-table based containers that offer local iterators to each
// bucket. Usage: std::cout << bucket_print(m, 4) << std::endl; (Prints bucket
// 5 of container m.)
template
<
typename
T
>
struct
bucket_print_wrapper
{
typedef
typename
T
::
const_local_iterator
const_iterator
;
typedef
typename
T
::
size_type
size_type
;
template
<
typename
T
>
struct
bucket_print_wrapper
{
typedef
typename
T
::
const_local_iterator
const_iterator
;
typedef
typename
T
::
size_type
size_type
;
const_iterator
begin
()
const
{
return
m_map
.
cbegin
(
n
);
}
const_iterator
end
()
const
{
return
m_map
.
cend
(
n
);
}
const_iterator
begin
()
const
{
return
m_map
.
cbegin
(
n
);
}
bucket_print_wrapper
(
const
T
&
m
,
size_type
bucket
)
:
m_map
(
m
),
n
(
bucket
)
{
}
const_iterator
end
()
const
{
return
m_map
.
cend
(
n
);
}
private:
const
T
&
m_map
;
const
size_type
n
;
};
bucket_print_wrapper
(
const
T
&
m
,
size_type
bucket
)
:
m_map
(
m
),
n
(
bucket
)
{}
}
// namespace pretty_print
private:
const
T
&
m_map
;
const
size_type
n
;
};
}
// namespace pretty_print
// Global accessor functions for the convenience wrappers
template
<
typename
T
>
inline
pretty_print
::
array_wrapper_n
<
T
>
pretty_print_array
(
const
T
*
const
a
,
size_t
n
)
{
return
pretty_print
::
array_wrapper_n
<
T
>
(
a
,
n
);
template
<
typename
T
>
inline
pretty_print
::
array_wrapper_n
<
T
>
pretty_print_array
(
const
T
*
const
a
,
size_t
n
)
{
return
pretty_print
::
array_wrapper_n
<
T
>
(
a
,
n
);
}
template
<
typename
T
>
pretty_print
::
bucket_print_wrapper
<
T
>
bucket_print
(
const
T
&
m
,
typename
T
::
size_type
n
)
{
return
pretty_print
::
bucket_print_wrapper
<
T
>
(
m
,
n
);
template
<
typename
T
>
pretty_print
::
bucket_print_wrapper
<
T
>
bucket_print
(
const
T
&
m
,
typename
T
::
size_type
n
)
{
return
pretty_print
::
bucket_print_wrapper
<
T
>
(
m
,
n
);
}
// Main magic entry point: An overload snuck into namespace std.
// Can we do better?
namespace
std
{
// Prints a container to the stream using default delimiters
namespace
std
{
// Prints a container to the stream using default delimiters
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
>
inline
typename
enable_if
<
::
pretty_print
::
is_container
<
T
>::
value
,
basic_ostream
<
TChar
,
TCharTraits
>
&>::
type
operator
<<
(
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
T
&
container
)
{
return
stream
<<
::
pretty_print
::
print_container_helper
<
T
,
TChar
,
TCharTraits
>
(
container
);
}
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
>
inline
typename
enable_if
<::
pretty_print
::
is_container
<
T
>::
value
,
basic_ostream
<
TChar
,
TCharTraits
>
&>::
type
operator
<<
(
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
T
&
container
)
{
return
stream
<<
::
pretty_print
::
print_container_helper
<
T
,
TChar
,
TCharTraits
>
(
container
);
}
}
// namespace std
#endif // H_PRETTY_PRINT
#endif // H_PRETTY_PRINT
mmdet3d/ops/spconv/include/spconv/box_iou.h
View file @
f27d308f
...
...
@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef BOX_IOU_H
#define BOX_IOU_H
#include <pybind11/pybind11.h>
// must include pybind11/eigen.h if using eigen matrix as arguments.
#include <pybind11/numpy.h>
#include <algorithm>
#include <boost/geometry.hpp>
#include <pybind11/numpy.h>
namespace
spconv
{
// #include "voxelnet/core/cc/pybind11_helper.h"
...
...
@@ -40,9 +40,10 @@ inline py::array_t<DType> zeros(std::vector<long int> shape) {
}
template
<
typename
DType
>
py
::
array_t
<
DType
>
rbbox_iou
(
py
::
array_t
<
DType
>
box_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
py
::
array_t
<
DType
>
rbbox_iou
(
py
::
array_t
<
DType
>
box_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
namespace
bg
=
boost
::
geometry
;
typedef
bg
::
model
::
point
<
DType
,
2
,
bg
::
cs
::
cartesian
>
point_t
;
typedef
bg
::
model
::
polygon
<
point_t
>
polygon_t
;
...
...
@@ -61,8 +62,7 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners,
}
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
if
(
standup_iou_r
(
n
,
k
)
<=
standup_thresh
)
continue
;
if
(
standup_iou_r
(
n
,
k
)
<=
standup_thresh
)
continue
;
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
0
,
0
),
box_corners_r
(
n
,
0
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
1
,
0
),
box_corners_r
(
n
,
1
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
2
,
0
),
box_corners_r
(
n
,
2
,
1
)));
...
...
@@ -99,9 +99,10 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners,
}
template
<
typename
DType
>
py
::
array_t
<
DType
>
rbbox_intersection
(
py
::
array_t
<
DType
>
box_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
py
::
array_t
<
DType
>
rbbox_intersection
(
py
::
array_t
<
DType
>
box_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
namespace
bg
=
boost
::
geometry
;
typedef
bg
::
model
::
point
<
DType
,
2
,
bg
::
cs
::
cartesian
>
point_t
;
typedef
bg
::
model
::
polygon
<
point_t
>
polygon_t
;
...
...
@@ -120,8 +121,7 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne
}
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
if
(
standup_iou_r
(
n
,
k
)
<=
standup_thresh
)
continue
;
if
(
standup_iou_r
(
n
,
k
)
<=
standup_thresh
)
continue
;
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
0
,
0
),
box_corners_r
(
n
,
0
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
1
,
0
),
box_corners_r
(
n
,
1
,
1
)));
bg
::
append
(
poly
,
point_t
(
box_corners_r
(
n
,
2
,
0
),
box_corners_r
(
n
,
2
,
1
)));
...
...
@@ -152,6 +152,5 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne
return
overlaps
;
}
}
// namespace spconv
}
// namespace spconv
#endif
mmdet3d/ops/spconv/include/spconv/geometry.h
View file @
f27d308f
...
...
@@ -15,9 +15,10 @@
#ifndef SPCONV_GEOMETRY_H_
#define SPCONV_GEOMETRY_H_
#include <tensorview/tensorview.h>
#include <iostream>
#include <limits>
#include <tensorview/tensorview.h>
namespace
spconv
{
template
<
typename
Index
,
unsigned
NDim
>
...
...
@@ -70,8 +71,7 @@ TV_HOST_DEVICE Index getValidOutPos(const Index *input_pos,
}
out
[
pointCounter
*
(
NDim
+
1
)
+
NDim
]
=
offset
;
if
(
valid
)
++
pointCounter
;
if
(
valid
)
++
pointCounter
;
counter
[
NDim
-
1
]
+=
1
;
#pragma unroll
for
(
int
c
=
NDim
-
1
;
c
>=
0
;
--
c
)
{
...
...
@@ -128,8 +128,7 @@ TV_HOST_DEVICE Index getValidOutPosTranspose(
m
*=
kernelSize
[
j
];
}
out
[
pointCounter
*
(
NDim
+
1
)
+
NDim
]
=
offset
;
if
(
valid
)
++
pointCounter
;
if
(
valid
)
++
pointCounter
;
counter
[
NDim
-
1
]
+=
1
;
#pragma unroll
for
(
int
c
=
NDim
-
1
;
c
>=
0
;
--
c
)
{
...
...
@@ -167,7 +166,7 @@ Index getIndicePairsConv(tv::TensorView<const Index> indicesIn,
}
Index
numValidPoints
=
0
;
std
::
vector
<
Index
>
validPoints_
(
kernelVolume
*
(
NDim
+
1
));
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
pointPtr
=
nullptr
;
for
(
int
j
=
0
;
j
<
numActIn
;
++
j
)
{
batchIdx
=
indicesIn
(
j
,
0
);
...
...
@@ -218,7 +217,7 @@ Index getIndicePairsDeConv(tv::TensorView<const Index> indicesIn,
}
Index
numValidPoints
=
0
;
std
::
vector
<
Index
>
validPoints_
(
kernelVolume
*
(
NDim
+
1
));
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
pointPtr
=
nullptr
;
for
(
int
j
=
0
;
j
<
numActIn
;
++
j
)
{
batchIdx
=
indicesIn
(
j
,
0
);
...
...
@@ -252,7 +251,8 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
tv
::
TensorView
<
Index
>
indiceNum
,
const
Index
*
const
kernelSize
,
const
Index
*
const
stride
,
const
Index
*
const
padding
,
const
Index
*
dilation
,
const
Index
*
const
outSpatialShape
)
{
const
Index
*
dilation
,
const
Index
*
const
outSpatialShape
)
{
Index
numAct
=
0
;
auto
numActIn
=
indicesIn
.
dim
(
0
);
Index
batchIdx
=
0
;
...
...
@@ -269,7 +269,7 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
Index
numValidPoints
=
0
;
// Index validPoints[kernelVolume * (NDim + 1)];
std
::
vector
<
Index
>
validPoints_
(
kernelVolume
*
(
NDim
+
1
));
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
validPoints
=
validPoints_
.
data
();
Index
*
pointPtr
=
nullptr
;
Index
index
=
0
;
for
(
int
j
=
0
;
j
<
numActIn
;
++
j
)
{
...
...
@@ -296,6 +296,6 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
return
numActIn
;
}
}
// namespace spconv
}
// namespace spconv
#endif
mmdet3d/ops/spconv/include/spconv/indice.cu.h
View file @
f27d308f
...
...
@@ -14,9 +14,9 @@
#ifndef INDICE_CU_H_
#define INDICE_CU_H_
#include <tensorview/tensorview.h>
#include <tensorview/helper_kernel.cu.h>
#include <spconv/geometry.h>
#include <tensorview/helper_kernel.cu.h>
#include <tensorview/tensorview.h>
namespace
spconv
{
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
,
...
...
@@ -115,7 +115,6 @@ __global__ void assignGridAndIndiceOutKernel(
int
numAct
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
int
batchSize
)
{
Index
index
;
auto
indicesOutPtr
=
indicesOut
.
data
();
for
(
int
ix
:
tv
::
KernelLoopX
<
int
>
(
numAct
))
{
...
...
@@ -128,13 +127,11 @@ __global__ void assignGridAndIndiceOutKernel(
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
__global__
void
assignIndicePairsKernel
(
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
int
numActIn
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
)
{
__global__
void
assignIndicePairsKernel
(
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
int
numActIn
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
)
{
Index
index
;
int
kernelVolume
=
indicePairs
.
dim
(
0
);
for
(
int
ix
:
tv
::
KernelLoopX
<
int
>
(
numActIn
))
{
...
...
@@ -148,10 +145,9 @@ assignIndicePairsKernel(tv::TensorView<Index> indicesOut,
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
__global__
void
prepareSubMGridKernel
(
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
)
{
__global__
void
prepareSubMGridKernel
(
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
)
{
auto
numActIn
=
indicesIn
.
dim
(
0
);
Index
spatialVolume
=
1
;
#pragma unroll
...
...
@@ -216,10 +212,9 @@ __global__ void resetGridKernel(const Index *indicePairUnique,
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
__global__
void
resetGridSubMKernel
(
const
Index
*
indices
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
int
numAct
)
{
__global__
void
resetGridSubMKernel
(
const
Index
*
indices
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
int
numAct
)
{
int
outSpatialShapeReg
[
NDim
];
for
(
int
i
=
0
;
i
<
NDim
;
++
i
)
{
outSpatialShapeReg
[
i
]
=
outSpatialShape
[
i
];
...
...
@@ -238,6 +233,6 @@ resetGridSubMKernel(const Index *indices, tv::TensorView<IndexGrid> gridsOut,
}
}
}
// namespace spconv
}
// namespace spconv
#endif
mmdet3d/ops/spconv/include/spconv/indice.h
View file @
f27d308f
...
...
@@ -16,64 +16,65 @@
#define SPARSE_CONV_INDICE_FUNCTOR_H_
#include <tensorview/tensorview.h>
namespace
spconv
{
namespace
functor
{
namespace
spconv
{
namespace
functor
{
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateConvIndicePairFunctorP1
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
);
struct
CreateConvIndicePairFunctorP1
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
);
};
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateConvIndicePairFunctorP2
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indice
sOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
struct
CreateConvIndicePairFunctorP2
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indice
Pairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
};
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateConvIndicePairFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
struct
CreateConvIndicePairFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
};
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateSubMIndicePairFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
struct
CreateSubMIndicePairFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
);
};
}
// namespace functor
}
// namespace spconv
}
// namespace functor
}
// namespace spconv
#endif
mmdet3d/ops/spconv/include/spconv/maxpool.h
View file @
f27d308f
...
...
@@ -16,29 +16,24 @@
#define SPARSE_MAXPOOL_FUNCTOR_H_
#include <tensorview/tensorview.h>
namespace
spconv
{
namespace
functor
{
namespace
spconv
{
namespace
functor
{
template
<
typename
Device
,
typename
T
,
typename
Index
>
struct
SparseMaxPoolForwardFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
T
>
outFeatures
,
struct
SparseMaxPoolForwardFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
T
>
outFeatures
,
tv
::
TensorView
<
const
T
>
inFeatures
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
);
};
template
<
typename
Device
,
typename
T
,
typename
Index
>
struct
SparseMaxPoolBackwardFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
T
>
outFeatures
,
struct
SparseMaxPoolBackwardFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
T
>
outFeatures
,
tv
::
TensorView
<
const
T
>
inFeatures
,
tv
::
TensorView
<
const
T
>
dout
,
tv
::
TensorView
<
T
>
din
,
tv
::
TensorView
<
const
T
>
dout
,
tv
::
TensorView
<
T
>
din
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
);
};
}
// namespace functor
}
// namespace spconv
}
// namespace functor
}
// namespace spconv
#endif
mmdet3d/ops/spconv/include/spconv/mp_helper.h
View file @
f27d308f
...
...
@@ -4,7 +4,8 @@
#include <utility>
namespace
spconv
{
template
<
class
...
T
>
struct
mp_list
{};
template
<
class
...
T
>
struct
mp_list
{};
template
<
class
T
,
T
...
I
>
using
mp_list_c
=
mp_list
<
std
::
integral_constant
<
T
,
I
>
...
>
;
...
...
@@ -16,15 +17,17 @@ constexpr F mp_for_each_impl(mp_list<T...>, F &&f) {
return
std
::
initializer_list
<
int
>
{(
f
(
T
()),
0
)...},
std
::
forward
<
F
>
(
f
);
}
template
<
class
F
>
constexpr
F
mp_for_each_impl
(
mp_list
<>
,
F
&&
f
)
{
template
<
class
F
>
constexpr
F
mp_for_each_impl
(
mp_list
<>
,
F
&&
f
)
{
return
std
::
forward
<
F
>
(
f
);
}
}
// namespace detail
}
// namespace detail
namespace
detail
{
template
<
class
A
,
template
<
class
...
>
class
B
>
struct
mp_rename_impl
{
template
<
class
A
,
template
<
class
...
>
class
B
>
struct
mp_rename_impl
{
// An error "no type named 'type'" here means that the first argument to
// mp_rename is not a list
};
...
...
@@ -34,14 +37,15 @@ struct mp_rename_impl<A<T...>, B> {
using
type
=
B
<
T
...
>
;
};
}
// namespace detail
}
// namespace detail
template
<
class
A
,
template
<
class
...
>
class
B
>
using
mp_rename
=
typename
detail
::
mp_rename_impl
<
A
,
B
>::
type
;
template
<
class
L
,
class
F
>
constexpr
F
mp_for_each
(
F
&&
f
)
{
template
<
class
L
,
class
F
>
constexpr
F
mp_for_each
(
F
&&
f
)
{
return
detail
::
mp_for_each_impl
(
mp_rename
<
L
,
mp_list
>
(),
std
::
forward
<
F
>
(
f
));
}
}
// namespace spconv
}
// namespace spconv
#endif
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment