Unverified Commit 4e0341f2 authored by VVsssssk's avatar VVsssssk Committed by GitHub
Browse files

[Features] Add stack ball query and stack group points ops (#2292)

* add stack sa model ops

* fix lint

* fix lint

* fix comments

* fix bug

* fix lint

* fix comments

* fix lint

* fix lint

* fix
parent a0cac22c
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple from typing import Optional, Tuple
import torch import torch
from torch.autograd import Function from torch.autograd import Function
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['ball_query_forward']) ext_module = ext_loader.load_ext(
'_ext', ['ball_query_forward', 'stack_ball_query_forward'])
class BallQuery(Function): class BallQuery(Function):
"""Find nearby points in spherical space.""" """Find nearby points in spherical space."""
@staticmethod @staticmethod
def forward(ctx, min_radius: float, max_radius: float, sample_num: int, def forward(
xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor: ctx,
min_radius: float,
max_radius: float,
sample_num: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor,
xyz_batch_cnt: Optional[torch.Tensor] = None,
center_xyz_batch_cnt: Optional[torch.Tensor] = None
) -> torch.Tensor:
""" """
Args: Args:
min_radius (float): minimum radius of the balls. min_radius (float): minimum radius of the balls.
max_radius (float): maximum radius of the balls. max_radius (float): maximum radius of the balls.
sample_num (int): maximum number of features in the balls. sample_num (int): maximum number of features in the balls.
xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features. xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features,
or staked input (N1 + N2 ..., 3).
center_xyz (torch.Tensor): (B, npoint, 3) centers of the ball center_xyz (torch.Tensor): (B, npoint, 3) centers of the ball
query. query, or staked input (M1 + M2 ..., 3).
xyz_batch_cnt: (batch_size): Stacked input xyz coordinates nums in
each batch, just like (N1, N2, ...). Defaults to None.
New in version 1.7.0.
center_xyz_batch_cnt: (batch_size): Stacked centers coordinates
nums in each batch, just line (M1, M2, ...). Defaults to None.
New in version 1.7.0.
Returns: Returns:
torch.Tensor: (B, npoint, nsample) tensor with the indices of the torch.Tensor: (B, npoint, nsample) tensor with the indices of the
...@@ -31,21 +47,34 @@ class BallQuery(Function): ...@@ -31,21 +47,34 @@ class BallQuery(Function):
assert center_xyz.is_contiguous() assert center_xyz.is_contiguous()
assert xyz.is_contiguous() assert xyz.is_contiguous()
assert min_radius < max_radius assert min_radius < max_radius
if xyz_batch_cnt is not None and center_xyz_batch_cnt is not None:
B, N, _ = xyz.size() assert xyz_batch_cnt.dtype == torch.int
npoint = center_xyz.size(1) assert center_xyz_batch_cnt.dtype == torch.int
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int) idx = center_xyz.new_zeros((center_xyz.shape[0], sample_num),
dtype=torch.int32)
ext_module.ball_query_forward( ext_module.stack_ball_query_forward(
center_xyz, center_xyz,
xyz, center_xyz_batch_cnt,
idx, xyz,
b=B, xyz_batch_cnt,
n=N, idx,
m=npoint, max_radius=max_radius,
min_radius=min_radius, nsample=sample_num,
max_radius=max_radius, )
nsample=sample_num) else:
B, N, _ = xyz.size()
npoint = center_xyz.size(1)
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int32)
ext_module.ball_query_forward(
center_xyz,
xyz,
idx,
b=B,
n=N,
m=npoint,
min_radius=min_radius,
max_radius=max_radius,
nsample=sample_num)
if torch.__version__ != 'parrots': if torch.__version__ != 'parrots':
ctx.mark_non_differentiable(idx) ctx.mark_non_differentiable(idx)
return idx return idx
......
...@@ -60,21 +60,19 @@ __global__ void correlation_forward_cuda_kernel( ...@@ -60,21 +60,19 @@ __global__ void correlation_forward_cuda_kernel(
for (int i = 0; i < kH; ++i) { for (int i = 0; i < kH; ++i) {
int i1 = start_i + i * dilationH; int i1 = start_i + i * dilationH;
int i2 = i1 + ph_dilated; int i2 = i1 + ph_dilated;
if if (WITHIN_BOUNDS(i1, i2, iH, iH)) {
WITHIN_BOUNDS(i1, i2, iH, iH) { for (int j = 0; j < kW; ++j) {
for (int j = 0; j < kW; ++j) { int j1 = start_j + j * dilationW;
int j1 = start_j + j * dilationW; int j2 = j1 + pw_dilated;
int j2 = j1 + pw_dilated; if (WITHIN_BOUNDS(j1, j2, iW, iW)) {
if for (int c = thread; c < C; c += WARP_SIZE) {
WITHIN_BOUNDS(j1, j2, iW, iW) { scalar_t v1 = rInput1[n][i1][j1][c];
for (int c = thread; c < C; c += WARP_SIZE) { scalar_t v2 = rInput2[n][i2][j2][c];
scalar_t v1 = rInput1[n][i1][j1][c]; prod_sum += v1 * v2;
scalar_t v2 = rInput2[n][i2][j2][c]; }
prod_sum += v1 * v2;
}
}
} }
} }
}
} }
// accumulate // accumulate
for (int offset = 16; offset > 0; offset /= 2) for (int offset = 16; offset > 0; offset /= 2)
......
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
#ifndef STACK_BALL_QUERY_CUDA_KERNEL_CUH
#define STACK_BALL_QUERY_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void stack_ball_query_forward_cuda_kernel(
int B, int M, float radius, int nsample, const T *new_xyz,
const int *new_xyz_batch_cnt, const T *xyz, const int *xyz_batch_cnt,
int *idx) {
// :param xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// :param xyz_batch_cnt: (batch_size), [N1, N2, ...]
// :param new_xyz: (M1 + M2 ..., 3) centers of the ball query
// :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// output:
// idx: (M, nsample)
const T *cur_xyz = xyz;
int *cur_idx = idx;
CUDA_1D_KERNEL_LOOP(pt_idx, M) {
int bs_idx = 0;
for (int pt_cnt = 0; bs_idx < B; bs_idx++) {
pt_cnt += new_xyz_batch_cnt[bs_idx];
if (pt_idx < pt_cnt) break;
}
int xyz_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++) xyz_batch_start_idx += xyz_batch_cnt[k];
const T *new_xyz_p = new_xyz + pt_idx * 3;
cur_xyz += xyz_batch_start_idx * 3;
cur_idx += pt_idx * nsample;
float radius2 = radius * radius;
T new_x = new_xyz_p[0];
T new_y = new_xyz_p[1];
T new_z = new_xyz_p[2];
int n = xyz_batch_cnt[bs_idx];
int cnt = 0;
for (int k = 0; k < n; ++k) {
T x = cur_xyz[k * 3 + 0];
T y = cur_xyz[k * 3 + 1];
T z = cur_xyz[k * 3 + 2];
T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 < radius2) {
if (cnt == 0) {
for (int l = 0; l < nsample; ++l) {
cur_idx[l] = k;
}
}
cur_idx[cnt] = k;
++cnt;
if (cnt >= nsample) break;
}
}
if (cnt == 0) cur_idx[0] = -1;
}
}
#endif // STACK_BALL_QUERY_CUDA_KERNEL_CUH
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu
#ifndef STACK_GROUP_POINTS_CUDA_KERNEL_CUH
#define STACK_GROUP_POINTS_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#include <stdio.h>
template <typename T>
__global__ void stack_group_points_forward_cuda_kernel(
int b, int c, int m, int nsample, const T *features,
const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt,
T *out) {
// :param features: (N1 + N2 ..., C) tensor of features to group
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the
// indices of features to group with :param idx: (M1 + M2 ..., nsample) tensor
// containing the indices of features to group with :param idx_batch_cnt:
// (batch_size) [M1 + M2 ...] tensor containing the indices of features to
// group with :return:
// output: (M1 + M2, C, nsample) tensor
CUDA_1D_KERNEL_LOOP(index, m * c * nsample) {
const T *cur_features = features;
const int *cur_idx = idx;
int sample_idx = index % nsample;
int c_idx = (index / nsample) % c;
int pt_idx = (index / nsample / c);
if (pt_idx >= m || c_idx >= c || sample_idx >= nsample) return;
int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
for (int k = 1; k < b; k++) {
if (pt_idx < pt_cnt) break;
pt_cnt += idx_batch_cnt[k];
bs_idx = k;
}
int features_batch_start_idx = 0;
int features_batch_end_idx = features_batch_cnt[0];
for (int k = 0; k < bs_idx; k++) {
features_batch_start_idx += features_batch_cnt[k];
features_batch_end_idx =
features_batch_start_idx + features_batch_cnt[k + 1];
}
cur_features += features_batch_start_idx * c;
cur_idx += pt_idx * nsample + sample_idx;
int in_idx = cur_idx[0] * c + c_idx;
int out_idx = pt_idx * c * nsample + c_idx * nsample + sample_idx;
if (in_idx < features_batch_end_idx * c) {
out[out_idx] = cur_features[in_idx];
}
}
}
template <typename T>
__global__ void stack_group_points_backward_cuda_kernel(
int b, int c, int m, int n, int nsample, const T *grad_out, const int *idx,
const int *idx_batch_cnt, const int *features_batch_cnt, T *grad_features) {
// :param grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the
// output from forward :param idx: (M1 + M2 ..., nsample) tensor containing
// the indices of features to group with :param idx_batch_cnt: (batch_size)
// [M1 + M2 ...] tensor containing the indices of features to group with
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the
// indices of features to group with :return:
// grad_features: (N1 + N2 ..., C) gradient of the features
CUDA_1D_KERNEL_LOOP(index, m * c * nsample) {
const T *cur_grad_out = grad_out;
const int *cur_idx = idx;
T *cur_grad_features = grad_features;
int sample_idx = index % nsample;
int c_idx = (index / nsample) % c;
int pt_idx = (index / nsample / c);
if (pt_idx >= m || c_idx >= c || sample_idx >= nsample) return;
int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
for (int k = 1; k < b; k++) {
if (pt_idx < pt_cnt) break;
pt_cnt += idx_batch_cnt[k];
bs_idx = k;
}
int features_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++)
features_batch_start_idx += features_batch_cnt[k];
cur_grad_out += pt_idx * c * nsample + c_idx * nsample + sample_idx;
cur_idx += pt_idx * nsample + sample_idx;
cur_grad_features += (features_batch_start_idx + cur_idx[0]) * c + c_idx;
atomicAdd(cur_grad_features, cur_grad_out[0]);
}
}
#endif // GROUP_POINTS_CUDA_KERNEL_CUH
...@@ -15,5 +15,6 @@ using at::Tensor; ...@@ -15,5 +15,6 @@ using at::Tensor;
using phalf = at::Half; using phalf = at::Half;
#define __PHALF(x) (x) #define __PHALF(x) (x)
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
#endif // PYTORCH_CUDA_HELPER #endif // PYTORCH_CUDA_HELPER
...@@ -18,3 +18,21 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor, ...@@ -18,3 +18,21 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor,
ball_query_forward_impl(b, n, m, min_radius, max_radius, nsample, ball_query_forward_impl(b, n, m, min_radius, max_radius, nsample,
new_xyz_tensor, xyz_tensor, idx_tensor); new_xyz_tensor, xyz_tensor, idx_tensor);
} }
void stack_ball_query_forward_impl(float max_radius, int nsample,
const Tensor new_xyz,
const Tensor new_xyz_batch_cnt,
const Tensor xyz, const Tensor xyz_batch_cnt,
Tensor idx) {
DISPATCH_DEVICE_IMPL(stack_ball_query_forward_impl, max_radius, nsample,
new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx);
}
void stack_ball_query_forward(Tensor new_xyz_tensor, Tensor new_xyz_batch_cnt,
Tensor xyz_tensor, Tensor xyz_batch_cnt,
Tensor idx_tensor, float max_radius,
int nsample) {
stack_ball_query_forward_impl(max_radius, nsample, new_xyz_tensor,
new_xyz_batch_cnt, xyz_tensor, xyz_batch_cnt,
idx_tensor);
}
...@@ -67,6 +67,30 @@ void ball_query_forward_impl(int b, int n, int m, float min_radius, ...@@ -67,6 +67,30 @@ void ball_query_forward_impl(int b, int n, int m, float min_radius,
Tensor idx); Tensor idx);
REGISTER_DEVICE_IMPL(ball_query_forward_impl, CUDA, ball_query_forward_cuda); REGISTER_DEVICE_IMPL(ball_query_forward_impl, CUDA, ball_query_forward_cuda);
void StackBallQueryForwardCUDAKernelLauncher(float max_radius, int nsample,
const Tensor new_xyz,
const Tensor new_xyz_batch_cnt,
const Tensor xyz,
const Tensor xyz_batch_cnt,
Tensor idx);
void stack_ball_query_forward_cuda(float max_radius, int nsample,
const Tensor new_xyz,
const Tensor new_xyz_batch_cnt,
const Tensor xyz, const Tensor xyz_batch_cnt,
Tensor idx) {
StackBallQueryForwardCUDAKernelLauncher(
max_radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx);
};
void stack_ball_query_forward_impl(float max_radius, int nsample,
const Tensor new_xyz,
const Tensor new_xyz_batch_cnt,
const Tensor xyz, const Tensor xyz_batch_cnt,
Tensor idx);
REGISTER_DEVICE_IMPL(stack_ball_query_forward_impl, CUDA,
stack_ball_query_forward_cuda);
void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2, void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
Tensor ious, const int mode, Tensor ious, const int mode,
const bool aligned, const int offset); const bool aligned, const int offset);
...@@ -571,6 +595,56 @@ REGISTER_DEVICE_IMPL(group_points_forward_impl, CUDA, ...@@ -571,6 +595,56 @@ REGISTER_DEVICE_IMPL(group_points_forward_impl, CUDA,
REGISTER_DEVICE_IMPL(group_points_backward_impl, CUDA, REGISTER_DEVICE_IMPL(group_points_backward_impl, CUDA,
group_points_backward_cuda); group_points_backward_cuda);
void StackGroupPointsForwardCUDAKernelLauncher(
int b, int c, int m, int nsample, const Tensor features_tensor,
const Tensor features_batch_cnt_tensor, const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor, Tensor out_tensor);
void StackGroupPointsBackwardCUDAKernelLauncher(
int b, int c, int m, int n, int nsample, const Tensor grad_out_tensor,
const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor,
const Tensor features_batch_cnt_tensor, Tensor grad_features_tensor);
void stack_group_points_forward_cuda(int b, int c, int m, int nsample,
const Tensor features_tensor,
const Tensor features_batch_cnt_tensor,
const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor,
Tensor out_tensor) {
StackGroupPointsForwardCUDAKernelLauncher(
b, c, m, nsample, features_tensor, features_batch_cnt_tensor, idx_tensor,
idx_batch_cnt_tensor, out_tensor);
};
void stack_group_points_backward_cuda(int b, int c, int m, int n, int nsample,
const Tensor grad_out_tensor,
const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor,
const Tensor features_batch_cnt_tensor,
Tensor grad_features_tensor) {
StackGroupPointsBackwardCUDAKernelLauncher(
b, c, m, n, nsample, grad_out_tensor, idx_tensor, idx_batch_cnt_tensor,
features_batch_cnt_tensor, grad_features_tensor);
};
void stack_group_points_forward_impl(int b, int c, int m, int nsample,
const Tensor features_tensor,
const Tensor features_batch_cnt_tensor,
const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor,
Tensor out_tensor);
void stack_group_points_backward_impl(int b, int c, int m, int n, int nsample,
const Tensor grad_out_tensor,
const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor,
const Tensor features_batch_cnt_tensor,
Tensor grad_features_tensor);
REGISTER_DEVICE_IMPL(stack_group_points_forward_impl, CUDA,
stack_group_points_forward_cuda);
REGISTER_DEVICE_IMPL(stack_group_points_backward_impl, CUDA,
stack_group_points_backward_cuda);
void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a, void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
const Tensor boxes_a, const Tensor boxes_a,
const int num_b, const int num_b,
......
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "pytorch_cuda_helper.hpp"
#include "stack_ball_query_cuda_kernel.cuh"
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
void StackBallQueryForwardCUDAKernelLauncher(float max_radius, int nsample,
const Tensor new_xyz,
const Tensor new_xyz_batch_cnt,
const Tensor xyz,
const Tensor xyz_batch_cnt,
Tensor idx) {
at::cuda::CUDAGuard device_guard(new_xyz.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// const float *new_xyz_ptr = new_xyz.data_ptr<float>();
// const float *xyz_ptr = xyz.data_ptr<float>();
// const int *new_xyz_batch_cnt_ptr = new_xyz_batch_cnt.data_ptr<int>();
// const int *xyz_batch_cnt_ptr = xyz_batch_cnt.data_ptr<int>();
// int *idx_ptr = idx.data_ptr<int>();
int B = xyz_batch_cnt.size(0);
int M = new_xyz.size(0);
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
new_xyz.scalar_type(), "stack_ball_query_forward_cuda_kernel", [&] {
stack_ball_query_forward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
B, M, max_radius, nsample, new_xyz.data_ptr<scalar_t>(),
new_xyz_batch_cnt.data_ptr<int>(), xyz.data_ptr<scalar_t>(),
xyz_batch_cnt.data_ptr<int>(), idx.data_ptr<int>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu
#include <stdio.h>
#include <stdlib.h>
#include "pytorch_cuda_helper.hpp"
#include "stack_group_points_cuda_kernel.cuh"
void StackGroupPointsForwardCUDAKernelLauncher(
int b, int c, int m, int nsample, const Tensor features_tensor,
const Tensor features_batch_cnt_tensor, const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor, Tensor out_tensor) {
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
at::cuda::CUDAGuard device_guard(features_tensor.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(m * c * nsample, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
features_tensor.scalar_type(), "stack_group_points_forward_cuda_kernel",
[&] {
stack_group_points_forward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, m, nsample, features_tensor.data_ptr<scalar_t>(),
features_batch_cnt_tensor.data_ptr<int>(),
idx_tensor.data_ptr<int>(),
idx_batch_cnt_tensor.data_ptr<int>(),
out_tensor.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
void StackGroupPointsBackwardCUDAKernelLauncher(
int b, int c, int m, int n, int nsample, const Tensor grad_out_tensor,
const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor,
const Tensor features_batch_cnt_tensor, Tensor grad_features_tensor) {
at::cuda::CUDAGuard device_guard(grad_features_tensor.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(m * c * nsample, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_features_tensor.scalar_type(),
"stack_group_points_backward_cuda_kernel", [&] {
stack_group_points_backward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, m, n, nsample, grad_out_tensor.data_ptr<scalar_t>(),
idx_tensor.data_ptr<int>(),
idx_batch_cnt_tensor.data_ptr<int>(),
features_batch_cnt_tensor.data_ptr<int>(),
grad_features_tensor.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
...@@ -32,3 +32,45 @@ void group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, ...@@ -32,3 +32,45 @@ void group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
group_points_backward_impl(b, c, n, npoints, nsample, grad_out_tensor, group_points_backward_impl(b, c, n, npoints, nsample, grad_out_tensor,
idx_tensor, grad_points_tensor); idx_tensor, grad_points_tensor);
} }
void stack_group_points_backward_impl(int b, int c, int m, int n, int nsample,
const Tensor grad_out_tensor,
const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor,
const Tensor features_batch_cnt_tensor,
Tensor grad_features_tensor) {
DISPATCH_DEVICE_IMPL(stack_group_points_backward_impl, b, c, m, n, nsample,
grad_out_tensor, idx_tensor, idx_batch_cnt_tensor,
features_batch_cnt_tensor, grad_features_tensor);
}
void stack_group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
Tensor idx_batch_cnt_tensor,
Tensor features_batch_cnt_tensor,
Tensor grad_features_tensor, int b, int c,
int m, int n, int nsample) {
stack_group_points_backward_impl(
b, c, m, n, nsample, grad_out_tensor, idx_tensor, idx_batch_cnt_tensor,
features_batch_cnt_tensor, grad_features_tensor);
}
void stack_group_points_forward_impl(int b, int c, int m, int nsample,
const Tensor features_tensor,
const Tensor features_batch_cnt_tensor,
const Tensor idx_tensor,
const Tensor idx_batch_cnt_tensor,
Tensor out_tensor) {
DISPATCH_DEVICE_IMPL(stack_group_points_forward_impl, b, c, m, nsample,
features_tensor, features_batch_cnt_tensor, idx_tensor,
idx_batch_cnt_tensor, out_tensor);
}
void stack_group_points_forward(Tensor features_tensor,
Tensor features_batch_cnt_tensor,
Tensor idx_tensor, Tensor idx_batch_cnt_tensor,
Tensor out_tensor, int b, int c, int m,
int nsample) {
DISPATCH_DEVICE_IMPL(stack_group_points_forward_impl, b, c, m, nsample,
features_tensor, features_batch_cnt_tensor, idx_tensor,
idx_batch_cnt_tensor, out_tensor);
}
...@@ -75,6 +75,18 @@ void group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor, ...@@ -75,6 +75,18 @@ void group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
Tensor grad_points_tensor, int b, int c, int n, Tensor grad_points_tensor, int b, int c, int n,
int npoints, int nsample); int npoints, int nsample);
void stack_group_points_forward(Tensor features_tensor,
Tensor features_batch_cnt_tensor,
Tensor idx_tensor, Tensor idx_batch_cnt_tensor,
Tensor out_tensor, int b, int c, int m,
int nsample);
void stack_group_points_backward(Tensor grad_out_tensor, Tensor idx_tensor,
Tensor idx_batch_cnt_tensor,
Tensor features_batch_cnt_tensor,
Tensor grad_features_tensor, int b, int c,
int m, int n, int nsample);
void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature, void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature,
Tensor pooled_features, Tensor pooled_empty_flag); Tensor pooled_features, Tensor pooled_empty_flag);
...@@ -240,6 +252,10 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor, ...@@ -240,6 +252,10 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor,
Tensor idx_tensor, int b, int n, int m, Tensor idx_tensor, int b, int n, int m,
float min_radius, float max_radius, int nsample); float min_radius, float max_radius, int nsample);
void stack_ball_query_forward(Tensor new_xyz_tensor, Tensor new_xyz_batch_cnt,
Tensor xyz_tensor, Tensor xyz_batch_cnt,
Tensor idx_tensor, float max_radius, int nsample);
void prroi_pool_forward(Tensor input, Tensor rois, Tensor output, void prroi_pool_forward(Tensor input, Tensor rois, Tensor output,
int pooled_height, int pooled_width, int pooled_height, int pooled_width,
float spatial_scale); float spatial_scale);
...@@ -557,6 +573,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -557,6 +573,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"group_points_backward", py::arg("grad_out_tensor"), "group_points_backward", py::arg("grad_out_tensor"),
py::arg("idx_tensor"), py::arg("grad_points_tensor"), py::arg("b"), py::arg("idx_tensor"), py::arg("grad_points_tensor"), py::arg("b"),
py::arg("c"), py::arg("n"), py::arg("npoints"), py::arg("nsample")); py::arg("c"), py::arg("n"), py::arg("npoints"), py::arg("nsample"));
m.def("stack_group_points_forward", &stack_group_points_forward,
"stack_group_points_forward", py::arg("features_tensor"),
py::arg("features_batch_cnt_tensor"), py::arg("idx_tensor"),
py::arg("idx_batch_cnt_tensor"), py::arg("out_tensor"), py::arg("b"),
py::arg("c"), py::arg("m"), py::arg("nsample"));
m.def("stack_group_points_backward", &stack_group_points_backward,
"stack_group_points_backward", py::arg("grad_out_tensor"),
py::arg("idx_tensor"), py::arg("idx_batch_cnt_tensor"),
py::arg("features_batch_cnt_tensor"), py::arg("grad_features_tensor"),
py::arg("b"), py::arg("c"), py::arg("m"), py::arg("n"),
py::arg("nsample"));
m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"), m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"),
py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"), py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"),
py::arg("new_xyz_tensor"), py::arg("idx_tensor"), py::arg("new_xyz_tensor"), py::arg("idx_tensor"),
...@@ -726,6 +753,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -726,6 +753,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("new_xyz_tensor"), py::arg("xyz_tensor"), py::arg("idx_tensor"), py::arg("new_xyz_tensor"), py::arg("xyz_tensor"), py::arg("idx_tensor"),
py::arg("b"), py::arg("n"), py::arg("m"), py::arg("min_radius"), py::arg("b"), py::arg("n"), py::arg("m"), py::arg("min_radius"),
py::arg("max_radius"), py::arg("nsample")); py::arg("max_radius"), py::arg("nsample"));
m.def("stack_ball_query_forward", &stack_ball_query_forward,
"stack_ball_query_forward", py::arg("new_xyz_tensor"),
py::arg("new_xyz_batch_cnt"), py::arg("xyz_tensor"),
py::arg("xyz_batch_cnt"), py::arg("idx_tensor"), py::arg("max_radius"),
py::arg("nsample"));
m.def("roi_align_rotated_forward", &roi_align_rotated_forward, m.def("roi_align_rotated_forward", &roi_align_rotated_forward,
"roi_align_rotated forward", py::arg("input"), py::arg("rois"), "roi_align_rotated forward", py::arg("input"), py::arg("rois"),
py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"), py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"),
......
...@@ -9,8 +9,10 @@ from ..utils import ext_loader ...@@ -9,8 +9,10 @@ from ..utils import ext_loader
from .ball_query import ball_query from .ball_query import ball_query
from .knn import knn from .knn import knn
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext('_ext', [
'_ext', ['group_points_forward', 'group_points_backward']) 'group_points_forward', 'group_points_backward',
'stack_group_points_forward', 'stack_group_points_backward'
])
class QueryAndGroup(nn.Module): class QueryAndGroup(nn.Module):
...@@ -183,39 +185,71 @@ class GroupingOperation(Function): ...@@ -183,39 +185,71 @@ class GroupingOperation(Function):
"""Group feature with given index.""" """Group feature with given index."""
@staticmethod @staticmethod
def forward(ctx, features: torch.Tensor, def forward(
indices: torch.Tensor) -> torch.Tensor: ctx,
features: torch.Tensor,
indices: torch.Tensor,
features_batch_cnt: Optional[torch.Tensor] = None,
indices_batch_cnt: Optional[torch.Tensor] = None) -> torch.Tensor:
""" """
Args: Args:
features (Tensor): (B, C, N) tensor of features to group. features (Tensor): Tensor of features to group, input shape is
indices (Tensor): (B, npoint, nsample) the indices of (B, C, N) or stacked inputs (N1 + N2 ..., C).
features to group with. indices (Tensor): The indices of features to group with, input
shape is (B, npoint, nsample) or stacked inputs
(M1 + M2 ..., nsample).
features_batch_cnt (Tensor, optional): Input features nums in
each batch, just like (N1, N2, ...). Defaults to None.
New in version 1.7.0.
indices_batch_cnt (Tensor, optional): Input indices nums in
each batch, just like (M1, M2, ...). Defaults to None.
New in version 1.7.0.
Returns: Returns:
Tensor: (B, C, npoint, nsample) Grouped features. Tensor: Grouped features, the shape is (B, C, npoint, nsample)
or (M1 + M2 ..., C, nsample).
""" """
features = features.contiguous() features = features.contiguous()
indices = indices.contiguous() indices = indices.contiguous()
if features_batch_cnt is not None and indices_batch_cnt is not None:
B, nfeatures, nsample = indices.size() assert features_batch_cnt.dtype == torch.int
_, C, N = features.size() assert indices_batch_cnt.dtype == torch.int
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) M, nsample = indices.size()
N, C = features.size()
ext_module.group_points_forward( B = indices_batch_cnt.shape[0]
features, output = features.new_zeros((M, C, nsample))
indices, ext_module.stack_group_points_forward(
output, features,
b=B, features_batch_cnt,
c=C, indices,
n=N, indices_batch_cnt,
npoints=nfeatures, output,
nsample=nsample) b=B,
m=M,
ctx.for_backwards = (indices, N) c=C,
nsample=nsample)
ctx.for_backwards = (B, N, indices, features_batch_cnt,
indices_batch_cnt)
else:
B, nfeatures, nsample = indices.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
ext_module.group_points_forward(
features,
indices,
output,
b=B,
c=C,
n=N,
npoints=nfeatures,
nsample=nsample)
ctx.for_backwards = (indices, N)
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, None]: def backward(ctx, grad_out: torch.Tensor) -> Tuple:
""" """
Args: Args:
grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
...@@ -224,22 +258,42 @@ class GroupingOperation(Function): ...@@ -224,22 +258,42 @@ class GroupingOperation(Function):
Returns: Returns:
Tensor: (B, C, N) gradient of the features. Tensor: (B, C, N) gradient of the features.
""" """
idx, N = ctx.for_backwards if len(ctx.for_backwards) != 5:
idx, N = ctx.for_backwards
B, C, npoint, nsample = grad_out.size()
grad_features = torch.cuda.FloatTensor(B, C, N).zero_() B, C, npoint, nsample = grad_out.size()
grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
grad_out_data = grad_out.data.contiguous()
ext_module.group_points_backward( grad_out_data = grad_out.data.contiguous()
grad_out_data, ext_module.group_points_backward(
idx, grad_out_data,
grad_features.data, idx,
b=B, grad_features.data,
c=C, b=B,
n=N, c=C,
npoints=npoint, n=N,
nsample=nsample) npoints=npoint,
return grad_features, None nsample=nsample)
return grad_features, None
else:
B, N, idx, features_batch_cnt, idx_batch_cnt = ctx.for_backwards
M, C, nsample = grad_out.size()
grad_features = torch.cuda.FloatTensor(N, C).zero_()
grad_out_data = grad_out.data.contiguous()
ext_module.stack_group_points_backward(
grad_out_data,
idx,
idx_batch_cnt,
features_batch_cnt,
grad_features.data,
b=B,
c=C,
m=M,
n=N,
nsample=nsample)
return grad_features, None, None, None
grouping_operation = GroupingOperation.apply grouping_operation = GroupingOperation.apply
...@@ -53,3 +53,50 @@ def test_ball_query(): ...@@ -53,3 +53,50 @@ def test_ball_query():
[7, 7, 7, 7, 7], [0, 0, 0, 0, 0], [7, 7, 7, 7, 7], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]).cuda() [0, 0, 0, 0, 0]]]).cuda()
assert torch.all(idx == expected_idx) assert torch.all(idx == expected_idx)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_stack_ball_query():
new_xyz = torch.tensor([[-0.0740, 1.3147, -1.3625],
[-2.2769, 2.7817, -0.2334],
[-0.4003, 2.4666, -0.5116],
[-0.0740, 1.3147, -1.3625],
[-0.0740, 1.3147, -1.3625],
[-2.0289, 2.4952, -0.1708],
[-2.0668, 6.0278, -0.4875],
[0.4066, 1.4211, -0.2947],
[-2.0289, 2.4952, -0.1708],
[-2.0289, 2.4952, -0.1708]]).cuda()
new_xyz_batch_cnt = torch.tensor([5, 5], dtype=torch.int32).cuda()
xyz = torch.tensor([[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634],
[-0.4003, 2.4666, -0.5116], [-0.5251, 2.4379, -0.8466],
[-0.9691, 1.1418, -1.3733], [-0.2232, 0.9561, -1.3626],
[-2.2769, 2.7817, -0.2334], [-0.2822, 1.3192, -1.3645],
[0.1533, 1.5024, -1.0432], [0.4917, 1.1529, -1.3496],
[-2.0289, 2.4952, -0.1708], [-0.7188, 0.9956, -0.5096],
[-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610],
[0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791],
[-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947],
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856,
-1.2000]]).cuda()
xyz_batch_cnt = torch.tensor([10, 10], dtype=torch.int32).cuda()
idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt)
expected_idx = torch.tensor([[0, 0, 0, 0, 0], [6, 6, 6, 6, 6],
[2, 2, 2, 2, 2], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
[2, 2, 2, 2, 2], [7, 7, 7, 7, 7],
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]).cuda()
assert torch.all(idx == expected_idx)
xyz = xyz.double()
new_xyz = new_xyz.double()
expected_idx = expected_idx.double()
idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt)
assert torch.all(idx == expected_idx)
xyz = xyz.half()
new_xyz = new_xyz.half()
expected_idx = expected_idx.half()
idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt)
assert torch.all(idx == expected_idx)
...@@ -12,7 +12,7 @@ def test_grouping_points(): ...@@ -12,7 +12,7 @@ def test_grouping_points():
[0, 0, 0]], [0, 0, 0]],
[[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0], [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0],
[0, 0, 0]]]).int().cuda() [0, 0, 0]]]).int().cuda()
festures = torch.tensor([[[ features = torch.tensor([[[
0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274, 0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274,
0.9268, 0.8414 0.9268, 0.8414
], ],
...@@ -37,7 +37,7 @@ def test_grouping_points(): ...@@ -37,7 +37,7 @@ def test_grouping_points():
-1.4049, 0.4990, -0.7037, -0.9924, 0.0386 -1.4049, 0.4990, -0.7037, -0.9924, 0.0386
]]]).cuda() ]]]).cuda()
output = grouping_operation(festures, idx) output = grouping_operation(features, idx)
expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798], expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798],
[-1.3311, -1.3311, -1.3311], [-1.3311, -1.3311, -1.3311],
[0.9268, 0.9268, 0.9268], [0.9268, 0.9268, 0.9268],
...@@ -75,3 +75,161 @@ def test_grouping_points(): ...@@ -75,3 +75,161 @@ def test_grouping_points():
[-0.6646, -0.6646, -0.6646], [-0.6646, -0.6646, -0.6646],
[-0.6646, -0.6646, -0.6646]]]]).cuda() [-0.6646, -0.6646, -0.6646]]]]).cuda()
assert torch.allclose(output, expected_output) assert torch.allclose(output, expected_output)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_stack_grouping_points():
idx = torch.tensor([[0, 0, 0], [3, 3, 3], [8, 8, 8], [1, 1, 1], [0, 0, 0],
[2, 2, 2], [0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0],
[1, 1, 1], [0, 0, 0]]).int().cuda()
features = torch.tensor([[
0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274,
0.9268, 0.8414
],
[
5.4247, 1.5113, 2.3944, 1.4740, 5.0300,
5.1030, 1.9360, 2.1939, 2.1581, 3.4666
],
[
-1.6266, -1.0281, -1.0393, -1.6931, -1.3982,
-0.5732, -1.0830, -1.7561, -1.6786, -1.6967
],
[
-0.0380, -0.1880, -1.5724, 0.6905, -0.3190,
0.7798, -0.3693, -0.9457, -0.2942, -1.8527
],
[
1.1773, 1.5009, 2.6399, 5.9242, 1.0962,
2.7346, 6.0865, 1.5555, 4.3303, 2.8229
],
[
-0.6646, -0.6870, -0.1125, -0.2224, -0.3445,
-1.4049, 0.4990, -0.7037, -0.9924, 0.0386
]]).float().cuda()
features_batch_cnt = torch.tensor([3, 3]).int().cuda()
indices_batch_cnt = torch.tensor([6, 6]).int().cuda()
output = grouping_operation(features, idx, features_batch_cnt,
indices_batch_cnt)
expected_output = torch.Tensor([[[0.5798, 0.5798, 0.5798],
[-0.7981, -0.7981, -0.7981],
[-0.9280, -0.9280, -0.9280],
[-1.3311, -1.3311, -1.3311],
[1.3687, 1.3687, 1.3687],
[0.9277, 0.9277, 0.9277],
[-0.4164, -0.4164, -0.4164],
[-1.8274, -1.8274, -1.8274],
[0.9268, 0.9268, 0.9268],
[0.8414, 0.8414, 0.8414]],
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]],
[[5.4247, 5.4247, 5.4247],
[1.5113, 1.5113, 1.5113],
[2.3944, 2.3944, 2.3944],
[1.4740, 1.4740, 1.4740],
[5.0300, 5.0300, 5.0300],
[5.1030, 5.1030, 5.1030],
[1.9360, 1.9360, 1.9360],
[2.1939, 2.1939, 2.1939],
[2.1581, 2.1581, 2.1581],
[3.4666, 3.4666, 3.4666]],
[[0.5798, 0.5798, 0.5798],
[-0.7981, -0.7981, -0.7981],
[-0.9280, -0.9280, -0.9280],
[-1.3311, -1.3311, -1.3311],
[1.3687, 1.3687, 1.3687],
[0.9277, 0.9277, 0.9277],
[-0.4164, -0.4164, -0.4164],
[-1.8274, -1.8274, -1.8274],
[0.9268, 0.9268, 0.9268],
[0.8414, 0.8414, 0.8414]],
[[-1.6266, -1.6266, -1.6266],
[-1.0281, -1.0281, -1.0281],
[-1.0393, -1.0393, -1.0393],
[-1.6931, -1.6931, -1.6931],
[-1.3982, -1.3982, -1.3982],
[-0.5732, -0.5732, -0.5732],
[-1.0830, -1.0830, -1.0830],
[-1.7561, -1.7561, -1.7561],
[-1.6786, -1.6786, -1.6786],
[-1.6967, -1.6967, -1.6967]],
[[-0.0380, -0.0380, -0.0380],
[-0.1880, -0.1880, -0.1880],
[-1.5724, -1.5724, -1.5724],
[0.6905, 0.6905, 0.6905],
[-0.3190, -0.3190, -0.3190],
[0.7798, 0.7798, 0.7798],
[-0.3693, -0.3693, -0.3693],
[-0.9457, -0.9457, -0.9457],
[-0.2942, -0.2942, -0.2942],
[-1.8527, -1.8527, -1.8527]],
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]],
[[-0.0380, -0.0380, -0.0380],
[-0.1880, -0.1880, -0.1880],
[-1.5724, -1.5724, -1.5724],
[0.6905, 0.6905, 0.6905],
[-0.3190, -0.3190, -0.3190],
[0.7798, 0.7798, 0.7798],
[-0.3693, -0.3693, -0.3693],
[-0.9457, -0.9457, -0.9457],
[-0.2942, -0.2942, -0.2942],
[-1.8527, -1.8527, -1.8527]],
[[1.1773, 1.1773, 1.1773],
[1.5009, 1.5009, 1.5009],
[2.6399, 2.6399, 2.6399],
[5.9242, 5.9242, 5.9242],
[1.0962, 1.0962, 1.0962],
[2.7346, 2.7346, 2.7346],
[6.0865, 6.0865, 6.0865],
[1.5555, 1.5555, 1.5555],
[4.3303, 4.3303, 4.3303],
[2.8229, 2.8229, 2.8229]],
[[-0.0380, -0.0380, -0.0380],
[-0.1880, -0.1880, -0.1880],
[-1.5724, -1.5724, -1.5724],
[0.6905, 0.6905, 0.6905],
[-0.3190, -0.3190, -0.3190],
[0.7798, 0.7798, 0.7798],
[-0.3693, -0.3693, -0.3693],
[-0.9457, -0.9457, -0.9457],
[-0.2942, -0.2942, -0.2942],
[-1.8527, -1.8527,
-1.8527]]]).cuda().float()
assert torch.allclose(output, expected_output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment