Commit 487d4d66 authored by Georgia Gkioxari's avatar Georgia Gkioxari Committed by Facebook GitHub Bot
Browse files

point mesh distances

Summary:
Implementation of point to mesh distances. The current diff contains two types:
(a) Point to Edge
(b) Point to Face

```

Benchmark                                       Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
POINT_MESH_EDGE_4_100_300_5000_cuda:0                2745            3138            183
POINT_MESH_EDGE_4_100_300_10000_cuda:0               4408            4499            114
POINT_MESH_EDGE_4_100_3000_5000_cuda:0               4978            5070            101
POINT_MESH_EDGE_4_100_3000_10000_cuda:0              9076            9187             56
POINT_MESH_EDGE_4_1000_300_5000_cuda:0               1411            1487            355
POINT_MESH_EDGE_4_1000_300_10000_cuda:0              4829            5030            104
POINT_MESH_EDGE_4_1000_3000_5000_cuda:0              7539            7620             67
POINT_MESH_EDGE_4_1000_3000_10000_cuda:0            12088           12272             42
POINT_MESH_EDGE_8_100_300_5000_cuda:0                3106            3222            161
POINT_MESH_EDGE_8_100_300_10000_cuda:0               8561            8648             59
POINT_MESH_EDGE_8_100_3000_5000_cuda:0               6932            7021             73
POINT_MESH_EDGE_8_100_3000_10000_cuda:0             24032           24176             21
POINT_MESH_EDGE_8_1000_300_5000_cuda:0               5272            5399             95
POINT_MESH_EDGE_8_1000_300_10000_cuda:0             11348           11430             45
POINT_MESH_EDGE_8_1000_3000_5000_cuda:0             17478           17683             29
POINT_MESH_EDGE_8_1000_3000_10000_cuda:0            25961           26236             20
POINT_MESH_EDGE_16_100_300_5000_cuda:0               8244            8323             61
POINT_MESH_EDGE_16_100_300_10000_cuda:0             18018           18071             28
POINT_MESH_EDGE_16_100_3000_5000_cuda:0             19428           19544             26
POINT_MESH_EDGE_16_100_3000_10000_cuda:0            44967           45135             12
POINT_MESH_EDGE_16_1000_300_5000_cuda:0              7825            7937             64
POINT_MESH_EDGE_16_1000_300_10000_cuda:0            18504           18571             28
POINT_MESH_EDGE_16_1000_3000_5000_cuda:0            65805           66132              8
POINT_MESH_EDGE_16_1000_3000_10000_cuda:0           90885           91089              6
--------------------------------------------------------------------------------

Benchmark                                       Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
POINT_MESH_FACE_4_100_300_5000_cuda:0                1561            1685            321
POINT_MESH_FACE_4_100_300_10000_cuda:0               2818            2954            178
POINT_MESH_FACE_4_100_3000_5000_cuda:0              15893           16018             32
POINT_MESH_FACE_4_100_3000_10000_cuda:0             16350           16439             31
POINT_MESH_FACE_4_1000_300_5000_cuda:0               3179            3278            158
POINT_MESH_FACE_4_1000_300_10000_cuda:0              2353            2436            213
POINT_MESH_FACE_4_1000_3000_5000_cuda:0             16262           16336             31
POINT_MESH_FACE_4_1000_3000_10000_cuda:0             9334            9448             54
POINT_MESH_FACE_8_100_300_5000_cuda:0                4377            4493            115
POINT_MESH_FACE_8_100_300_10000_cuda:0               9728            9822             52
POINT_MESH_FACE_8_100_3000_5000_cuda:0              26428           26544             19
POINT_MESH_FACE_8_100_3000_10000_cuda:0             42238           43031             12
POINT_MESH_FACE_8_1000_300_5000_cuda:0               3891            3982            129
POINT_MESH_FACE_8_1000_300_10000_cuda:0              5363            5429             94
POINT_MESH_FACE_8_1000_3000_5000_cuda:0             20998           21084             24
POINT_MESH_FACE_8_1000_3000_10000_cuda:0            39711           39897             13
POINT_MESH_FACE_16_100_300_5000_cuda:0               5955            6001             84
POINT_MESH_FACE_16_100_300_10000_cuda:0             12082           12144             42
POINT_MESH_FACE_16_100_3000_5000_cuda:0             44996           45176             12
POINT_MESH_FACE_16_100_3000_10000_cuda:0            73042           73197              7
POINT_MESH_FACE_16_1000_300_5000_cuda:0              8292            8374             61
POINT_MESH_FACE_16_1000_300_10000_cuda:0            19442           19506             26
POINT_MESH_FACE_16_1000_3000_5000_cuda:0            36059           36194             14
POINT_MESH_FACE_16_1000_3000_10000_cuda:0           64644           64822              8
--------------------------------------------------------------------------------
```

Reviewed By: jcjohnson

Differential Revision: D20590462

fbshipit-source-id: 42a39837b514a546ac9471bfaff60eefe7fae829
parent 474c8b45
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include "pytorch3d_cutils.h"
#include "utils/pytorch3d_cutils.h"
#include <vector>
......
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include "pytorch3d_cutils.h"
#include "utils/pytorch3d_cutils.h"
#include <vector>
......
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include "pytorch3d_cutils.h"
#include "utils/pytorch3d_cutils.h"
#include <vector>
......
......@@ -9,6 +9,8 @@
#include "knn/knn.h"
#include "nearest_neighbor_points/nearest_neighbor_points.h"
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
#include "point_mesh/point_mesh_edge.h"
#include "point_mesh/point_mesh_face.h"
#include "rasterize_meshes/rasterize_meshes.h"
#include "rasterize_points/rasterize_points.h"
......@@ -39,4 +41,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("_rasterize_meshes_naive", &RasterizeMeshesNaive);
m.def("_rasterize_meshes_coarse", &RasterizeMeshesCoarse);
m.def("_rasterize_meshes_fine", &RasterizeMeshesFine);
// PointEdge distance functions
m.def("point_edge_dist_forward", &PointEdgeDistanceForward);
m.def("point_edge_dist_backward", &PointEdgeDistanceBackward);
m.def("edge_point_dist_forward", &EdgePointDistanceForward);
m.def("edge_point_dist_backward", &EdgePointDistanceBackward);
m.def("point_edge_array_dist_forward", &PointEdgeArrayDistanceForward);
m.def("point_edge_array_dist_backward", &PointEdgeArrayDistanceBackward);
// PointFace distance functions
m.def("point_face_dist_forward", &PointFaceDistanceForward);
m.def("point_face_dist_backward", &PointFaceDistanceBackward);
m.def("face_point_dist_forward", &FacePointDistanceForward);
m.def("face_point_dist_backward", &FacePointDistanceBackward);
m.def("point_face_array_dist_forward", &PointFaceArrayDistanceForward);
m.def("point_face_array_dist_backward", &PointFaceArrayDistanceBackward);
}
......@@ -5,8 +5,8 @@
#include <iostream>
#include <tuple>
#include "dispatch.cuh"
#include "mink.cuh"
#include "utils/dispatch.cuh"
#include "utils/mink.cuh"
// A chunk of work is blocksize-many points of P1.
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
......
......@@ -3,7 +3,7 @@
#pragma once
#include <torch/extension.h>
#include <tuple>
#include "pytorch3d_cutils.h"
#include "utils/pytorch3d_cutils.h"
// Compute indices of K nearest neighbors in pointcloud p2 to points
// in pointcloud p1.
......
......@@ -2,43 +2,7 @@
#include <ATen/ATen.h>
#include <float.h>
template <typename scalar_t>
__device__ void WarpReduce(
volatile scalar_t* min_dists,
volatile int64_t* min_idxs,
const size_t tid) {
// s = 32
if (min_dists[tid] > min_dists[tid + 32]) {
min_idxs[tid] = min_idxs[tid + 32];
min_dists[tid] = min_dists[tid + 32];
}
// s = 16
if (min_dists[tid] > min_dists[tid + 16]) {
min_idxs[tid] = min_idxs[tid + 16];
min_dists[tid] = min_dists[tid + 16];
}
// s = 8
if (min_dists[tid] > min_dists[tid + 8]) {
min_idxs[tid] = min_idxs[tid + 8];
min_dists[tid] = min_dists[tid + 8];
}
// s = 4
if (min_dists[tid] > min_dists[tid + 4]) {
min_idxs[tid] = min_idxs[tid + 4];
min_dists[tid] = min_dists[tid + 4];
}
// s = 2
if (min_dists[tid] > min_dists[tid + 2]) {
min_idxs[tid] = min_idxs[tid + 2];
min_dists[tid] = min_dists[tid + 2];
}
// s = 1
if (min_dists[tid] > min_dists[tid + 1]) {
min_idxs[tid] = min_idxs[tid + 1];
min_dists[tid] = min_dists[tid + 1];
}
}
#include "utils/warp_reduce.cuh"
// CUDA kernel to compute nearest neighbors between two batches of pointclouds
// where each point is of dimension D.
......
......@@ -2,7 +2,7 @@
#pragma once
#include <torch/extension.h>
#include "pytorch3d_cutils.h"
#include "utils/pytorch3d_cutils.h"
// Compute indices of nearest neighbors in pointcloud p2 to points
// in pointcloud p1.
......
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <algorithm>
#include <list>
#include <queue>
#include <tuple>
#include "utils/float_math.cuh"
#include "utils/geometry_utils.cuh"
#include "utils/warp_reduce.cuh"
// ****************************************************************************
// * PointEdgeDistance *
// ****************************************************************************
__global__ void PointEdgeForwardKernel(
const float* __restrict__ points, // (P, 3)
const int64_t* __restrict__ points_first_idx, // (B,)
const float* __restrict__ segms, // (S, 2, 3)
const int64_t* __restrict__ segms_first_idx, // (B,)
float* __restrict__ dist_points, // (P,)
int64_t* __restrict__ idx_points, // (P,)
const size_t B,
const size_t P,
const size_t S) {
float3* points_f3 = (float3*)points;
float3* segms_f3 = (float3*)segms;
// Single shared memory buffer which is split and cast to different types.
extern __shared__ char shared_buf[];
float* min_dists = (float*)shared_buf; // float[NUM_THREADS]
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
const size_t batch_idx = blockIdx.y; // index of batch element.
// start and end for points in batch
const int64_t startp = points_first_idx[batch_idx];
const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P;
// start and end for segments in batch_idx
const int64_t starts = segms_first_idx[batch_idx];
const int64_t ends = batch_idx + 1 < B ? segms_first_idx[batch_idx + 1] : S;
const size_t i = blockIdx.x; // index of point within batch element.
const size_t tid = threadIdx.x; // thread idx
// Each block will compute one element of the output idx_points[startp + i],
// dist_points[startp + i]. Within the block we will use threads to compute
// the distances between points[startp + i] and segms[j] for all j belonging
// in the same batch as i, i.e. j in [starts, ends]. Then use a block
// reduction to take an argmin of the distances.
// If i exceeds the number of points in batch_idx, then do nothing
if (i < (endp - startp)) {
// Retrieve (startp + i) point
const float3 p_f3 = points_f3[startp + i];
// Compute the distances between points[startp + i] and segms[j] for
// all j belonging in the same batch as i, i.e. j in [starts, ends].
// Here each thread will reduce over (ends-starts) / blockDim.x in serial,
// and store its result to shared memory
float min_dist = FLT_MAX;
size_t min_idx = 0;
for (size_t j = tid; j < (ends - starts); j += blockDim.x) {
const float3 v0 = segms_f3[(starts + j) * 2 + 0];
const float3 v1 = segms_f3[(starts + j) * 2 + 1];
float dist = PointLine3DistanceForward(p_f3, v0, v1);
min_dist = (j == tid) ? dist : min_dist;
min_idx = (dist <= min_dist) ? (starts + j) : min_idx;
min_dist = (dist <= min_dist) ? dist : min_dist;
}
min_dists[tid] = min_dist;
min_idxs[tid] = min_idx;
__syncthreads();
// Perform reduction in shared memory.
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
if (tid < s) {
if (min_dists[tid] > min_dists[tid + s]) {
min_dists[tid] = min_dists[tid + s];
min_idxs[tid] = min_idxs[tid + s];
}
}
__syncthreads();
}
// Unroll the last 6 iterations of the loop since they will happen
// synchronized within a single warp.
if (tid < 32)
WarpReduce<float>(min_dists, min_idxs, tid);
// Finally thread 0 writes the result to the output buffer.
if (tid == 0) {
idx_points[startp + i] = min_idxs[0];
dist_points[startp + i] = min_dists[0];
}
}
}
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& segms,
const torch::Tensor& segms_first_idx,
const int64_t max_points) {
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
const int64_t B = points_first_idx.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
(segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3");
AT_ASSERTM(segms_first_idx.size(0) == B);
// clang-format off
torch::Tensor dists = torch::zeros({P,}, points.options());
torch::Tensor idxs = torch::zeros({P,}, points_first_idx.options());
// clang-format on
const int threads = 128;
const dim3 blocks(max_points, B);
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
PointEdgeForwardKernel<<<blocks, threads, shared_size>>>(
points.data_ptr<float>(),
points_first_idx.data_ptr<int64_t>(),
segms.data_ptr<float>(),
segms_first_idx.data_ptr<int64_t>(),
dists.data_ptr<float>(),
idxs.data_ptr<int64_t>(),
B,
P,
S);
return std::make_tuple(dists, idxs);
}
__global__ void PointEdgeBackwardKernel(
const float* __restrict__ points, // (P, 3)
const float* __restrict__ segms, // (S, 2, 3)
const int64_t* __restrict__ idx_points, // (P,)
const float* __restrict__ grad_dists, // (P,)
float* __restrict__ grad_points, // (P, 3)
float* __restrict__ grad_segms, // (S, 2, 3)
const size_t P) {
float3* points_f3 = (float3*)points;
float3* segms_f3 = (float3*)segms;
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = gridDim.x * blockDim.x;
for (size_t p = tid; p < P; p += stride) {
const float3 p_f3 = points_f3[p];
const int64_t sidx = idx_points[p];
const float3 v0 = segms_f3[sidx * 2 + 0];
const float3 v1 = segms_f3[sidx * 2 + 1];
const float grad_dist = grad_dists[p];
const auto grads = PointLine3DistanceBackward(p_f3, v0, v1, grad_dist);
const float3 grad_point = thrust::get<0>(grads);
const float3 grad_v0 = thrust::get<1>(grads);
const float3 grad_v1 = thrust::get<2>(grads);
atomicAdd(grad_points + p * 3 + 0, grad_point.x);
atomicAdd(grad_points + p * 3 + 1, grad_point.y);
atomicAdd(grad_points + p * 3 + 2, grad_point.z);
atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 0, grad_v0.x);
atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 1, grad_v0.y);
atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 2, grad_v0.z);
atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 0, grad_v1.x);
atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 1, grad_v1.y);
atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 2, grad_v1.z);
}
}
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists) {
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
(segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3");
AT_ASSERTM(idx_points.size(0) == P);
AT_ASSERTM(grad_dists.size(0) == P);
// clang-format off
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
torch::Tensor grad_segms = torch::zeros({S, 2, 3}, segms.options());
// clang-format on
const int blocks = 64;
const int threads = 512;
PointEdgeBackwardKernel<<<blocks, threads>>>(
points.data_ptr<float>(),
segms.data_ptr<float>(),
idx_points.data_ptr<int64_t>(),
grad_dists.data_ptr<float>(),
grad_points.data_ptr<float>(),
grad_segms.data_ptr<float>(),
P);
return std::make_tuple(grad_points, grad_segms);
}
// ****************************************************************************
// * EdgePointDistance *
// ****************************************************************************
__global__ void EdgePointForwardKernel(
const float* __restrict__ points, // (P, 3)
const int64_t* __restrict__ points_first_idx, // (B,)
const float* __restrict__ segms, // (S, 2, 3)
const int64_t* __restrict__ segms_first_idx, // (B,)
float* __restrict__ dist_segms, // (S,)
int64_t* __restrict__ idx_segms, // (S,)
const size_t B,
const size_t P,
const size_t S) {
float3* points_f3 = (float3*)points;
float3* segms_f3 = (float3*)segms;
// Single shared memory buffer which is split and cast to different types.
extern __shared__ char shared_buf[];
float* min_dists = (float*)shared_buf; // float[NUM_THREADS]
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
const size_t batch_idx = blockIdx.y; // index of batch element.
// start and end for points in batch_idx
const int64_t startp = points_first_idx[batch_idx];
const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P;
// start and end for segms in batch_idx
const int64_t starts = segms_first_idx[batch_idx];
const int64_t ends = batch_idx + 1 < B ? segms_first_idx[batch_idx + 1] : S;
const size_t i = blockIdx.x; // index of point within batch element.
const size_t tid = threadIdx.x; // thread index
// Each block will compute one element of the output idx_segms[starts + i],
// dist_segms[starts + i]. Within the block we will use threads to compute
// the distances between segms[starts + i] and points[j] for all j belonging
// in the same batch as i, i.e. j in [startp, endp]. Then use a block
// reduction to take an argmin of the distances.
// If i exceeds the number of segms in batch_idx, then do nothing
if (i < (ends - starts)) {
const float3 v0 = segms_f3[(starts + i) * 2 + 0];
const float3 v1 = segms_f3[(starts + i) * 2 + 1];
// Compute the distances between segms[starts + i] and points[j] for
// all j belonging in the same batch as i, i.e. j in [startp, endp].
// Here each thread will reduce over (endp-startp) / blockDim.x in serial,
// and store its result to shared memory
float min_dist = FLT_MAX;
size_t min_idx = 0;
for (size_t j = tid; j < (endp - startp); j += blockDim.x) {
// Retrieve (startp + i) point
const float3 p_f3 = points_f3[startp + j];
float dist = PointLine3DistanceForward(p_f3, v0, v1);
min_dist = (j == tid) ? dist : min_dist;
min_idx = (dist <= min_dist) ? (startp + j) : min_idx;
min_dist = (dist <= min_dist) ? dist : min_dist;
}
min_dists[tid] = min_dist;
min_idxs[tid] = min_idx;
__syncthreads();
// Perform reduction in shared memory.
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
if (tid < s) {
if (min_dists[tid] > min_dists[tid + s]) {
min_dists[tid] = min_dists[tid + s];
min_idxs[tid] = min_idxs[tid + s];
}
}
__syncthreads();
}
// Unroll the last 6 iterations of the loop since they will happen
// synchronized within a single warp.
if (tid < 32)
WarpReduce<float>(min_dists, min_idxs, tid);
// Finally thread 0 writes the result to the output buffer.
if (tid == 0) {
idx_segms[starts + i] = min_idxs[0];
dist_segms[starts + i] = min_dists[0];
}
}
}
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& segms,
const torch::Tensor& segms_first_idx,
const int64_t max_segms) {
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
const int64_t B = points_first_idx.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
(segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3");
AT_ASSERTM(segms_first_idx.size(0) == B);
// clang-format off
torch::Tensor dists = torch::zeros({S,}, segms.options());
torch::Tensor idxs = torch::zeros({S,}, segms_first_idx.options());
// clang-format on
const int threads = 128;
const dim3 blocks(max_segms, B);
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
EdgePointForwardKernel<<<blocks, threads, shared_size>>>(
points.data_ptr<float>(),
points_first_idx.data_ptr<int64_t>(),
segms.data_ptr<float>(),
segms_first_idx.data_ptr<int64_t>(),
dists.data_ptr<float>(),
idxs.data_ptr<int64_t>(),
B,
P,
S);
return std::make_tuple(dists, idxs);
}
__global__ void EdgePointBackwardKernel(
const float* __restrict__ points, // (P, 3)
const float* __restrict__ segms, // (S, 2, 3)
const int64_t* __restrict__ idx_segms, // (S,)
const float* __restrict__ grad_dists, // (S,)
float* __restrict__ grad_points, // (P, 3)
float* __restrict__ grad_segms, // (S, 2, 3)
const size_t S) {
float3* points_f3 = (float3*)points;
float3* segms_f3 = (float3*)segms;
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = gridDim.x * blockDim.x;
for (size_t s = tid; s < S; s += stride) {
const float3 v0 = segms_f3[s * 2 + 0];
const float3 v1 = segms_f3[s * 2 + 1];
const int64_t pidx = idx_segms[s];
const float3 p_f3 = points_f3[pidx];
const float grad_dist = grad_dists[s];
const auto grads = PointLine3DistanceBackward(p_f3, v0, v1, grad_dist);
const float3 grad_point = thrust::get<0>(grads);
const float3 grad_v0 = thrust::get<1>(grads);
const float3 grad_v1 = thrust::get<2>(grads);
atomicAdd(grad_points + pidx * 3 + 0, grad_point.x);
atomicAdd(grad_points + pidx * 3 + 1, grad_point.y);
atomicAdd(grad_points + pidx * 3 + 2, grad_point.z);
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 0, grad_v0.x);
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 1, grad_v0.y);
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 2, grad_v0.z);
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 0, grad_v1.x);
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 1, grad_v1.y);
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 2, grad_v1.z);
}
}
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& idx_segms,
const torch::Tensor& grad_dists) {
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
(segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3");
AT_ASSERTM(idx_segms.size(0) == S);
AT_ASSERTM(grad_dists.size(0) == S);
// clang-format off
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
torch::Tensor grad_segms = torch::zeros({S, 2, 3}, segms.options());
// clang-format on
const int blocks = 64;
const int threads = 512;
EdgePointBackwardKernel<<<blocks, threads>>>(
points.data_ptr<float>(),
segms.data_ptr<float>(),
idx_segms.data_ptr<int64_t>(),
grad_dists.data_ptr<float>(),
grad_points.data_ptr<float>(),
grad_segms.data_ptr<float>(),
S);
return std::make_tuple(grad_points, grad_segms);
}
// ****************************************************************************
// * PointEdgeArrayDistance *
// ****************************************************************************
__global__ void PointEdgeArrayForwardKernel(
const float* __restrict__ points, // (P, 3)
const float* __restrict__ segms, // (S, 2, 3)
float* __restrict__ dists, // (P, S)
const size_t P,
const size_t S) {
float3* points_f3 = (float3*)points;
float3* segms_f3 = (float3*)segms;
// Parallelize over P * S computations
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
const int s = t_i / P; // segment index.
const int p = t_i % P; // point index
float3 a = segms_f3[s * 2 + 0];
float3 b = segms_f3[s * 2 + 1];
float3 point = points_f3[p];
float dist = PointLine3DistanceForward(point, a, b);
dists[p * S + s] = dist;
}
}
torch::Tensor PointEdgeArrayDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& segms) {
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
(segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3");
torch::Tensor dists = torch::zeros({P, S}, points.options());
const size_t blocks = 1024;
const size_t threads = 64;
PointEdgeArrayForwardKernel<<<blocks, threads>>>(
points.data_ptr<float>(),
segms.data_ptr<float>(),
dists.data_ptr<float>(),
P,
S);
return dists;
}
__global__ void PointEdgeArrayBackwardKernel(
const float* __restrict__ points, // (P, 3)
const float* __restrict__ segms, // (S, 2, 3)
const float* __restrict__ grad_dists, // (P, S)
float* __restrict__ grad_points, // (P, 3)
float* __restrict__ grad_segms, // (S, 2, 3)
const size_t P,
const size_t S) {
float3* points_f3 = (float3*)points;
float3* segms_f3 = (float3*)segms;
// Parallelize over P * S computations
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
const int s = t_i / P; // segment index.
const int p = t_i % P; // point index
const float3 a = segms_f3[s * 2 + 0];
const float3 b = segms_f3[s * 2 + 1];
const float3 point = points_f3[p];
const float grad_dist = grad_dists[p * S + s];
const auto grads = PointLine3DistanceBackward(point, a, b, grad_dist);
const float3 grad_point = thrust::get<0>(grads);
const float3 grad_a = thrust::get<1>(grads);
const float3 grad_b = thrust::get<2>(grads);
atomicAdd(grad_points + p * 3 + 0, grad_point.x);
atomicAdd(grad_points + p * 3 + 1, grad_point.y);
atomicAdd(grad_points + p * 3 + 2, grad_point.z);
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 0, grad_a.x);
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 1, grad_a.y);
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 2, grad_a.z);
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 0, grad_b.x);
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 1, grad_b.y);
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 2, grad_b.z);
}
}
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& grad_dists) {
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
(segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3");
AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == S));
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
torch::Tensor grad_segms = torch::zeros({S, 2, 3}, segms.options());
const size_t blocks = 1024;
const size_t threads = 64;
PointEdgeArrayBackwardKernel<<<blocks, threads>>>(
points.data_ptr<float>(),
segms.data_ptr<float>(),
grad_dists.data_ptr<float>(),
grad_points.data_ptr<float>(),
grad_segms.data_ptr<float>(),
P,
S);
return std::make_tuple(grad_points, grad_segms);
}
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
// ****************************************************************************
// * PointEdgeDistance *
// ****************************************************************************
// Computes the squared euclidean distance of each p in points to the closest
// mesh edge belonging to the corresponding example in the batch of size N.
//
// Args:
// points: FloatTensor of shape (P, 3)
// points_first_idx: LongTensor of shape (N,) indicating the first point
// index for each example in the batch
// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge
// segment is spanned by (segms[s, 0], segms[s, 1])
// segms_first_idx: LongTensor of shape (N,) indicating the first edge
// index for each example in the batch
// max_points: Scalar equal to max(P_i) for i in [0, N - 1] containing
// the maximum number of points in the batch and is used to set
// the grid dimensions in the CUDA implementation.
//
// Returns:
// dists: FloatTensor of shape (P,), where dists[p] is the squared euclidean
// distance of points[p] to the closest edge in the same example in the
// batch.
// idxs: LongTensor of shape (P,), where idxs[p] is the index of the closest
// edge in the batch.
// So, dists[p] = d(points[p], segms[idxs[p], 0], segms[idxs[p], 1]),
// where d(u, v0, v1) is the distance of u from the segment spanned by
// (v0, v1).
//
#ifdef WITH_CUDA
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& segms,
const torch::Tensor& segms_first_idx,
const int64_t max_points);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& segms,
const torch::Tensor& segms_first_idx,
const int64_t max_points) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
return PointEdgeDistanceForwardCuda(
points, points_first_idx, segms, segms_first_idx, max_points);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
}
// Backward pass for PointEdgeDistance.
//
// Args:
// points: FloatTensor of shape (P, 3)
// segms: FloatTensor of shape (S, 2, 3)
// idx_points: LongTensor of shape (P,) containing the indices
// of the closest edge in the example in the batch.
// This is computed by the forward pass.
// grad_dists: FloatTensor of shape (P,)
//
// Returns:
// grad_points: FloatTensor of shape (P, 3)
// grad_segms: FloatTensor of shape (S, 2, 3)
//
#ifdef WITH_CUDA
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
}
// ****************************************************************************
// * EdgePointDistance *
// ****************************************************************************
// Computes the squared euclidean distance of each edge segment to the closest
// point belonging to the corresponding example in the batch of size N.
//
// Args:
// points: FloatTensor of shape (P, 3)
// points_first_idx: LongTensor of shape (N,) indicating the first point
// index for each example in the batch
// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge
// segment is spanned by (segms[s, 0], segms[s, 1])
// segms_first_idx: LongTensor of shape (N,) indicating the first edge
// index for each example in the batch
// max_segms: Scalar equal to max(S_i) for i in [0, N - 1] containing
// the maximum number of edges in the batch and is used to set
// the block dimensions in the CUDA implementation.
//
// Returns:
// dists: FloatTensor of shape (S,), where dists[s] is the squared
// euclidean distance of s-th edge to the closest point in the
// corresponding example in the batch.
// idxs: LongTensor of shape (S,), where idxs[s] is the index of the closest
// point in the example in the batch.
// So, dists[s] = d(points[idxs[s]], segms[s, 0], segms[s, 1]), where
// d(u, v0, v1) is the distance of u from the segment spanned by (v0, v1)
//
//
#ifdef WITH_CUDA
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& segms,
const torch::Tensor& segms_first_idx,
const int64_t max_segms);
#endif
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& segms,
const torch::Tensor& segms_first_idx,
const int64_t max_segms) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
return EdgePointDistanceForwardCuda(
points, points_first_idx, segms, segms_first_idx, max_segms);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
}
// Backward pass for EdgePointDistance.
//
// Args:
// points: FloatTensor of shape (P, 3)
// segms: FloatTensor of shape (S, 2, 3)
// idx_segms: LongTensor of shape (S,) containing the indices
// of the closest point in the example in the batch.
// This is computed by the forward pass
// grad_dists: FloatTensor of shape (S,)
//
// Returns:
// grad_points: FloatTensor of shape (P, 3)
// grad_segms: FloatTensor of shape (S, 2, 3)
//
#ifdef WITH_CUDA
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& idx_segms,
const torch::Tensor& grad_dists);
#endif
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& idx_segms,
const torch::Tensor& grad_dists) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
}
// ****************************************************************************
// * PointEdgeArrayDistance *
// ****************************************************************************
// Computes the squared euclidean distance of each p in points to each edge
// segment in segms.
//
// Args:
// points: FloatTensor of shape (P, 3)
// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th
// edge segment is spanned by (segms[s, 0], segms[s, 1])
//
// Returns:
// dists: FloatTensor of shape (P, S), where dists[p, s] is the squared
// euclidean distance of points[p] to the segment spanned by
// (segms[s, 0], segms[s, 1])
//
// For pointcloud and meshes of batch size N, this function requires N
// computations. The memory occupied is O(NPS) which can become quite large.
// For example, a medium sized batch with N = 32 with P = 10000 and S = 5000
// will require for the forward pass 5.8G of memory to store dists.
#ifdef WITH_CUDA
torch::Tensor PointEdgeArrayDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& segms);
#endif
torch::Tensor PointEdgeArrayDistanceForward(
const torch::Tensor& points,
const torch::Tensor& segms) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
return PointEdgeArrayDistanceForwardCuda(points, segms);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
}
// Backward pass for PointEdgeArrayDistance.
//
// Args:
// points: FloatTensor of shape (P, 3)
// segms: FloatTensor of shape (S, 2, 3)
// grad_dists: FloatTensor of shape (P, S)
//
// Returns:
// grad_points: FloatTensor of shape (P, 3)
// grad_segms: FloatTensor of shape (S, 2, 3)
//
#ifdef WITH_CUDA
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& grad_dists);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& grad_dists) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
}
This diff is collapsed.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
// ****************************************************************************
// * PointFaceDistance *
// ****************************************************************************
// Computes the squared euclidean distance of each p in points to it closest
// triangular face belonging to the corresponding mesh example in the batch of
// size N.
//
// Args:
// points: FloatTensor of shape (P, 3)
// points_first_idx: LongTensor of shape (N,) indicating the first point
// index for each example in the batch
// tris: FloatTensor of shape (T, 3, 3) of the triangular faces. The t-th
// triangulare face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2])
// tris_first_idx: LongTensor of shape (N,) indicating the first face
// index for each example in the batch
// max_points: Scalar equal to max(P_i) for i in [0, N - 1] containing
// the maximum number of points in the batch and is used to set
// the block dimensions in the CUDA implementation.
//
// Returns:
// dists: FloatTensor of shape (P,), where dists[p] is the minimum
// squared euclidean distance of points[p] to the faces in the same
// example in the batch.
// idxs: LongTensor of shape (P,), where idxs[p] is the index of the closest
// face in the batch.
// So, dists[p] = d(points[p], tris[idxs[p], 0], tris[idxs[p], 1],
// tris[idxs[p], 2]) where d(u, v0, v1, v2) is the distance of u from the
// face spanned by (v0, v1, v2)
//
//
#ifdef WITH_CUDA
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx,
const int64_t max_points);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx,
const int64_t max_points) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
return PointFaceDistanceForwardCuda(
points, points_first_idx, tris, tris_first_idx, max_points);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
}
// Backward pass for PointFaceDistance.
//
// Args:
// points: FloatTensor of shape (P, 3)
// tris: FloatTensor of shape (T, 3, 3)
// idx_points: LongTensor of shape (P,) containing the indices
// of the closest face in the example in the batch.
// This is computed by the forward pass
// grad_dists: FloatTensor of shape (P,)
//
// Returns:
// grad_points: FloatTensor of shape (P, 3)
// grad_tris: FloatTensor of shape (T, 3, 3)
//
#ifdef WITH_CUDA
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
return PointFaceDistanceBackwardCuda(points, tris, idx_points, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
}
// ****************************************************************************
// * FacePointDistance *
// ****************************************************************************
// Computes the squared euclidean distance of each triangular face to its
// closest point belonging to the corresponding example in the batch of size N.
//
// Args:
// points: FloatTensor of shape (P, 3)
// points_first_idx: LongTensor of shape (N,) indicating the first point
// index for each example in the batch
// tris: FloatTensor of shape (T, 3, 3) of the triangular faces. The t-th
// triangulare face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2])
// tris_first_idx: LongTensor of shape (N,) indicating the first face
// index for each example in the batch
// max_tris: Scalar equal to max(T_i) for i in [0, N - 1] containing
// the maximum number of faces in the batch and is used to set
// the block dimensions in the CUDA implementation.
//
// Returns:
// dists: FloatTensor of shape (T,), where dists[t] is the minimum squared
// euclidean distance of t-th triangular face from the closest point in
// the batch.
// idxs: LongTensor of shape (T,), where idxs[t] is the index of the closest
// point in the batch.
// So, dists[t] = d(points[idxs[t]], tris[t, 0], tris[t, 1], tris[t, 2])
// where d(u, v0, v1, v2) is the distance of u from the triangular face
// spanned by (v0, v1, v2)
//
#ifdef WITH_CUDA
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx,
const int64_t max_tros);
#endif
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx,
const int64_t max_tris) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
return FacePointDistanceForwardCuda(
points, points_first_idx, tris, tris_first_idx, max_tris);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
}
// Backward pass for FacePointDistance.
//
// Args:
// points: FloatTensor of shape (P, 3)
// tris: FloatTensor of shape (T, 3, 3)
// idx_tris: LongTensor of shape (T,) containing the indices
// of the closest point in the example in the batch.
// This is computed by the forward pass
// grad_dists: FloatTensor of shape (T,)
//
// Returns:
// grad_points: FloatTensor of shape (P, 3)
// grad_tris: FloatTensor of shape (T, 3, 3)
//
#ifdef WITH_CUDA
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_tris,
const torch::Tensor& grad_dists);
#endif
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_tris,
const torch::Tensor& grad_dists) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
return FacePointDistanceBackwardCuda(points, tris, idx_tris, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
}
// ****************************************************************************
// * PointFaceArrayDistance *
// ****************************************************************************
// Computes the squared euclidean distance of each p in points to each
// triangular face spanned by (v0, v1, v2) in tris.
//
// Args:
// points: FloatTensor of shape (P, 3)
// tris: FloatTensor of shape (T, 3, 3) of the triangular faces. The t-th
// triangulare face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2])
//
// Returns:
// dists: FloatTensor of shape (P, T), where dists[p, t] is the squared
// euclidean distance of points[p] to the face spanned by (v0, v1, v2)
// where v0 = tris[t, 0], v1 = tris[t, 1] and v2 = tris[t, 2]
//
// For pointcloud and meshes of batch size N, this function requires N
// computations. The memory occupied is O(NPT) which can become quite large.
// For example, a medium sized batch with N = 32 with P = 10000 and T = 5000
// will require for the forward pass 5.8G of memory to store dists.
#ifdef WITH_CUDA
torch::Tensor PointFaceArrayDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& tris);
#endif
torch::Tensor PointFaceArrayDistanceForward(
const torch::Tensor& points,
const torch::Tensor& tris) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
return PointFaceArrayDistanceForwardCuda(points, tris);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
}
// Backward pass for PointFaceArrayDistance.
//
// Args:
// points: FloatTensor of shape (P, 3)
// tris: FloatTensor of shape (T, 3, 3)
// grad_dists: FloatTensor of shape (P, T)
//
// Returns:
// grad_points: FloatTensor of shape (P, 3)
// grad_tris: FloatTensor of shape (T, 3, 3)
//
#ifdef WITH_CUDA
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& grad_dists);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& grad_dists) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
return PointFaceArrayDistanceBackwardCuda(points, tris, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
}
......@@ -6,10 +6,10 @@
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
#include "float_math.cuh"
#include "geometry_utils.cuh"
#include "rasterize_points/bitmask.cuh"
#include "rasterize_points/rasterization_utils.cuh"
#include "utils/float_math.cuh"
#include "utils/geometry_utils.cuh"
namespace {
// A structure for holding details about a pixel.
......
......@@ -5,9 +5,9 @@
#include <list>
#include <queue>
#include <tuple>
#include "geometry_utils.h"
#include "vec2.h"
#include "vec3.h"
#include "utils/geometry_utils.h"
#include "utils/vec2.h"
#include "utils/vec3.h"
float PixToNdc(int i, int S) {
// NDC x-offset + (i * pixel_width + half_pixel_width)
......
......@@ -3,6 +3,13 @@
#pragma once
#include <thrust/tuple.h>
// Set epsilon
#ifdef _MSC_VER
#define vEpsilon 1e-8f
#else
const auto vEpsilon = 1e-8;
#endif
// Common functions and operators for float2.
__device__ inline float2 operator-(const float2& a, const float2& b) {
......@@ -84,3 +91,49 @@ __device__ inline float dot(const float3& a, const float3& b) {
__device__ inline float sum(const float3& a) {
return a.x + a.y + a.z;
}
__device__ inline float3 cross(const float3& a, const float3& b) {
return make_float3(
a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x);
}
__device__ inline thrust::tuple<float3, float3>
cross_backward(const float3& a, const float3& b, const float3& grad_cross) {
const float grad_ax = -grad_cross.y * b.z + grad_cross.z * b.y;
const float grad_ay = grad_cross.x * b.z - grad_cross.z * b.x;
const float grad_az = -grad_cross.x * b.y + grad_cross.y * b.x;
const float3 grad_a = make_float3(grad_ax, grad_ay, grad_az);
const float grad_bx = grad_cross.y * a.z - grad_cross.z * a.y;
const float grad_by = -grad_cross.x * a.z + grad_cross.z * a.x;
const float grad_bz = grad_cross.x * a.y - grad_cross.y * a.x;
const float3 grad_b = make_float3(grad_bx, grad_by, grad_bz);
return thrust::make_tuple(grad_a, grad_b);
}
__device__ inline float norm(const float3& a) {
return sqrt(dot(a, a));
}
__device__ inline float3 normalize(const float3& a) {
return a / (norm(a) + vEpsilon);
}
__device__ inline float3 normalize_backward(
const float3& a,
const float3& grad_normz) {
const float a_norm = norm(a) + vEpsilon;
const float3 out = a / a_norm;
const float grad_ax = grad_normz.x * (1.0f - out.x * out.x) / a_norm +
grad_normz.y * (-out.x * out.y) / a_norm +
grad_normz.z * (-out.x * out.z) / a_norm;
const float grad_ay = grad_normz.x * (-out.x * out.y) / a_norm +
grad_normz.y * (1.0f - out.y * out.y) / a_norm +
grad_normz.z * (-out.y * out.z) / a_norm;
const float grad_az = grad_normz.x * (-out.x * out.z) / a_norm +
grad_normz.y * (-out.y * out.z) / a_norm +
grad_normz.z * (1.0f - out.z * out.z) / a_norm;
return make_float3(grad_ax, grad_ay, grad_az);
}
......@@ -8,11 +8,15 @@
// Set epsilon for preventing floating point errors and division by 0.
#ifdef _MSC_VER
#define kEpsilon 1e-30f
#define kEpsilon 1e-8f
#else
const auto kEpsilon = 1e-30;
const auto kEpsilon = 1e-8;
#endif
// ************************************************************* //
// vec2 utils //
// ************************************************************* //
// Determines whether a point p is on the right side of a 2D line segment
// given by the end points v0, v1.
//
......@@ -353,3 +357,295 @@ PointTriangleDistanceBackward(
return thrust::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
}
// ************************************************************* //
// vec3 utils //
// ************************************************************* //
// Computes the barycentric coordinates of a point p relative
// to a triangle (v0, v1, v2), i.e. p = w0 * v0 + w1 * v1 + w2 * v2
// s.t. w0 + w1 + w2 = 1.0
//
// NOTE that this function assumes that p lives on the space spanned
// by (v0, v1, v2).
// TODO(gkioxari) explicitly check whether p is coplanar with (v0, v1, v2)
// and throw an error if check fails
//
// Args:
// p: vec3 coordinates of a point
// v0, v1, v2: vec3 coordinates of the triangle vertices
//
// Returns
// bary: (w0, w1, w2) barycentric coordinates
//
__device__ inline float3 BarycentricCoords3Forward(
const float3& p,
const float3& v0,
const float3& v1,
const float3& v2) {
float3 p0 = v1 - v0;
float3 p1 = v2 - v0;
float3 p2 = p - v0;
const float d00 = dot(p0, p0);
const float d01 = dot(p0, p1);
const float d11 = dot(p1, p1);
const float d20 = dot(p2, p0);
const float d21 = dot(p2, p1);
const float denom = d00 * d11 - d01 * d01 + kEpsilon;
const float w1 = (d11 * d20 - d01 * d21) / denom;
const float w2 = (d00 * d21 - d01 * d20) / denom;
const float w0 = 1.0f - w1 - w2;
return make_float3(w0, w1, w2);
}
// Checks whether the point p is inside the triangle (v0, v1, v2).
// A point is inside the triangle, if all barycentric coordinates
// wrt the triangle are >= 0 & <= 1.
//
// NOTE that this function assumes that p lives on the space spanned
// by (v0, v1, v2).
// TODO(gkioxari) explicitly check whether p is coplanar with (v0, v1, v2)
// and throw an error if check fails
//
// Args:
// p: vec3 coordinates of a point
// v0, v1, v2: vec3 coordinates of the triangle vertices
//
// Returns:
// inside: bool indicating wether p is inside triangle
//
__device__ inline bool IsInsideTriangle(
const float3& p,
const float3& v0,
const float3& v1,
const float3& v2) {
float3 bary = BarycentricCoords3Forward(p, v0, v1, v2);
bool x_in = 0.0f <= bary.x && bary.x <= 1.0f;
bool y_in = 0.0f <= bary.y && bary.y <= 1.0f;
bool z_in = 0.0f <= bary.z && bary.z <= 1.0f;
bool inside = x_in && y_in && z_in;
return inside;
}
// Computes the minimum squared Euclidean distance between the point p
// and the segment spanned by (v0, v1).
// To find this we parametrize p as: x(t) = v0 + t * (v1 - v0)
// and find t which minimizes (x(t) - p) ^ 2.
// Note that p does not need to live in the space spanned by (v0, v1)
//
// Args:
// p: vec3 coordinates of a point
// v0, v1: vec3 coordinates of start and end of segment
//
// Returns:
// dist: the minimum squared distance of p from segment (v0, v1)
//
__device__ inline float
PointLine3DistanceForward(const float3& p, const float3& v0, const float3& v1) {
const float3 v1v0 = v1 - v0;
const float3 pv0 = p - v0;
const float t_bot = dot(v1v0, v1v0);
const float t_top = dot(pv0, v1v0);
// if t_bot small, then v0 == v1, set tt to 0.
float tt = (t_bot < kEpsilon) ? 0.0f : (t_top / t_bot);
tt = __saturatef(tt); // clamps to [0, 1]
const float3 p_proj = v0 + tt * v1v0;
const float3 diff = p - p_proj;
const float dist = dot(diff, diff);
return dist;
}
// Backward function of the minimum squared Euclidean distance between the point
// p and the line segment (v0, v1).
//
// Args:
// p: vec3 coordinates of a point
// v0, v1: vec3 coordinates of start and end of segment
// grad_dist: Float of the gradient wrt dist
//
// Returns:
// tuple of gradients for the point and line segment (v0, v1):
// (float3 grad_p, float3 grad_v0, float3 grad_v1)
__device__ inline thrust::tuple<float3, float3, float3>
PointLine3DistanceBackward(
const float3& p,
const float3& v0,
const float3& v1,
const float& grad_dist) {
const float3 v1v0 = v1 - v0;
const float3 pv0 = p - v0;
const float t_bot = dot(v1v0, v1v0);
const float t_top = dot(v1v0, pv0);
float3 grad_p = make_float3(0.0f, 0.0f, 0.0f);
float3 grad_v0 = make_float3(0.0f, 0.0f, 0.0f);
float3 grad_v1 = make_float3(0.0f, 0.0f, 0.0f);
const float tt = t_top / t_bot;
if (t_bot < kEpsilon) {
// if t_bot small, then v0 == v1,
// and dist = 0.5 * dot(pv0, pv0) + 0.5 * dot(pv1, pv1)
grad_p = grad_dist * 2.0f * pv0;
grad_v0 = -0.5f * grad_p;
grad_v1 = grad_v0;
} else if (tt < 0.0f) {
grad_p = grad_dist * 2.0f * pv0;
grad_v0 = -1.0f * grad_p;
// no gradients wrt v1
} else if (tt > 1.0f) {
grad_p = grad_dist * 2.0f * (p - v1);
grad_v1 = -1.0f * grad_p;
// no gradients wrt v0
} else {
const float3 p_proj = v0 + tt * v1v0;
const float3 diff = p - p_proj;
const float3 grad_base = grad_dist * 2.0f * diff;
grad_p = grad_base - dot(grad_base, v1v0) * v1v0 / t_bot;
const float3 dtt_v0 = (-1.0f * v1v0 - pv0 + 2.0f * tt * v1v0) / t_bot;
grad_v0 = (-1.0f + tt) * grad_base - dot(grad_base, v1v0) * dtt_v0;
const float3 dtt_v1 = (pv0 - 2.0f * tt * v1v0) / t_bot;
grad_v1 = -dot(grad_base, v1v0) * dtt_v1 - tt * grad_base;
}
return thrust::make_tuple(grad_p, grad_v0, grad_v1);
}
// Computes the squared distance of a point p relative to a triangle (v0, v1,
// v2). If the point's projection p0 on the plane spanned by (v0, v1, v2) is
// inside the triangle with vertices (v0, v1, v2), then the returned value is
// the squared distance of p to its projection p0. Otherwise, the returned value
// is the smallest squared distance of p from the line segments (v0, v1), (v0,
// v2) and (v1, v2).
//
// Args:
// p: vec3 coordinates of a point
// v0, v1, v2: vec3 coordinates of the triangle vertices
//
// Returns:
// dist: Float of the squared distance
//
__device__ inline float PointTriangle3DistanceForward(
const float3& p,
const float3& v0,
const float3& v1,
const float3& v2) {
float3 normal = cross(v2 - v0, v1 - v0);
const float norm_normal = norm(normal);
normal = normalize(normal);
// p0 is the projection of p on the plane spanned by (v0, v1, v2)
// i.e. p0 = p + t * normal, s.t. (p0 - v0) is orthogonal to normal
const float t = dot(v0 - p, normal);
const float3 p0 = p + t * normal;
bool is_inside = IsInsideTriangle(p0, v0, v1, v2);
float dist = 0.0f;
if ((is_inside) && (norm_normal > kEpsilon)) {
// if projection p0 is inside triangle spanned by (v0, v1, v2)
// then distance is equal to norm(p0 - p)^2
dist = t * t;
} else {
const float e01 = PointLine3DistanceForward(p, v0, v1);
const float e02 = PointLine3DistanceForward(p, v0, v2);
const float e12 = PointLine3DistanceForward(p, v1, v2);
dist = (e01 > e02) ? e02 : e01;
dist = (dist > e12) ? e12 : dist;
}
return dist;
}
// The backward pass for computing the squared distance of a point
// to the triangle (v0, v1, v2).
//
// Args:
// p: xyz coordinates of a point
// v0, v1, v2: xyz coordinates of the triangle vertices
// grad_dist: Float of the gradient wrt dist
//
// Returns:
// tuple of gradients for the point and triangle:
// (float3 grad_p, float3 grad_v0, float3 grad_v1, float3 grad_v2)
//
__device__ inline thrust::tuple<float3, float3, float3, float3>
PointTriangle3DistanceBackward(
const float3& p,
const float3& v0,
const float3& v1,
const float3& v2,
const float& grad_dist) {
const float3 v2v0 = v2 - v0;
const float3 v1v0 = v1 - v0;
const float3 v0p = v0 - p;
float3 raw_normal = cross(v2v0, v1v0);
const float norm_normal = norm(raw_normal);
float3 normal = normalize(raw_normal);
// p0 is the projection of p on the plane spanned by (v0, v1, v2)
// i.e. p0 = p + t * normal, s.t. (p0 - v0) is orthogonal to normal
const float t = dot(v0 - p, normal);
const float3 p0 = p + t * normal;
const float3 diff = t * normal;
bool is_inside = IsInsideTriangle(p0, v0, v1, v2);
float3 grad_p = make_float3(0.0f, 0.0f, 0.0f);
float3 grad_v0 = make_float3(0.0f, 0.0f, 0.0f);
float3 grad_v1 = make_float3(0.0f, 0.0f, 0.0f);
float3 grad_v2 = make_float3(0.0f, 0.0f, 0.0f);
if ((is_inside) && (norm_normal > kEpsilon)) {
// derivative of dist wrt p
grad_p = -2.0f * grad_dist * t * normal;
// derivative of dist wrt normal
const float3 grad_normal = 2.0f * grad_dist * t * (v0p + diff);
// derivative of dist wrt raw_normal
const float3 grad_raw_normal = normalize_backward(raw_normal, grad_normal);
// derivative of dist wrt v2v0 and v1v0
const auto grad_cross = cross_backward(v2v0, v1v0, grad_raw_normal);
const float3 grad_cross_v2v0 = thrust::get<0>(grad_cross);
const float3 grad_cross_v1v0 = thrust::get<1>(grad_cross);
grad_v0 =
grad_dist * 2.0f * t * normal - (grad_cross_v2v0 + grad_cross_v1v0);
grad_v1 = grad_cross_v1v0;
grad_v2 = grad_cross_v2v0;
} else {
const float e01 = PointLine3DistanceForward(p, v0, v1);
const float e02 = PointLine3DistanceForward(p, v0, v2);
const float e12 = PointLine3DistanceForward(p, v1, v2);
if ((e01 <= e02) && (e01 <= e12)) {
// e01 is smallest
const auto grads = PointLine3DistanceBackward(p, v0, v1, grad_dist);
grad_p = thrust::get<0>(grads);
grad_v0 = thrust::get<1>(grads);
grad_v1 = thrust::get<2>(grads);
} else if ((e02 <= e01) && (e02 <= e12)) {
// e02 is smallest
const auto grads = PointLine3DistanceBackward(p, v0, v2, grad_dist);
grad_p = thrust::get<0>(grads);
grad_v0 = thrust::get<1>(grads);
grad_v2 = thrust::get<2>(grads);
} else if ((e12 <= e01) && (e12 <= e02)) {
// e12 is smallest
const auto grads = PointLine3DistanceBackward(p, v1, v2, grad_dist);
grad_p = thrust::get<0>(grads);
grad_v1 = thrust::get<1>(grads);
grad_v2 = thrust::get<2>(grads);
}
}
return thrust::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
}
......@@ -7,7 +7,7 @@
#include "vec3.h"
// Set epsilon for preventing floating point errors and division by 0.
const auto kEpsilon = 1e-30;
const auto kEpsilon = 1e-8;
// Determines whether a point p is on the right side of a 2D line segment
// given by the end points v0, v1.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment