Unverified Commit e99bfce6 authored by dingchang's avatar dingchang Committed by GitHub
Browse files

[Feature] Add assign_score_withk op for paconv from mmdet3d (#1357)



* add ops (assign_score_withk) in mmdet3d

* add ops (assign_score_withk) in mmdet3d

* refactor code

* fix typo

* refactor code

* fix typo

* polish docstring
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent be5841e4
......@@ -2,6 +2,7 @@
We implement common CUDA ops used in detection, segmentation, etc.
- AssignScoreWithK
- BallQuery
- BBoxOverlaps
- CARAFE
......
......@@ -2,6 +2,7 @@
MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- AssignScoreWithK
- BallQuery
- BBoxOverlaps
- CARAFE
......
# Copyright (c) OpenMMLab. All rights reserved.
from .assign_score_withk import assign_score_withk
from .ball_query import ball_query
from .bbox import bbox_overlaps
from .border_align import BorderAlign, border_align
......@@ -58,11 +59,11 @@ __all__ = [
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'knn', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
'fused_bias_leakyrelu', 'RoIAlignRotated', 'roi_align_rotated',
'pixel_group', 'contour_expand', 'three_nn', 'three_interpolate',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'gather_points', 'furthest_point_sample',
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
'box_iou_rotated', 'nms_rotated', 'knn', 'ball_query', 'upfirdn2d',
'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', 'RoIAlignRotated',
'roi_align_rotated', 'pixel_group', 'contour_expand', 'three_nn',
'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
'border_align', 'gather_points', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation'
]
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['assign_score_withk_forward', 'assign_score_withk_backward'])
class AssignScoreWithK(Function):
r"""Perform weighted sum to generate output features according to scores.
Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
scene_seg/lib/paconv_lib/src/gpu>`_.
This is a memory-efficient CUDA implementation of assign_scores operation,
which first transform all point features with weight bank, then assemble
neighbor features with ``knn_idx`` and perform weighted sum of ``scores``.
See the `paper <https://arxiv.org/pdf/2103.14635.pdf>`_ appendix Sec. D for
more detailed descriptions.
Note:
This implementation assumes using ``neighbor`` kernel input, which is
(point_features - center_features, point_features).
See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/
pointnet2/paconv.py#L128 for more details.
"""
@staticmethod
def forward(ctx,
scores,
point_features,
center_features,
knn_idx,
aggregate='sum'):
"""
Args:
scores (torch.Tensor): (B, npoint, K, M), predicted scores to
aggregate weight matrices in the weight bank.
``npoint`` is the number of sampled centers.
``K`` is the number of queried neighbors.
``M`` is the number of weight matrices in the weight bank.
point_features (torch.Tensor): (B, N, M, out_dim)
Pre-computed point features to be aggregated.
center_features (torch.Tensor): (B, N, M, out_dim)
Pre-computed center features to be aggregated.
knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN.
We assume the first idx in each row is the idx of the center.
aggregate (str, optional): Aggregation method.
Can be 'sum', 'avg' or 'max'. Defaults: 'sum'.
Returns:
torch.Tensor: (B, out_dim, npoint, K), the aggregated features.
"""
agg = {'sum': 0, 'avg': 1, 'max': 2}
B, N, M, out_dim = point_features.size()
_, npoint, K, _ = scores.size()
output = point_features.new_zeros((B, out_dim, npoint, K))
ext_module.assign_score_withk_forward(B, N, npoint, M, K, out_dim,
agg[aggregate],
point_features.contiguous(),
center_features.contiguous(),
scores.contiguous(),
knn_idx.contiguous(), output)
ctx.save_for_backward(output, point_features, center_features, scores,
knn_idx)
ctx.agg = agg[aggregate]
return output
@staticmethod
def backward(ctx, grad_out):
"""
Args:
grad_out (torch.Tensor): (B, out_dim, npoint, K)
Returns:
grad_scores (torch.Tensor): (B, npoint, K, M)
grad_point_features (torch.Tensor): (B, N, M, out_dim)
grad_center_features (torch.Tensor): (B, N, M, out_dim)
"""
_, point_features, center_features, scores, knn_idx = ctx.saved_tensors
agg = ctx.agg
B, N, M, out_dim = point_features.size()
_, npoint, K, _ = scores.size()
grad_point_features = point_features.new_zeros(point_features.shape)
grad_center_features = center_features.new_zeros(center_features.shape)
grad_scores = scores.new_zeros(scores.shape)
ext_module.assign_score_withk_backward(B, N, npoint, M, K, out_dim,
agg, grad_out.contiguous(),
point_features.contiguous(),
center_features.contiguous(),
scores.contiguous(),
knn_idx.contiguous(),
grad_point_features,
grad_center_features,
grad_scores)
return grad_scores, grad_point_features, \
grad_center_features, None, None
assign_score_withk = AssignScoreWithK.apply
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ASSIGN_SCORE_WITHK_CUDA_KERNEL_CUH
#define ASSIGN_SCORE_WITHK_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
// input: points(B,N0,M,O), centers(B,N0,M,O), scores(B,N1,K,M), knn_idx(B,N1,K)
// output: fout(B,O,N)
// algo: fout(b,i,k,j) = s(b,i,k,m)*p(b,c(i),k,m,j) = s(b,i,k,m)*p(b,i(k),m,j)
// i(k) = idx(b,i,k)
// sum: fout(b,i,j) = fout(b,i,j) + s(b,i,k,m)*p(b,i,k,m,j)
// avg: fout(b,i,j) = sum(fout(b,i,k,j)) / k
// max: fout(b,i,j) = max(fout(b,i,k,j), sum(s(b,i,k,m)*p(b,i,k,m,j)))
template <typename T>
__global__ void assign_score_withk_forward_cuda_kernel(
const int B, const int N0, const int N1, const int M, const int K,
const int O, const int aggregate, const T* points, const T* centers,
const T* scores, const int64_t* knn_idx, T* output) {
// ----- parallel loop for B, N1, K and O ---------
long i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= B * N1 * K * O) return;
// ------- loop for M ----------
const int b = (int)(i / (O * N1 * K));
const int o = (int)(i % (O * N1 * K) / (N1 * K));
const int n = (int)(i % (N1 * K) / K);
const int k = (int)(i % K);
const int cn = (int)knn_idx[b * K * N1 + n * K +
0]; // The first neighbor is the center point
const int kn = (int)knn_idx[b * K * N1 + n * K + k];
if (kn >= N0 ||
kn < 0) { // if index overflows, it is out of the neighborhood range
return;
}
assert(b < B);
assert(kn < N0);
assert(cn < N0);
assert(o < O);
assert(n < N1);
const int out_idx = b * N1 * O * K + o * N1 * K + n * K + k;
T val = output[out_idx];
for (int m = 0; m < M; m++) {
val += points[b * N0 * M * O + kn * M * O + m * O + o] *
scores[b * N1 * K * M + n * K * M + k * M + m] -
centers[b * N0 * M * O + cn * M * O + m * O + o] *
scores[b * N1 * K * M + n * K * M + k * M + m];
}
output[out_idx] = val;
}
template <typename T>
__global__ void assign_score_withk_points_backward_cuda_kernel(
const int B, const int N0, const int N, const int M, const int K,
const int O, const int aggregate, const T* grad_out, const T* scores,
const int64_t* knn_idx, T* grad_points, T* grad_centers) {
// ----- parallel loop for B, M, O ---------
long i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= B * M * O) return;
int b = (int)(i / (M * O));
int m = (int)(i % (M * O) / O);
int o = (int)(i % O);
// ----- loop for N,K ---------
for (int n = 0; n < N; n++) {
for (int k = 0; k < K; k++) {
int kn = knn_idx[b * N * K + n * K + k];
int cn = knn_idx[b * N * K + n * K + 0];
if (kn >= N0 ||
kn < 0) { // if index overflows, it is out of the neighborhood range
continue;
}
atomicAdd(grad_points + b * N0 * M * O + kn * M * O + m * O + o,
scores[b * N * K * M + n * K * M + k * M + m] *
grad_out[b * O * N * K + o * N * K + n * K + k]);
atomicAdd(grad_centers + b * N0 * M * O + cn * M * O + m * O + o,
-scores[b * N * K * M + n * K * M + k * M + m] *
grad_out[b * O * N * K + o * N * K + n * K + k]);
}
}
}
template <typename T>
__global__ void assign_score_withk_scores_backward_cuda_kernel(
const int B, const int N0, const int N, const int M, const int K,
const int O, const int aggregate, const T* grad_out, const T* points,
const T* centers, const int64_t* knn_idx, T* grad_scores) {
// ----- parallel loop for B, N, K, M ---------
long i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= B * N * K * M) return;
const int b = (int)(i / (N * M * K));
const int n = (int)(i % (N * M * K) / M / K);
const int k = (int)(i % (M * K) / M);
const int m = (int)(i % M);
const int cn = knn_idx[b * N * K + n * K + 0];
const int kn = knn_idx[b * N * K + n * K + k];
if (kn >= N0 ||
kn < 0) { // if index overflows, it is out of the neighborhood range
return;
}
// -------------- loop for O ------------------------
const int out_idx = b * N * K * M + n * K * M + k * M + m;
T val = grad_scores[out_idx];
for (int o = 0; o < O; o++) {
val += (points[b * N0 * M * O + kn * M * O + m * O + o] -
centers[b * N0 * M * O + cn * M * O + m * O + o]) *
grad_out[b * O * N * K + o * N * K + n * K + k];
}
grad_scores[out_idx] = val;
}
#endif // ASSIGN_SCORE_WITHK_CUDA_KERNEL_CUH
// Modified from
// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/paconv_lib/src/gpu
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void AssignScoreWithKForwardCUDAKernelLauncher(
int B, int N0, int N1, int M, int K, int O, int aggregate,
const Tensor& points, const Tensor& centers, const Tensor& scores,
const Tensor& knn_idx, Tensor& output);
void assign_score_withk_forward_cuda(int B, int N0, int N1, int M, int K, int O,
int aggregate, const Tensor& points,
const Tensor& centers,
const Tensor& scores,
const Tensor& knn_idx, Tensor& output) {
AssignScoreWithKForwardCUDAKernelLauncher(
B, N0, N1, M, K, O, aggregate, points, centers, scores, knn_idx, output);
};
void AssignScoreWithKBackwardCUDAKernelLauncher(
int B, int N0, int N1, int M, int K, int O, int aggregate,
const Tensor& grad_out, const Tensor& points, const Tensor& centers,
const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points,
Tensor& grad_centers, Tensor& grad_scores);
void assign_score_withk_backward_cuda(
int B, int N0, int N1, int M, int K, int O, int aggregate,
const Tensor& grad_out, const Tensor& points, const Tensor& centers,
const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points,
Tensor& grad_centers, Tensor& grad_scores) {
AssignScoreWithKBackwardCUDAKernelLauncher(
B, N0, N1, M, K, O, aggregate, grad_out, points, centers, scores, knn_idx,
grad_points, grad_centers, grad_scores);
};
#endif
void assign_score_withk_forward(int B, int N0, int N1, int M, int K, int O,
int aggregate, const Tensor& points,
const Tensor& centers, const Tensor& scores,
const Tensor& knn_idx, Tensor& output) {
if (points.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CONTIGUOUS(points);
CHECK_CONTIGUOUS(centers);
CHECK_CONTIGUOUS(scores);
CHECK_CONTIGUOUS(knn_idx);
CHECK_CONTIGUOUS(output);
assign_score_withk_forward_cuda(B, N0, N1, M, K, O, aggregate, points,
centers, scores, knn_idx, output);
#else
AT_ERROR("assign_score_withk is not compiled with GPU support");
#endif
} else {
AT_ERROR("assign_score_withk is not implemented on CPU");
}
}
void assign_score_withk_backward(int B, int N0, int N1, int M, int K, int O,
int aggregate, const Tensor& grad_out,
const Tensor& points, const Tensor& centers,
const Tensor& scores, const Tensor& knn_idx,
Tensor& grad_points, Tensor& grad_centers,
Tensor& grad_scores) {
if (grad_points.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CONTIGUOUS(grad_out);
CHECK_CONTIGUOUS(scores);
CHECK_CONTIGUOUS(points);
CHECK_CONTIGUOUS(centers);
CHECK_CONTIGUOUS(knn_idx);
CHECK_CONTIGUOUS(grad_scores);
CHECK_CONTIGUOUS(grad_points);
CHECK_CONTIGUOUS(grad_centers);
assign_score_withk_backward_cuda(B, N0, N1, M, K, O, aggregate, grad_out,
points, centers, scores, knn_idx,
grad_points, grad_centers, grad_scores);
#else
AT_ERROR("assign_score_withk is not compiled with GPU support");
#endif
} else {
AT_ERROR("assign_score_withk is not implemented on CPU");
}
}
// Modified from
// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/paconv_lib/src/gpu
#include <stdio.h>
#include <stdlib.h>
#include "assign_score_withk_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void AssignScoreWithKForwardCUDAKernelLauncher(
int B, int N0, int N1, int M, int K, int O, int aggregate,
const Tensor& points, const Tensor& centers, const Tensor& scores,
const Tensor& knn_idx, Tensor& output) {
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(B * O * N1 * K, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "assign_score_withk_forward_cuda_kernel", [&] {
assign_score_withk_forward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
B, N0, N1, M, K, O, aggregate, points.data_ptr<scalar_t>(),
centers.data_ptr<scalar_t>(), scores.data_ptr<scalar_t>(),
knn_idx.data_ptr<int64_t>(), output.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
void AssignScoreWithKBackwardCUDAKernelLauncher(
int B, int N0, int N1, int M, int K, int O, int aggregate,
const Tensor& grad_out, const Tensor& points, const Tensor& centers,
const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points,
Tensor& grad_centers, Tensor& grad_scores) {
at::cuda::CUDAGuard device_guard(grad_out.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks1(DIVUP(B * M * O, THREADS_PER_BLOCK));
dim3 threads1(THREADS_PER_BLOCK);
dim3 blocks2(DIVUP(B * N1 * K * M, THREADS_PER_BLOCK));
dim3 threads2(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_out.scalar_type(), "assign_score_withk_points_backward_cuda_kernel",
[&] {
assign_score_withk_points_backward_cuda_kernel<scalar_t>
<<<blocks1, threads1, 0, stream>>>(
B, N0, N1, M, K, O, aggregate, grad_out.data_ptr<scalar_t>(),
scores.data_ptr<scalar_t>(), knn_idx.data_ptr<int64_t>(),
grad_points.data_ptr<scalar_t>(),
grad_centers.data_ptr<scalar_t>());
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_out.scalar_type(), "assign_score_withk_scores_backward_cuda_kernel",
[&] {
assign_score_withk_scores_backward_cuda_kernel<scalar_t>
<<<blocks2, threads2, 0, stream>>>(
B, N0, N1, M, K, O, aggregate, grad_out.data_ptr<scalar_t>(),
points.data_ptr<scalar_t>(), centers.data_ptr<scalar_t>(),
knn_idx.data_ptr<int64_t>(), grad_scores.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
......@@ -4,6 +4,18 @@
std::string get_compiler_version();
std::string get_compiling_cuda_version();
void assign_score_withk_forward(int B, int N0, int N1, int M, int K, int O,
int aggregate, const Tensor &points,
const Tensor &centers, const Tensor &scores,
const Tensor &knn_idx, Tensor &output);
void assign_score_withk_backward(int B, int N0, int N1, int M, int K, int O,
int aggregate, const Tensor &grad_out,
const Tensor &points, const Tensor &centers,
const Tensor &scores, const Tensor &knn_idx,
Tensor &grad_points, Tensor &grad_centers,
Tensor &grad_scores);
void carafe_naive_forward(Tensor features, Tensor masks, Tensor output,
int kernel_size, int group_size, int scale_factor);
......@@ -288,6 +300,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
m.def("get_compiling_cuda_version", &get_compiling_cuda_version,
"get_compiling_cuda_version");
m.def("assign_score_withk_forward", &assign_score_withk_forward,
"assign_score_withk_forward", py::arg("B"), py::arg("N0"),
py::arg("N1"), py::arg("M"), py::arg("K"), py::arg("O"),
py::arg("aggregate"), py::arg("points"), py::arg("centers"),
py::arg("scores"), py::arg("knn_idx"), py::arg("output"));
m.def("assign_score_withk_backward", &assign_score_withk_backward,
"assign_score_withk_backward", py::arg("B"), py::arg("N0"),
py::arg("N1"), py::arg("M"), py::arg("K"), py::arg("O"),
py::arg("aggregate"), py::arg("grad_out"), py::arg("points"),
py::arg("centers"), py::arg("scores"), py::arg("knn_idx"),
py::arg("grad_points"), py::arg("grad_centers"),
py::arg("grad_scores"));
m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"),
py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"),
py::arg("new_xyz_tensor"), py::arg("idx_tensor"),
......
import pytest
import torch
from mmcv.ops import assign_score_withk
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_paconv_assign_scores():
scores = torch.tensor([[[[0.06947571, 0.6065746], [0.28462553, 0.8378516],
[0.7595994, 0.97220325], [0.519155, 0.766185]],
[[0.15348864, 0.6051019], [0.21510637, 0.31916398],
[0.00236845, 0.5842595], [0.6783676, 0.5216348]]],
[[[0.23089725, 0.5568468], [0.7405102, 0.06438422],
[0.6887394, 0.22089851], [0.0502342, 0.79228795]],
[[0.44883424, 0.15427643],
[0.13817799, 0.34856772], [0.7989621, 0.33788306],
[0.15699774, 0.7693662]]]]).float().cuda()
scores.requires_grad_()
points = torch.tensor([[[[0.06001121, 0.92963666, 0.5753327, 0.7251477],
[0.53563064, 0.23129565, 0.92366195, 0.44261628]],
[[0.5770022, 0.56625944, 0.23560429, 0.11178821],
[0.7735967, 0.95678777, 0.25468266, 0.02895975]],
[[0.0589869, 0.09017515, 0.5977862, 0.02797985],
[0.603862, 0.35991007, 0.85761684, 0.3096559]],
[[0.22359002, 0.13983732, 0.5544243, 0.68863827],
[0.85646236, 0.75651926, 0.8638947, 0.83600986]],
[[0.45424145, 0.27458847, 0.6456112, 0.47162914],
[0.15773582, 0.47645122, 0.79964715, 0.3323908]],
[[0.8351399, 0.84696376, 0.9431732, 0.29418713],
[0.77168906, 0.6996871, 0.19354361, 0.03392768]],
[[0.30976456, 0.7074133, 0.581795, 0.976677],
[0.69656056, 0.07199162, 0.4708506, 0.29117996]],
[[0.5829035, 0.30201727, 0.76556486, 0.0935446],
[0.88030535, 0.16129416, 0.9242525, 0.49545723]]],
[[[0.50899494, 0.06482804, 0.44939405, 0.37704808],
[0.47028124, 0.11969638, 0.62823206, 0.28560323]],
[[0.40690207, 0.689753, 0.51636654, 0.23040164],
[0.06935787, 0.00488842, 0.22462702, 0.09182382]],
[[0.26611632, 0.00184339, 0.7730655, 0.5228131],
[0.87776035, 0.77895886, 0.2787183, 0.16620636]],
[[0.502574, 0.04039001, 0.5368497, 0.98379374],
[0.40973026, 0.3238272, 0.9733018, 0.13988364]],
[[0.04586202, 0.20983845, 0.20662665, 0.22270602],
[0.60387236, 0.5155574, 0.51237285, 0.6528438]],
[[0.45735973, 0.86821306, 0.61054605, 0.8370336],
[0.45193362, 0.3734138, 0.7825672, 0.5699416]],
[[0.44591594, 0.12447512, 0.09282011, 0.7055254],
[0.25223452, 0.46696228, 0.7051136, 0.892151]],
[[0.49615085, 0.47321403, 0.93138885, 0.7652197],
[0.38766378, 0.30332977, 0.23131835,
0.02863514]]]]).float().cuda()
points.requires_grad_()
centers = torch.tensor([[[[0.83878064, 0.96658987, 0.8033424, 0.9598312],
[0.45035273, 0.8768925, 0.977736, 0.54547966]],
[[0.01041394, 0.597893, 0.36212963, 0.4410367],
[0.94879234, 0.8372817, 0.21237361, 0.67945415]],
[[0.5096087, 0.26401454, 0.60034937, 0.5417416],
[0.87591463, 0.546456, 0.4096033, 0.16373193]],
[[0.79547447, 0.1482386, 0.12840575, 0.45384115],
[0.5640288, 0.944541, 0.5745328, 0.73229736]],
[[0.93011934, 0.7406011, 0.62621707, 0.8677915],
[0.91563636, 0.3595413, 0.6678378, 0.6085383]],
[[0.22431666, 0.65617776, 0.7483924, 0.6263364],
[0.30968404, 0.78204364, 0.14899081,
0.09628749]],
[[0.73675203, 0.72104895, 0.4648038, 0.6101647],
[0.7817645, 0.16572917, 0.3311919, 0.43407398]],
[[0.8193154, 0.09559608, 0.05978829, 0.90262103],
[0.4256065, 0.8165596, 0.8206446, 0.6604721]]],
[[[0.7159653, 0.18600845, 0.21433902, 0.3159626],
[0.3921569, 0.33221376, 0.5061177, 0.7961841]],
[[0.95338356, 0.04785997, 0.67185795, 0.6538394],
[0.4729132, 0.33404195, 0.17750603, 0.8445621]],
[[0.6755793, 0.16193843, 0.75943846, 0.92123103],
[0.2781859, 0.03114432, 0.710638, 0.52729136]],
[[0.8376105, 0.10858494, 0.13208169, 0.365772],
[0.5930795, 0.27390373, 0.14036089, 0.170403]],
[[0.3479789, 0.89855295, 0.04844379, 0.9871029],
[0.29781651, 0.0244137, 0.9179047, 0.8081611]],
[[0.12460887, 0.44991326, 0.19382608, 0.35037738],
[0.2773472, 0.4362057, 0.36757517, 0.5993509]],
[[0.29630446, 0.90046406, 0.5417113, 0.13510644],
[0.09623539, 0.04226565, 0.32001644,
0.44358212]],
[[0.5274848, 0.82096446, 0.9415489, 0.7123748],
[0.7537517, 0.8086482, 0.85345286,
0.7472754]]]]).float().cuda()
centers.requires_grad_()
knn_idx = torch.tensor([[[6, 7, 4, 6], [2, 4, 2, 4]],
[[7, 1, 3, 2], [6, 0, 2, 6]]]).long().cuda()
aggregate = 'sum'
expected_output = torch.tensor(
[[[[-0.08134781, 0.03877336, -0.8212776, -0.2869547],
[-0.23378491, -0.24112664, -0.1600166, -0.4121864]],
[[-0.05780616, -0.12298299, -0.0370461, -0.07889931],
[-0.13956165, -0.02006848, -0.10940295, -0.0293439]],
[[0.09284145, 0.58250105, 0.5927749, 0.16774094],
[0.27070042, 0.13422406, 0.2617501, 0.23416464]],
[[-0.06121218, -0.09561322, -0.20408826, 0.08079343],
[0.00944228, 0.03874819, 0.08404065, 0.04041629]]],
[[[-0.2110898, -0.13335688, -0.09315082, 0.08512095],
[0.09121774, 0.15976946, 0.23994486, 0.14350912]],
[[-0.36167958, -0.14891288, -0.64470863, -0.0646704],
[-0.28276974, -0.08847666, -0.46904767, 0.20491874]],
[[-0.34877953, -0.35533834, -0.25225785, -0.4638189],
[-0.1420663, 0.09467781, 0.17088932, 0.22580585]],
[[-0.3879708, -0.3991068, 0.05276498, -0.46989647],
[0.32522714, -0.02163534, 0.21604237, 0.4346682]]]]).float()
# test forward
output = assign_score_withk(scores, points, centers, knn_idx, aggregate)
assert torch.allclose(output.detach().cpu(), expected_output, atol=1e-6)
# test backward
loss = output.sum()
loss.backward()
expected_scores_grad = torch.tensor([[[[0.04288036, -0.18217683],
[-0.78873926, 0.7485497],
[-0.6866992, 0.05346543],
[0.04288036, -0.18217683]],
[[-1.1407862, 0.13533896],
[-0.06964391, -0.22948086],
[-1.1407862, 0.13533896],
[-0.06964391, -0.22948086]]],
[[[-0.3363995, -2.212181],
[-1.1589496, -2.7724311],
[-0.9387654, -1.3163853],
[-1.4385346, -1.0614843]],
[[-0.5048497, 1.4143617],
[-0.47332114, 0.6017133],
[-0.30974793, 1.1995442],
[-0.5048497, 1.4143617]]]]).float()
expected_points_grad = torch.tensor(
[[[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0.15585709, 0.15585709, 0.15585709, 0.15585709],
[1.1893613, 1.1893613, 1.1893613, 1.1893613]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[1.6530733, 1.6530733, 1.6530733, 1.6530733],
[1.8130021, 1.8130021, 1.8130021, 1.8130021]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0.58863074, 0.58863074, 0.58863074, 0.58863074],
[1.3727596, 1.3727596, 1.3727596, 1.3727596]],
[[0.28462553, 0.28462553, 0.28462553, 0.28462553],
[0.8378516, 0.8378516, 0.8378516, 0.8378516]]],
[[[0.13817799, 0.13817799, 0.13817799, 0.13817799],
[0.34856772, 0.34856772, 0.34856772, 0.34856772]],
[[0.7405102, 0.7405102, 0.7405102, 0.7405102],
[0.06438422, 0.06438422, 0.06438422, 0.06438422]],
[[0.8491963, 0.8491963, 0.8491963, 0.8491963],
[1.1301711, 1.1301711, 1.1301711, 1.1301711]],
[[0.6887394, 0.6887394, 0.6887394, 0.6887394],
[0.22089851, 0.22089851, 0.22089851, 0.22089851]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0.605832, 0.605832, 0.605832, 0.605832],
[0.92364264, 0.92364264, 0.92364264, 0.92364264]],
[[0.23089725, 0.23089725, 0.23089725, 0.23089725],
[0.5568468, 0.5568468, 0.5568468, 0.5568468]]]]).float()
expected_centers_grad = torch.tensor(
[[[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[-1.0493311, -1.0493311, -1.0493311, -1.0493311],
[-2.0301602, -2.0301602, -2.0301602, -2.0301602]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[-1.6328557, -1.6328557, -1.6328557, -1.6328557],
[-3.1828144, -3.1828144, -3.1828144, -3.1828144]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]]],
[[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[0., 0., 0., 0.], [0., 0., 0., 0.]],
[[-1.5429721, -1.5429721, -1.5429721, -1.5429721],
[-1.6100934, -1.6100934, -1.6100934, -1.6100934]],
[[-1.7103812, -1.7103812, -1.7103812, -1.7103812],
[-1.6344175, -1.6344175, -1.6344175, -1.6344175]]]]).float()
assert torch.allclose(
scores.grad.detach().cpu(), expected_scores_grad, atol=1e-6)
assert torch.allclose(
points.grad.detach().cpu(), expected_points_grad, atol=1e-6)
assert torch.allclose(
centers.grad.detach().cpu(), expected_centers_grad, atol=1e-6)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment