"src/vscode:/vscode.git/clone" did not exist on "f846d902cf6c5cbcd988eccb3409955415b988f0"
Unverified Commit c33d4ec1 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Support PAConv operation (#598)

* support knn query in QueryAndGroup

* add cuda implemented assign_scores op

* add unit test for paconv assign_score op

* refactor op

* add non-cuda & cuda version PAConv

* minor fix

* fix cuda-9.0 compatibility

* add weight init of paconv

* fix typos

* refactor paconv
parent cde515d5
...@@ -11,6 +11,7 @@ from .group_points import (GroupAll, QueryAndGroup, group_points, ...@@ -11,6 +11,7 @@ from .group_points import (GroupAll, QueryAndGroup, group_points,
from .interpolate import three_interpolate, three_nn from .interpolate import three_interpolate, three_nn
from .knn import knn from .knn import knn
from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d
from .paconv import PAConv, PAConvCUDA, assign_score_withk
from .pointnet_modules import (PointFPModule, PointSAModule, PointSAModuleMSG, from .pointnet_modules import (PointFPModule, PointSAModule, PointSAModuleMSG,
build_sa_module) build_sa_module)
from .roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_batch, from .roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_batch,
...@@ -30,6 +31,7 @@ __all__ = [ ...@@ -30,6 +31,7 @@ __all__ = [
'furthest_point_sample_with_dist', 'three_interpolate', 'three_nn', 'furthest_point_sample_with_dist', 'three_interpolate', 'three_nn',
'gather_points', 'grouping_operation', 'group_points', 'GroupAll', 'gather_points', 'grouping_operation', 'group_points', 'GroupAll',
'QueryAndGroup', 'PointSAModule', 'PointSAModuleMSG', 'PointFPModule', 'QueryAndGroup', 'PointSAModule', 'PointSAModuleMSG', 'PointFPModule',
'points_in_boxes_batch', 'get_compiler_version', 'points_in_boxes_batch', 'get_compiler_version', 'assign_score_withk',
'get_compiling_cuda_version', 'Points_Sampler', 'build_sa_module' 'get_compiling_cuda_version', 'Points_Sampler', 'build_sa_module',
'PAConv', 'PAConvCUDA'
] ]
...@@ -4,6 +4,7 @@ from torch.autograd import Function ...@@ -4,6 +4,7 @@ from torch.autograd import Function
from typing import Tuple from typing import Tuple
from ..ball_query import ball_query from ..ball_query import ball_query
from ..knn import knn
from . import group_points_ext from . import group_points_ext
...@@ -13,7 +14,8 @@ class QueryAndGroup(nn.Module): ...@@ -13,7 +14,8 @@ class QueryAndGroup(nn.Module):
Groups with a ball query of radius Groups with a ball query of radius
Args: Args:
max_radius (float): The maximum radius of the balls. max_radius (float | None): The maximum radius of the balls.
If None is given, we will use kNN sampling instead of ball query.
sample_num (int): Maximum number of features to gather in the ball. sample_num (int): Maximum number of features to gather in the ball.
min_radius (float): The minimum radius of the balls. min_radius (float): The minimum radius of the balls.
use_xyz (bool): Whether to use xyz. use_xyz (bool): Whether to use xyz.
...@@ -48,7 +50,12 @@ class QueryAndGroup(nn.Module): ...@@ -48,7 +50,12 @@ class QueryAndGroup(nn.Module):
self.uniform_sample = uniform_sample self.uniform_sample = uniform_sample
self.return_unique_cnt = return_unique_cnt self.return_unique_cnt = return_unique_cnt
if self.return_unique_cnt: if self.return_unique_cnt:
assert self.uniform_sample assert self.uniform_sample, \
'uniform_sample should be True when ' \
'returning the count of unique samples'
if self.max_radius is None:
assert not self.normalize_xyz, \
'can not normalize grouped xyz when max_radius is None'
def forward(self, points_xyz, center_xyz, features=None): def forward(self, points_xyz, center_xyz, features=None):
"""forward. """forward.
...@@ -61,8 +68,14 @@ class QueryAndGroup(nn.Module): ...@@ -61,8 +68,14 @@ class QueryAndGroup(nn.Module):
Return: Return:
Tensor: (B, 3 + C, npoint, sample_num) Grouped feature. Tensor: (B, 3 + C, npoint, sample_num) Grouped feature.
""" """
idx = ball_query(self.min_radius, self.max_radius, self.sample_num, # if self.max_radius is None, we will perform kNN instead of ball query
points_xyz, center_xyz) # idx is of shape [B, npoint, sample_num]
if self.max_radius is None:
idx = knn(self.sample_num, points_xyz, center_xyz, False)
idx = idx.transpose(1, 2).contiguous()
else:
idx = ball_query(self.min_radius, self.max_radius, self.sample_num,
points_xyz, center_xyz)
if self.uniform_sample: if self.uniform_sample:
unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
......
from .assign_score import assign_score_withk
from .paconv import PAConv, PAConvCUDA
__all__ = ['assign_score_withk', 'PAConv', 'PAConvCUDA']
from torch.autograd import Function
from . import assign_score_withk_ext
class AssignScoreWithK(Function):
r"""Perform weighted sum to generate output features according to scores.
Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
scene_seg/lib/paconv_lib/src/gpu>`_.
This is a memory-efficient CUDA implementation of assign_scores operation,
which first transform all point feature with weight bank, then assemble
neighbor features with `knn_idx` and perform weighted sum of `scores`.
See the `paper <https://arxiv.org/pdf/2103.14635.pdf>`_ appendix Sec. D for
more detailed descriptions.
Note:
This implementation assumes using ``neighbor`` kernel input, which is
(point_features - center_features, point_features).
See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/
pointnet2/paconv.py#L128 for more details.
"""
@staticmethod
def forward(ctx,
scores,
point_features,
center_features,
knn_idx,
aggregate='sum'):
"""Forward.
Args:
scores (torch.Tensor): (B, npoint, K, M), predicted scores to
aggregate weight matrices in the weight bank.
``npoint`` is the number of sampled centers.
``K`` is the number of queried neighbors.
``M`` is the number of weight matrices in the weight bank.
point_features (torch.Tensor): (B, N, M, out_dim)
Pre-computed point features to be aggregated.
center_features (torch.Tensor): (B, N, M, out_dim)
Pre-computed center features to be aggregated.
knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN.
We assume the first idx in each row is the idx of the center.
aggregate (str, optional): Aggregation method.
Can be 'sum', 'avg' or 'max'. Defaults to 'sum'.
Returns:
torch.Tensor: (B, out_dim, npoint, K), the aggregated features.
"""
agg = {'sum': 0, 'avg': 1, 'max': 2}
B, N, M, out_dim = point_features.size()
_, npoint, K, _ = scores.size()
output = point_features.new_zeros((B, out_dim, npoint, K))
assign_score_withk_ext.assign_score_withk_forward_wrapper(
B, N, npoint, M, K, out_dim, agg[aggregate],
point_features.contiguous(), center_features.contiguous(),
scores.contiguous(), knn_idx.contiguous(), output)
ctx.save_for_backward(output, point_features, center_features, scores,
knn_idx)
ctx.agg = agg[aggregate]
return output
@staticmethod
def backward(ctx, grad_out):
"""Backward.
Args:
grad_out (torch.Tensor): (B, out_dim, npoint, K)
Returns:
grad_scores (torch.Tensor): (B, npoint, K, M)
grad_point_features (torch.Tensor): (B, N, M, out_dim)
grad_center_features (torch.Tensor): (B, N, M, out_dim)
"""
_, point_features, center_features, scores, knn_idx = ctx.saved_tensors
agg = ctx.agg
B, N, M, out_dim = point_features.size()
_, npoint, K, _ = scores.size()
grad_point_features = point_features.new_zeros(point_features.shape)
grad_center_features = center_features.new_zeros(center_features.shape)
grad_scores = scores.new_zeros(scores.shape)
assign_score_withk_ext.assign_score_withk_backward_wrapper(
B, N, npoint, M, K, out_dim, agg, grad_out.contiguous(),
point_features.contiguous(), center_features.contiguous(),
scores.contiguous(), knn_idx.contiguous(), grad_point_features,
grad_center_features, grad_scores)
return grad_scores, grad_point_features, \
grad_center_features, None, None
assign_score_withk = AssignScoreWithK.apply
import copy
import torch
from mmcv.cnn import (ConvModule, build_activation_layer, build_norm_layer,
constant_init, xavier_init)
from torch import nn as nn
from torch.nn import functional as F
from .assign_score import assign_score_withk as assign_score_cuda
from .utils import assign_kernel_withoutk, assign_score, calc_euclidian_dist
class ScoreNet(nn.Module):
"""ScoreNet that outputs coefficient scores to assemble weight kernels in
the weight bank according to the relative position of point pairs.
Args:
mlp_channels (List[int]): Hidden unit sizes of SharedMLP layers.
last_bn (bool, optional): Whether to use BN on the last output of mlps.
Defaults to False.
score_norm (str, optional): Normalization function of output scores.
Can be 'softmax', 'sigmoid' or 'identity'. Defaults to 'softmax'.
temp_factor (float, optional): Temperature factor to scale the output
scores before softmax. Defaults to 1.0.
norm_cfg (dict, optional): Type of normalization method.
Defaults to dict(type='BN2d').
bias (bool | str, optional): If specified as `auto`, it will be decided
by the norm_cfg. Bias will be set as True if `norm_cfg` is None,
otherwise False. Defaults to 'auto'.
"""
def __init__(self,
mlp_channels,
last_bn=False,
score_norm='softmax',
temp_factor=1.0,
norm_cfg=dict(type='BN2d'),
bias='auto'):
super(ScoreNet, self).__init__()
assert score_norm in ['softmax', 'sigmoid', 'identity'], \
f'unsupported score_norm function {score_norm}'
self.score_norm = score_norm
self.temp_factor = temp_factor
self.mlps = nn.Sequential()
for i in range(len(mlp_channels) - 2):
self.mlps.add_module(
f'layer{i}',
ConvModule(
mlp_channels[i],
mlp_channels[i + 1],
kernel_size=(1, 1),
stride=(1, 1),
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg,
bias=bias))
# for the last mlp that outputs scores, no relu and possibly no bn
i = len(mlp_channels) - 2
self.mlps.add_module(
f'layer{i}',
ConvModule(
mlp_channels[i],
mlp_channels[i + 1],
kernel_size=(1, 1),
stride=(1, 1),
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg if last_bn else None,
act_cfg=None,
bias=bias))
def init_weights(self):
"""Initialize weights of shared MLP layers."""
# refer to https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/pointnet2/paconv.py#L105 # noqa
for m in self.mlps.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m)
def forward(self, xyz_features):
"""Forward.
Args:
xyz_features (torch.Tensor): (B, C, N, K), features constructed
from xyz coordinates of point pairs. May contain relative
positions, Euclidian distance, etc.
Returns:
torch.Tensor: (B, N, K, M), predicted scores for `M` kernels.
"""
scores = self.mlps(xyz_features) # (B, M, N, K)
# perform score normalization
if self.score_norm == 'softmax':
scores = F.softmax(scores / self.temp_factor, dim=1)
elif self.score_norm == 'sigmoid':
scores = torch.sigmoid(scores / self.temp_factor)
else: # 'identity'
scores = scores
scores = scores.permute(0, 2, 3, 1) # (B, N, K, M)
return scores
class PAConv(nn.Module):
"""Non-CUDA version of PAConv.
PAConv stores a trainable weight bank containing several weight kernels.
Given input points and features, it computes coefficient scores to assemble
those kernels to form conv kernels, and then runs convolution on the input.
Args:
in_channels (int): Input channels of point features.
out_channels (int): Output channels of point features.
num_kernels (int): Number of weight kernels in the weight bank.
norm_cfg (dict, optional): Type of normalization method.
Defaults to dict(type='BN2d', momentum=0.1).
act_cfg (dict, optional): Type of activation method.
Defaults to dict(type='ReLU', inplace=True).
scorenet_input (str, optional): Type of input to ScoreNet.
Can be 'identity', 'w_neighbor' or 'w_neighbor_dist'.
Defaults to 'w_neighbor_dist'.
weight_bank_init (str, optional): Init method of weight bank kernels.
Can be 'kaiming' or 'xavier'. Defaults to 'kaiming'.
kernel_input (str, optional): Input features to be multiplied with
weight kernels. Can be 'identity' or 'w_neighbor'.
Defaults to 'w_neighbor'.
scorenet_cfg (dict, optional): Config of the ScoreNet module, which
may contain the following keys and values:
- mlp_channels (List[int]): Hidden units of MLPs.
- score_norm (str): Normalization function of output scores.
Can be 'softmax', 'sigmoid' or 'identity'.
- temp_factor (float): Temperature factor to scale the output
scores before softmax.
- last_bn (bool): Whether to use BN on the last output of mlps.
"""
def __init__(self,
in_channels,
out_channels,
num_kernels,
norm_cfg=dict(type='BN2d', momentum=0.1),
act_cfg=dict(type='ReLU', inplace=True),
scorenet_input='w_neighbor_dist',
weight_bank_init='kaiming',
kernel_input='w_neighbor',
scorenet_cfg=dict(
mlp_channels=[8, 16, 16],
score_norm='softmax',
temp_factor=1.0,
last_bn=False)):
super(PAConv, self).__init__()
# determine weight kernel size according to used features
if kernel_input == 'identity':
# only use grouped_features
self.kernel_mul = 1
elif kernel_input == 'w_neighbor':
# concat of (grouped_features - center_features, grouped_features)
self.kernel_mul = 2
else:
raise NotImplementedError(
f'unsupported kernel_input {kernel_input}')
self.kernel_input = kernel_input
# determine mlp channels in ScoreNet according to used xyz features
if scorenet_input == 'identity':
# only use relative position (grouped_xyz - center_xyz)
self.scorenet_in_channels = 3
elif scorenet_input == 'w_neighbor':
# (grouped_xyz - center_xyz, grouped_xyz)
self.scorenet_in_channels = 6
elif scorenet_input == 'w_neighbor_dist':
# (center_xyz, grouped_xyz - center_xyz, Euclidian distance)
self.scorenet_in_channels = 7
else:
raise NotImplementedError(
f'unsupported scorenet_input {scorenet_input}')
self.scorenet_input = scorenet_input
# construct weight kernels in weight bank
# self.weight_bank is of shape [C, num_kernels * out_c]
# where C can be in_c or (2 * in_c)
if weight_bank_init == 'kaiming':
weight_init = nn.init.kaiming_normal_
elif weight_bank_init == 'xavier':
weight_init = nn.init.xavier_normal_
else:
raise NotImplementedError(
f'unsupported weight bank init method {weight_bank_init}')
self.m = num_kernels
weight_bank = weight_init(
torch.empty(self.m, in_channels * self.kernel_mul, out_channels))
weight_bank = weight_bank.permute(1, 0, 2).reshape(
in_channels * self.kernel_mul, self.m * out_channels).contiguous()
self.weight_bank = nn.Parameter(weight_bank, requires_grad=True)
# construct ScoreNet
scorenet_cfg_ = copy.deepcopy(scorenet_cfg)
scorenet_cfg_['mlp_channels'].insert(0, self.scorenet_in_channels)
scorenet_cfg_['mlp_channels'].append(self.m)
self.scorenet = ScoreNet(**scorenet_cfg_)
self.bn = build_norm_layer(norm_cfg, out_channels)[1] if \
norm_cfg is not None else None
self.activate = build_activation_layer(act_cfg) if \
act_cfg is not None else None
self.init_weights()
def init_weights(self):
"""Initialize weights of shared MLP layers."""
self.scorenet.init_weights()
if self.bn is not None:
constant_init(self.bn, val=1)
def _prepare_scorenet_input(self, points_xyz):
"""Prepare input point pairs features for self.ScoreNet.
Args:
points_xyz (torch.Tensor): (B, 3, npoint, K)
Coordinates of the grouped points.
Returns:
torch.Tensor: (B, C, npoint, K)
The generated features per point pair.
"""
B, _, npoint, K = points_xyz.size()
center_xyz = points_xyz[..., :1].repeat(1, 1, 1, K)
xyz_diff = points_xyz - center_xyz # [B, 3, npoint, K]
if self.scorenet_input == 'identity':
xyz_features = xyz_diff
elif self.scorenet_input == 'w_neighbor':
xyz_features = torch.cat((xyz_diff, points_xyz), dim=1)
else: # w_neighbor_dist
euclidian_dist = calc_euclidian_dist(
center_xyz.permute(0, 2, 3, 1).reshape(B * npoint * K, 3),
points_xyz.permute(0, 2, 3, 1).reshape(B * npoint * K, 3)).\
reshape(B, 1, npoint, K)
xyz_features = torch.cat((center_xyz, xyz_diff, euclidian_dist),
dim=1)
return xyz_features
def forward(self, points_xyz, features):
"""Forward.
Args:
points_xyz (torch.Tensor): (B, 3, npoint, K)
Coordinates of the grouped points.
features (torch.Tensor): (B, in_c, npoint, K)
Features of the queried points.
Returns:
torch.Tensor: (B, out_c, npoint, K), features after PAConv.
"""
B, _, npoint, K = features.size()
if self.kernel_input == 'w_neighbor':
center_features = features[..., :1].repeat(1, 1, 1, K)
features_diff = features - center_features
# to (B, 2 * in_c, npoint, K)
features = torch.cat((features_diff, features), dim=1)
# prepare features for between each point and its grouping center
xyz_features = self._prepare_scorenet_input(points_xyz)
# scores to assemble weight kernels
scores = self.scorenet(xyz_features) # [B, npoint, K, m]
# first compute out features over all kernels
# features is [B, C, npoint, K], weight_bank is [C, m * out_c]
new_features = torch.matmul(
features.permute(0, 2, 3, 1), self.weight_bank).\
view(B, npoint, K, self.m, -1) # [B, npoint, K, m, out_c]
# then aggregate using scores
new_features = assign_score(scores, new_features)
# to [B, out_c, npoint, K]
new_features = new_features.permute(0, 3, 1, 2).contiguous()
if self.bn is not None:
new_features = self.bn(new_features)
if self.activate is not None:
new_features = self.activate(new_features)
return new_features
class PAConvCUDA(PAConv):
"""CUDA version of PAConv that implements a cuda op to efficiently perform
kernel assembling.
Different from vanilla PAConv, the input features of this function is not
grouped by centers. Instead, they will be queried on-the-fly by the
additional input `points_idx`. This avoids the large intermediate matrix.
See the `paper <https://arxiv.org/pdf/2103.14635.pdf>`_ appendix Sec. D for
more detailed descriptions.
"""
def __init__(self,
in_channels,
out_channels,
num_kernels,
norm_cfg=dict(type='BN2d', momentum=0.1),
act_cfg=dict(type='ReLU', inplace=True),
scorenet_input='w_neighbor_dist',
weight_bank_init='kaiming',
kernel_input='w_neighbor',
scorenet_cfg=dict(
mlp_channels=[8, 16, 16],
score_norm='softmax',
temp_factor=1.0,
last_bn=False)):
super(PAConvCUDA, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
num_kernels=num_kernels,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
scorenet_input=scorenet_input,
weight_bank_init=weight_bank_init,
kernel_input=kernel_input,
scorenet_cfg=scorenet_cfg)
assert self.kernel_input == 'w_neighbor', \
'CUDA implemented PAConv only supports w_neighbor kernel_input'
def forward(self, points_xyz, features, points_idx):
"""Forward.
Args:
points_xyz (torch.Tensor): (B, 3, npoint, K)
Coordinates of the grouped points.
features (torch.Tensor): (B, in_c, N)
Features of all points in the current point cloud.
Different from `features` in non-CUDA version PAConv, here the
features are not grouped by each center to form a K dim.
points_idx (torch.Tensor): (B, npoint, K)
Index of the grouped points.
Returns:
torch.Tensor: (B, out_c, npoint, K), features after PAConv.
"""
# prepare features for between each point and its grouping center
xyz_features = self._prepare_scorenet_input(points_xyz)
# scores to assemble weight kernels
scores = self.scorenet(xyz_features) # [B, npoint, K, m]
# pre-compute features for points and centers separately
# features is [B, in_c, N], weight_bank is [C, m * out_dim]
point_feat, center_feat = assign_kernel_withoutk(
features, self.weight_bank, self.m)
# aggregate features using custom cuda op
new_features = assign_score_cuda(
scores, point_feat, center_feat, points_idx,
'sum').contiguous() # [B, out_c, npoint, K]
if self.bn is not None:
new_features = self.bn(new_features)
if self.activate is not None:
new_features = self.activate(new_features)
return new_features
// Modified from https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/paconv_lib/src/gpu
#include <torch/torch.h>
#include <torch/extension.h>
void assign_score_withk_forward_wrapper(
int B, int N0, int N1, int M,
int K, int O, int aggregate,
const at::Tensor& points,
const at::Tensor& centers,
const at::Tensor& scores,
const at::Tensor& knn_idx,
at::Tensor& output
);
void assign_score_withk_backward_wrapper(
int B, int N0, int N1, int M,
int K, int O, int aggregate,
const at::Tensor& grad_out,
const at::Tensor& points,
const at::Tensor& centers,
const at::Tensor& scores,
const at::Tensor& knn_idx,
at::Tensor& grad_points,
at::Tensor& grad_centers,
at::Tensor& grad_scores
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("assign_score_withk_forward_wrapper",
&assign_score_withk_forward_wrapper,
"Assign score kernel forward (GPU), save memory version");
m.def("assign_score_withk_backward_wrapper",
&assign_score_withk_backward_wrapper,
"Assign score kernel backward (GPU), save memory version");
}
// Modified from https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/paconv_lib/src/gpu
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <cmath>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/types.h>
#define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
#define CHECK_CONTIGUOUS(x) \
do { \
AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \
} while (0)
#define CUDA_CHECK_ERRORS() \
do { \
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \
fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
__FILE__); \
exit(-1); \
} \
} while (0)
// input: points(B,N0,M,O), centers(B,N0,M,O), scores(B,N1,K,M), knn_idx(B,N1,K)
// output: fout(B,O,N)
// algo: fout(b,i,k,j) = s(b,i,k,m)*p(b,c(i),k,m,j) = s(b,i,k,m)*p(b,i(k),m,j)
// i(k) = idx(b,i,k)
// sum: fout(b,i,j) = fout(b,i,j) + s(b,i,k,m)*p(b,i,k,m,j)
// avg: fout(b,i,j) = sum(fout(b,i,k,j)) / k
// max: fout(b,i,j) = max(fout(b,i,k,j), sum(s(b,i,k,m)*p(b,i,k,m,j)))
__global__ void assign_score_withk_forward_kernel(const int B, const int N0, const int N1,
const int M, const int K, const int O, const int aggregate,
const float* points,
const float* centers,
const float* scores,
const long* knn_idx,
float* output) {
// ----- parallel loop for B, N1, K and O ---------
long i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= B*N1*K*O) return;
// ------- loop for M ----------
for (int m = 0; m < M; m++) {
int b = (int)(i / (O * N1 * K));
int o = (int)(i % (O * N1 * K) / (N1 * K));
int n = (int)(i % (N1 * K) / K);
int k = (int)(i % K);
int cn = (int) knn_idx[b*K*N1 + n*K + 0]; //The first neighbor is the center point
int kn = (int) knn_idx[b*K*N1 + n*K + k];
if (kn >= N0 || kn < 0) { // if index overflows, it is out of the neighborhood range
continue;
}
assert (b < B);
assert (kn < N0);
assert (cn < N0);
assert (o < O);
assert (n < N1);
atomicAdd(output + b*N1*O*K + o*N1*K + n*K + k,
points[b*N0*M*O + kn*M*O + m*O + o] * scores[b*N1*K*M + n*K*M + k*M + m]
- centers[b*N0*M*O + cn*M*O + m*O + o] * scores[b*N1*K*M + n*K*M + k*M + m]);
}
}
__global__ void assign_score_withk_backward_points_kernel(const int B, const int N0, const int N, const int M,
const int K, const int O, const int aggregate,
const float* grad_out,
const float* scores,
const long* knn_idx,
float* grad_points,
float* grad_centers) {
// ----- parallel loop for B, M, O ---------
long i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= B*M*O) return;
int b = (int)(i / (M * O));
int m = (int)(i % (M * O) / O);
int o = (int)(i % O);
// ----- loop for N,K ---------
for (int n = 0; n < N; n++) {
for (int k = 0; k < K; k++) {
int kn = knn_idx[b*N*K + n*K + k];
int cn = knn_idx[b*N*K + n*K + 0];
if (kn >= N0 || kn < 0) { // if index overflows, it is out of the neighborhood range
continue;
}
atomicAdd(grad_points + b*N0*M*O + kn*M*O + m*O + o,
scores[b*N*K*M + n*K*M + k*M + m] * grad_out[b*O*N*K + o*N*K + n*K + k]);
atomicAdd(grad_centers + b*N0*M*O + cn*M*O + m*O + o,
- scores[b*N*K*M + n*K*M + k*M + m] * grad_out[b*O*N*K + o*N*K + n*K + k]);
}
}
}
__global__ void assign_score_withk_backward_scores_kernel(const int B, const int N0, const int N, const int M,
const int K, const int O, const int aggregate,
const float* grad_out,
const float* points,
const float* centers,
const long* knn_idx,
float* grad_scores) {
// ----- parallel loop for B, N, K, M ---------
long i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= B*N*K*M) return;
int b = (int)(i / (N * M * K));
int n = (int)(i % (N * M * K) / M / K);
int k = (int)(i % (M * K) / M);
int m = (int)(i % M);
int cn = knn_idx[b*N*K + n*K + 0];
int kn = knn_idx[b*N*K + n*K + k];
if (kn >= N0 || kn < 0) { // if index overflows, it is out of the neighborhood range
return;
}
// -------------- loop for O ------------------------
for(int o = 0; o < O; o++) {
atomicAdd(grad_scores + b*N*K*M + n*K*M + k*M + m,
(points[b*N0*M*O + kn*M*O + m*O + o]
- centers[b*N0*M*O + cn*M*O + m*O + o])* grad_out[b*O*N*K + o*N*K + n*K + k]);
}
}
void assign_score_withk_forward_wrapper(int B, int N0, int N1, int M, int K, int O, int aggregate,
const at::Tensor& points,
const at::Tensor& centers,
const at::Tensor& scores,
const at::Tensor& knn_idx,
at::Tensor& output) {
CHECK_CONTIGUOUS(points);
CHECK_CONTIGUOUS(centers);
CHECK_CONTIGUOUS(scores);
CHECK_CONTIGUOUS(knn_idx);
CHECK_CONTIGUOUS(output);
const float* points_data = points.data_ptr<float>();
const float* centers_data = centers.data_ptr<float>();
const float* scores_data = scores.data_ptr<float>();
const long* knn_idx_data = knn_idx.data_ptr<long>();
float* output_data = output.data_ptr<float>();
dim3 blocks(DIVUP(B*O*N1*K, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);
assign_score_withk_forward_kernel<<<blocks, threads, 0>>>(
B, N0, N1, M, K, O, aggregate, points_data, centers_data, scores_data, knn_idx_data, output_data);
CUDA_CHECK_ERRORS();
}
void assign_score_withk_backward_wrapper(int B, int N0, int N1, int M, int K, int O, int aggregate,
const at::Tensor& grad_out,
const at::Tensor& points,
const at::Tensor& centers,
const at::Tensor& scores,
const at::Tensor& knn_idx,
at::Tensor& grad_points,
at::Tensor& grad_centers,
at::Tensor& grad_scores) {
CHECK_CONTIGUOUS(grad_out);
CHECK_CONTIGUOUS(scores);
CHECK_CONTIGUOUS(points);
CHECK_CONTIGUOUS(centers);
CHECK_CONTIGUOUS(knn_idx);
CHECK_CONTIGUOUS(grad_scores);
CHECK_CONTIGUOUS(grad_points);
CHECK_CONTIGUOUS(grad_centers);
const float* grad_out_data = grad_out.data_ptr<float>();
const float* points_data = points.data_ptr<float>();
const float* centers_data = centers.data_ptr<float>();
const float* scores_data = scores.data_ptr<float>();
const long* knn_idx_data = knn_idx.data_ptr<long>();
float* grad_points_data = grad_points.data_ptr<float>();
float* grad_centers_data = grad_centers.data_ptr<float>();
float* grad_scores_data = grad_scores.data_ptr<float>();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks1(DIVUP(B*M*O, THREADS_PER_BLOCK));
dim3 threads1(THREADS_PER_BLOCK);
dim3 blocks2(DIVUP(B*N1*K*M, THREADS_PER_BLOCK));
dim3 threads2(THREADS_PER_BLOCK);
assign_score_withk_backward_points_kernel<<<blocks1, threads1, 0>>>(
B, N0, N1, M, K, O, aggregate, grad_out_data, scores_data, knn_idx_data, grad_points_data, grad_centers_data);
assign_score_withk_backward_scores_kernel<<<blocks2, threads2, 0>>>(
B, N0, N1, M, K, O, aggregate, grad_out_data, points_data, centers_data, knn_idx_data, grad_scores_data);
CUDA_CHECK_ERRORS();
}
import torch
def calc_euclidian_dist(xyz1, xyz2):
"""Calculate the Euclidian distance between two sets of points.
Args:
xyz1 (torch.Tensor): (N, 3), the first set of points.
xyz2 (torch.Tensor): (N, 3), the second set of points.
Returns:
torch.Tensor: (N, ), the Euclidian distance between each point pair.
"""
assert xyz1.shape[0] == xyz2.shape[0], 'number of points are not the same'
assert xyz1.shape[1] == xyz2.shape[1] == 3, \
'points coordinates dimension is not 3'
return torch.norm(xyz1 - xyz2, dim=-1)
def assign_score(scores, point_features):
"""Perform weighted sum to aggregate output features according to scores.
This function is used in non-CUDA version of PAConv.
Compared to the cuda op assigh_score_withk, this pytorch implementation
pre-computes output features for the neighbors of all centers, and then
performs aggregation. It consumes more GPU memories.
Args:
scores (torch.Tensor): (B, npoint, K, M), predicted scores to
aggregate weight matrices in the weight bank.
`npoint` is the number of sampled centers.
`K` is the number of queried neighbors.
`M` is the number of weight matrices in the weight bank.
point_features (torch.Tensor): (B, npoint, K, M, out_dim)
Pre-computed point features to be aggregated.
Returns:
torch.Tensor: (B, npoint, K, out_dim), the aggregated features.
"""
B, npoint, K, M = scores.size()
scores = scores.view(B, npoint, K, 1, M)
output = torch.matmul(scores, point_features).view(B, npoint, K, -1)
return output
def assign_kernel_withoutk(features, kernels, M):
"""Pre-compute features with weight matrices in weight bank. This function
is used before cuda op assign_score_withk in CUDA version PAConv.
Args:
features (torch.Tensor): (B, in_dim, N), input features of all points.
`N` is the number of points in current point cloud.
kernels (torch.Tensor): (2 * in_dim, M * out_dim), weight matrices in
the weight bank, transformed from (M, 2 * in_dim, out_dim).
`2 * in_dim` is because the input features are concatenation of
(point_features - center_features, point_features).
M (int): Number of weight matrices in the weight bank.
Returns:
Tuple[torch.Tensor]: both of shape (B, N, M, out_dim)
point_features: Pre-computed features for points.
center_features: Pre-computed features for centers.
"""
B, in_dim, N = features.size()
feat_trans = features.permute(0, 2, 1) # [B, N, in_dim]
out_feat_half1 = torch.matmul(feat_trans, kernels[:in_dim]).view(
B, N, M, -1) # [B, N, M, out_dim]
out_feat_half2 = torch.matmul(feat_trans, kernels[in_dim:]).view(
B, N, M, -1) # [B, N, M, out_dim]
# TODO: why this hard-coded if condition?
# when the network input is only xyz without additional features
# xyz will be used as features, so that features.size(1) == 3 % 2 != 0
# we need to compensate center_features because otherwise
# `point_features - center_features` will result in all zeros?
if features.size(1) % 2 != 0:
out_feat_half_coord = torch.matmul(
feat_trans[:, :, :3], # [B, N, 3]
kernels[in_dim:in_dim + 3]).view(B, N, M, -1) # [B, N, M, out_dim]
else:
out_feat_half_coord = torch.zeros_like(out_feat_half2)
point_features = out_feat_half1 + out_feat_half2
center_features = out_feat_half1 + out_feat_half_coord
return point_features, center_features
...@@ -227,6 +227,11 @@ if __name__ == '__main__': ...@@ -227,6 +227,11 @@ if __name__ == '__main__':
module='mmdet3d.ops.knn', module='mmdet3d.ops.knn',
sources=['src/knn.cpp'], sources=['src/knn.cpp'],
sources_cuda=['src/knn_cuda.cu']), sources_cuda=['src/knn_cuda.cu']),
make_cuda_ext(
name='assign_score_withk_ext',
module='mmdet3d.ops.paconv',
sources=['src/assign_score_withk.cpp'],
sources_cuda=['src/assign_score_withk_cuda.cu']),
make_cuda_ext( make_cuda_ext(
name='group_points_ext', name='group_points_ext',
module='mmdet3d.ops.group_points', module='mmdet3d.ops.group_points',
......
import pytest
import torch
from mmdet3d.ops import PAConv, PAConvCUDA, assign_score_withk
def test_paconv_assign_scores():
if not torch.cuda.is_available():
pytest.skip()
scores = torch.tensor([[[[0.06947571, 0.6065746], [0.28462553, 0.8378516],
[0.7595994, 0.97220325], [0.519155, 0.766185]],
[[0.15348864, 0.6051019], [0.21510637, 0.31916398],
[0.00236845, 0.5842595], [0.6783676, 0.5216348]]],
[[[0.23089725, 0.5568468], [0.7405102, 0.06438422],
[0.6887394, 0.22089851], [0.0502342, 0.79228795]],
[[0.44883424, 0.15427643],
[0.13817799, 0.34856772], [0.7989621, 0.33788306],
[0.15699774, 0.7693662]]]]).float().cuda()
scores.requires_grad_()
points = torch.tensor([[[[0.06001121, 0.92963666, 0.5753327, 0.7251477],
[0.53563064, 0.23129565, 0.92366195, 0.44261628]],
[[0.5770022, 0.56625944, 0.23560429, 0.11178821],
[0.7735967, 0.95678777, 0.25468266, 0.02895975]],
[[0.0589869, 0.09017515, 0.5977862, 0.02797985],
[0.603862, 0.35991007, 0.85761684, 0.3096559]],
[[0.22359002, 0.13983732, 0.5544243, 0.68863827],
[0.85646236, 0.75651926, 0.8638947, 0.83600986]],
[[0.45424145, 0.27458847, 0.6456112, 0.47162914],
[0.15773582, 0.47645122, 0.79964715, 0.3323908]],
[[0.8351399, 0.84696376, 0.9431732, 0.29418713],
[0.77168906, 0.6996871, 0.19354361, 0.03392768]],
[[0.30976456, 0.7074133, 0.581795, 0.976677],
[0.69656056, 0.07199162, 0.4708506, 0.29117996]],
[[0.5829035, 0.30201727, 0.76556486, 0.0935446],
[0.88030535, 0.16129416, 0.9242525, 0.49545723]]],
[[[0.50899494, 0.06482804, 0.44939405, 0.37704808],
[0.47028124, 0.11969638, 0.62823206, 0.28560323]],
[[0.40690207, 0.689753, 0.51636654, 0.23040164],
[0.06935787, 0.00488842, 0.22462702, 0.09182382]],
[[0.26611632, 0.00184339, 0.7730655, 0.5228131],
[0.87776035, 0.77895886, 0.2787183, 0.16620636]],
[[0.502574, 0.04039001, 0.5368497, 0.98379374],
[0.40973026, 0.3238272, 0.9733018, 0.13988364]],
[[0.04586202, 0.20983845, 0.20662665, 0.22270602],
[0.60387236, 0.5155574, 0.51237285, 0.6528438]],
[[0.45735973, 0.86821306, 0.61054605, 0.8370336],
[0.45193362, 0.3734138, 0.7825672, 0.5699416]],
[[0.44591594, 0.12447512, 0.09282011, 0.7055254],
[0.25223452, 0.46696228, 0.7051136, 0.892151]],
[[0.49615085, 0.47321403, 0.93138885, 0.7652197],
[0.38766378, 0.30332977, 0.23131835,
0.02863514]]]]).float().cuda()
points.requires_grad_()
centers = torch.tensor([[[[0.83878064, 0.96658987, 0.8033424, 0.9598312],
[0.45035273, 0.8768925, 0.977736, 0.54547966]],
[[0.01041394, 0.597893, 0.36212963, 0.4410367],
[0.94879234, 0.8372817, 0.21237361, 0.67945415]],
[[0.5096087, 0.26401454, 0.60034937, 0.5417416],
[0.87591463, 0.546456, 0.4096033, 0.16373193]],
[[0.79547447, 0.1482386, 0.12840575, 0.45384115],
[0.5640288, 0.944541, 0.5745328, 0.73229736]],
[[0.93011934, 0.7406011, 0.62621707, 0.8677915],
[0.91563636, 0.3595413, 0.6678378, 0.6085383]],
[[0.22431666, 0.65617776, 0.7483924, 0.6263364],
[0.30968404, 0.78204364, 0.14899081,
0.09628749]],
[[0.73675203, 0.72104895, 0.4648038, 0.6101647],
[0.7817645, 0.16572917, 0.3311919, 0.43407398]],
[[0.8193154, 0.09559608, 0.05978829, 0.90262103],
[0.4256065, 0.8165596, 0.8206446, 0.6604721]]],
[[[0.7159653, 0.18600845, 0.21433902, 0.3159626],
[0.3921569, 0.33221376, 0.5061177, 0.7961841]],
[[0.95338356, 0.04785997, 0.67185795, 0.6538394],
[0.4729132, 0.33404195, 0.17750603, 0.8445621]],
[[0.6755793, 0.16193843, 0.75943846, 0.92123103],
[0.2781859, 0.03114432, 0.710638, 0.52729136]],
[[0.8376105, 0.10858494, 0.13208169, 0.365772],
[0.5930795, 0.27390373, 0.14036089, 0.170403]],
[[0.3479789, 0.89855295, 0.04844379, 0.9871029],
[0.29781651, 0.0244137, 0.9179047, 0.8081611]],
[[0.12460887, 0.44991326, 0.19382608, 0.35037738],
[0.2773472, 0.4362057, 0.36757517, 0.5993509]],
[[0.29630446, 0.90046406, 0.5417113, 0.13510644],
[0.09623539, 0.04226565, 0.32001644,
0.44358212]],
[[0.5274848, 0.82096446, 0.9415489, 0.7123748],
[0.7537517, 0.8086482, 0.85345286,
0.7472754]]]]).float().cuda()
centers.requires_grad_()
knn_idx = torch.tensor([[[6, 7, 4, 6], [2, 4, 2, 4]],
[[7, 1, 3, 2], [6, 0, 2, 6]]]).long().cuda()
aggregate = 'sum'
expected_output = torch.tensor(
[[[[-0.08134781, 0.03877336, -0.8212776, -0.2869547],
[-0.23378491, -0.24112664, -0.1600166, -0.4121864]],
[[-0.05780616, -0.12298299, -0.0370461, -0.07889931],
[-0.13956165, -0.02006848, -0.10940295, -0.0293439]],
[[0.09284145, 0.58250105, 0.5927749, 0.16774094],
[0.27070042, 0.13422406, 0.2617501, 0.23416464]],
[[-0.06121218, -0.09561322, -0.20408826, 0.08079343],
[0.00944228, 0.03874819, 0.08404065, 0.04041629]]],
[[[-0.2110898, -0.13335688, -0.09315082, 0.08512095],
[0.09121774, 0.15976946, 0.23994486, 0.14350912]],
[[-0.36167958, -0.14891288, -0.64470863, -0.0646704],
[-0.28276974, -0.08847666, -0.46904767, 0.20491874]],
[[-0.34877953, -0.35533834, -0.25225785, -0.4638189],
[-0.1420663, 0.09467781, 0.17088932, 0.22580585]],
[[-0.3879708, -0.3991068, 0.05276498, -0.46989647],
[0.32522714, -0.02163534, 0.21604237, 0.4346682]]]]).float()
# test forward
output = assign_score_withk(scores, points, centers, knn_idx, aggregate)
assert torch.allclose(output.detach().cpu(), expected_output, atol=1e-6)
# test backward
loss = output.sum()
loss.backward()
expected_scores_grad = torch.tensor([[[[0.04288036, -0.18217683],
[-0.78873926, 0.7485497],
[-0.6866992, 0.05346543],
[0.04288036, -0.18217683]],
[[-1.1407862, 0.13533896],
[-0.06964391, -0.22948086],
[-1.1407862, 0.13533896],
[-0.06964391, -0.22948086]]],
[[[-0.3363995, -2.212181],
[-1.1589496, -2.7724311],
[-0.9387654, -1.3163853],
[-1.4385346, -1.0614843]],
[[-0.5048497, 1.4143617],
[-0.47332114, 0.6017133],
[-0.30974793, 1.1995442],
[-0.5048497, 1.4143617]]]]).float()
expected_points_grad = torch.tensor(
[[[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0.15585709, 0.15585709, 0.15585709, 0.15585709],
[1.1893613, 1.1893613, 1.1893613, 1.1893613]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[1.6530733, 1.6530733, 1.6530733, 1.6530733],
[1.8130021, 1.8130021, 1.8130021, 1.8130021]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0.58863074, 0.58863074, 0.58863074, 0.58863074],
[1.3727596, 1.3727596, 1.3727596, 1.3727596]],
[[0.28462553, 0.28462553, 0.28462553, 0.28462553],
[0.8378516, 0.8378516, 0.8378516, 0.8378516]]],
[[[0.13817799, 0.13817799, 0.13817799, 0.13817799],
[0.34856772, 0.34856772, 0.34856772, 0.34856772]],
[[0.7405102, 0.7405102, 0.7405102, 0.7405102],
[0.06438422, 0.06438422, 0.06438422, 0.06438422]],
[[0.8491963, 0.8491963, 0.8491963, 0.8491963],
[1.1301711, 1.1301711, 1.1301711, 1.1301711]],
[[0.6887394, 0.6887394, 0.6887394, 0.6887394],
[0.22089851, 0.22089851, 0.22089851, 0.22089851]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0.605832, 0.605832, 0.605832, 0.605832],
[0.92364264, 0.92364264, 0.92364264, 0.92364264]],
[[0.23089725, 0.23089725, 0.23089725, 0.23089725],
[0.5568468, 0.5568468, 0.5568468, 0.5568468]]]]).float()
expected_centers_grad = torch.tensor(
[[[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[-1.0493311, -1.0493311, -1.0493311, -1.0493311],
[-2.0301602, -2.0301602, -2.0301602, -2.0301602]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[-1.6328557, -1.6328557, -1.6328557, -1.6328557],
[-3.1828144, -3.1828144, -3.1828144, -3.1828144]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]]],
[[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[-1.5429721, -1.5429721, -1.5429721, -1.5429721],
[-1.6100934, -1.6100934, -1.6100934, -1.6100934]],
[[-1.7103812, -1.7103812, -1.7103812, -1.7103812],
[-1.6344175, -1.6344175, -1.6344175, -1.6344175]]]]).float()
assert torch.allclose(
scores.grad.detach().cpu(), expected_scores_grad, atol=1e-6)
assert torch.allclose(
points.grad.detach().cpu(), expected_points_grad, atol=1e-6)
assert torch.allclose(
centers.grad.detach().cpu(), expected_centers_grad, atol=1e-6)
def test_paconv():
B = 2
in_channels = 6
out_channels = 12
npoint = 4
K = 3
points_xyz = torch.randn(B, 3, npoint, K)
features = torch.randn(B, in_channels, npoint, K)
paconv = PAConv(in_channels, out_channels, 4)
with torch.no_grad():
new_features = paconv(points_xyz, features)
assert new_features.shape == torch.Size([B, out_channels, npoint, K])
def test_paconv_cuda():
if not torch.cuda.is_available():
pytest.skip()
B = 2
in_channels = 6
out_channels = 12
N = 32
npoint = 4
K = 3
points_xyz = torch.randn(B, 3, npoint, K).float().cuda()
features = torch.randn(B, in_channels, N).float().cuda()
points_idx = torch.randint(0, N, (B, npoint, K)).long().cuda()
paconv = PAConvCUDA(in_channels, out_channels, 4).cuda()
with torch.no_grad():
new_features = paconv(points_xyz, features, points_idx)
assert new_features.shape == torch.Size([B, out_channels, npoint, K])
...@@ -165,6 +165,33 @@ def test_pointnet_sa_module(): ...@@ -165,6 +165,33 @@ def test_pointnet_sa_module():
assert new_features.shape == torch.Size([1, 32, 16]) assert new_features.shape == torch.Size([1, 32, 16])
assert inds.shape == torch.Size([1, 16]) assert inds.shape == torch.Size([1, 16])
# can't set normalize_xyz when radius is None
with pytest.raises(AssertionError):
sa_cfg = dict(
type='PointSAModule',
num_point=16,
radius=None,
num_sample=8,
mlp_channels=[12, 32],
norm_cfg=dict(type='BN2d'),
use_xyz=True,
pool_mod='max',
normalize_xyz=True)
self = build_sa_module(sa_cfg)
# test kNN sampling when radius is None
sa_cfg['normalize_xyz'] = False
self = build_sa_module(sa_cfg).cuda()
xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32)
xyz = torch.from_numpy(xyz[..., :3]).view(1, -1, 3).cuda()
features = xyz.repeat([1, 1, 4]).transpose(1, 2).contiguous().cuda()
new_xyz, new_features, inds = self(xyz, features)
assert new_xyz.shape == torch.Size([1, 16, 3])
assert new_features.shape == torch.Size([1, 32, 16])
assert inds.shape == torch.Size([1, 16])
def test_pointnet_fp_module(): def test_pointnet_fp_module():
if not torch.cuda.is_available(): if not torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment