Commit 0d8608b9 authored by Jiali Duan's avatar Jiali Duan Committed by Facebook GitHub Bot
Browse files

Marching Cubes C++ torch extension

Summary:
Torch C++ extension for Marching Cubes

- Add torch C++ extension for marching cubes. Observe a speed up of ~255x-324x speed up (over varying batch sizes and spatial resolutions)

- Add C++ impl in existing unit-tests.

(Note: this ignores all push blocking failures!)

Reviewed By: kjchalup

Differential Revision: D39590638

fbshipit-source-id: e44d2852a24c2c398e5ea9db20f0dfaa1817e457
parent 850efdf7
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "interp_face_attrs/interp_face_attrs.h" #include "interp_face_attrs/interp_face_attrs.h"
#include "iou_box3d/iou_box3d.h" #include "iou_box3d/iou_box3d.h"
#include "knn/knn.h" #include "knn/knn.h"
#include "marching_cubes/marching_cubes.h"
#include "mesh_normal_consistency/mesh_normal_consistency.h" #include "mesh_normal_consistency/mesh_normal_consistency.h"
#include "packed_to_padded_tensor/packed_to_padded_tensor.h" #include "packed_to_padded_tensor/packed_to_padded_tensor.h"
#include "point_mesh/point_mesh_cuda.h" #include "point_mesh/point_mesh_cuda.h"
...@@ -94,6 +95,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -94,6 +95,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// 3D IoU // 3D IoU
m.def("iou_box3d", &IoUBox3D); m.def("iou_box3d", &IoUBox3D);
// Marching cubes
m.def("marching_cubes", &MarchingCubes);
// Pulsar. // Pulsar.
#ifdef PULSAR_LOGGING_ENABLED #ifdef PULSAR_LOGGING_ENABLED
c10::ShowLogInfoToStderr(); c10::ShowLogInfoToStderr();
......
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h>
#include <tuple>
#include <vector>
#include "utils/pytorch3d_cutils.h"
// Run Marching Cubes algorithm over a batch of volume scalar fields
// with a pre-defined threshold and return a mesh composed of vertices
// and faces for the mesh.
//
// Args:
// vol: FloatTensor of shape (D, H, W) giving a volume
// scalar grids.
// isolevel: isosurface value to use as the threshoold to determine whether
// the points are within a volume.
//
// Returns:
// vertices: List of N FloatTensors of vertices
// faces: List of N LongTensors of faces
// CPU implementation
std::tuple<at::Tensor, at::Tensor> MarchingCubesCpu(
const at::Tensor& vol,
const float isolevel);
// Implementation which is exposed
inline std::tuple<at::Tensor, at::Tensor> MarchingCubes(
const at::Tensor& vol,
const float isolevel) {
return MarchingCubesCpu(vol.contiguous(), isolevel);
}
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h>
#include <algorithm>
#include <array>
#include <cstring>
#include <unordered_map>
#include <vector>
#include "marching_cubes/marching_cubes_utils.h"
// Cpu implementation for Marching Cubes
// Args:
// vol: a Tensor of size (D, H, W) corresponding to a 3D scalar field
// isolevel: the isosurface value to use as the threshold to determine
// whether points are within a volume.
//
// Returns:
// vertices: a float tensor of shape (N, 3) for positions of the mesh
// faces: a long tensor of shape (N, 3) for indices of the face vertices
//
std::tuple<at::Tensor, at::Tensor> MarchingCubesCpu(
const at::Tensor& vol,
const float isolevel) {
// volume shapes
const int D = vol.size(0);
const int H = vol.size(1);
const int W = vol.size(2);
// Create tensor accessors
auto vol_a = vol.accessor<float, 3>();
// vpair_to_edge maps a pair of vertex ids to its corresponding edge id
std::unordered_map<std::pair<int, int>, int64_t> vpair_to_edge;
// edge_id_to_v maps from an edge id to a vertex position
std::unordered_map<int64_t, Vertex> edge_id_to_v;
// uniq_edge_id: used to remove redundant edge ids
std::unordered_map<int64_t, int64_t> uniq_edge_id;
std::vector<int64_t> faces; // store face indices
std::vector<Vertex> verts; // store vertex positions
// enumerate each cell in the 3d grid
for (int z = 0; z < D - 1; z++) {
for (int y = 0; y < H - 1; y++) {
for (int x = 0; x < W - 1; x++) {
Cube cube(x, y, z, vol_a, isolevel);
// Cube is entirely in/out of the surface
if (_FACE_TABLE[cube.cubeindex][0] == -1) {
continue;
}
// store all boundary vertices that intersect with the edges
std::array<Vertex, 12> interp_points;
// triangle vertex IDs and positions
std::vector<int64_t> tri;
std::vector<Vertex> ps;
// Interpolate the vertices where the surface intersects with the cube
for (int j = 0; _FACE_TABLE[cube.cubeindex][j] != -1; j++) {
const int e = _FACE_TABLE[cube.cubeindex][j];
interp_points[e] = cube.VertexInterp(isolevel, e, vol_a);
auto vpair = cube.GetVPairFromEdge(e, W, H);
if (!vpair_to_edge.count(vpair)) {
vpair_to_edge[vpair] = vpair_to_edge.size();
}
int64_t edge = vpair_to_edge[vpair];
tri.push_back(edge);
ps.push_back(interp_points[e]);
// Check if the triangle face is degenerate. A triangle face
// is degenerate if any of the two verices share the same 3D position
if ((j + 1) % 3 == 0 && ps[0] != ps[1] && ps[1] != ps[2] &&
ps[2] != ps[0]) {
for (int k = 0; k < 3; k++) {
int v = tri[k];
edge_id_to_v[tri.at(k)] = ps.at(k);
if (!uniq_edge_id.count(v)) {
uniq_edge_id[v] = verts.size();
verts.push_back(edge_id_to_v[v]);
}
faces.push_back(uniq_edge_id[v]);
}
tri.clear();
ps.clear();
}
} // endif
} // endfor x
} // endfor y
} // endfor z
// Collect returning tensor
const int n_vertices = verts.size();
const int64_t n_faces = (int64_t)faces.size() / 3;
auto vert_tensor = torch::zeros({n_vertices, 3}, torch::kFloat);
auto face_tensor = torch::zeros({n_faces, 3}, torch::kInt64);
auto vert_a = vert_tensor.accessor<float, 2>();
for (int i = 0; i < n_vertices; i++) {
vert_a[i][0] = verts.at(i).x;
vert_a[i][1] = verts.at(i).y;
vert_a[i][2] = verts.at(i).z;
}
auto face_a = face_tensor.accessor<int64_t, 2>();
for (int64_t i = 0; i < n_faces; i++) {
face_a[i][0] = faces.at(i * 3 + 0);
face_a[i][1] = faces.at(i * 3 + 1);
face_a[i][2] = faces.at(i * 3 + 2);
}
return std::make_tuple(vert_tensor, face_tensor);
}
This diff is collapsed.
...@@ -7,8 +7,10 @@ ...@@ -7,8 +7,10 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from pytorch3d import _C
from pytorch3d.ops.marching_cubes_data import EDGE_TO_VERTICES, FACE_TABLE, INDEX from pytorch3d.ops.marching_cubes_data import EDGE_TO_VERTICES, FACE_TABLE, INDEX
from pytorch3d.transforms import Translate from pytorch3d.transforms import Translate
from torch.autograd import Function
EPS = 0.00001 EPS = 0.00001
...@@ -225,3 +227,71 @@ def marching_cubes_naive( ...@@ -225,3 +227,71 @@ def marching_cubes_naive(
batched_verts.append([]) batched_verts.append([])
batched_faces.append([]) batched_faces.append([])
return batched_verts, batched_faces return batched_verts, batched_faces
########################################
# Marching Cubes Implementation in C++
########################################
class _marching_cubes(Function):
"""
Torch Function wrapper for marching_cubes C++ implementation
Backward is not supported.
"""
@staticmethod
def forward(ctx, vol, isolevel):
verts, faces = _C.marching_cubes(vol, isolevel)
return verts, faces
@staticmethod
def backward(ctx, grad_verts, grad_faces):
raise ValueError("marching_cubes backward is not supported")
def marching_cubes(
vol_batch: torch.Tensor,
isolevel: Optional[float] = None,
return_local_coords: bool = True,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Run marching cubes over a volume scalar field with a designated isolevel.
Returns vertices and faces of the obtained mesh.
This operation is non-differentiable.
Args:
vol_batch: a Tensor of size (N, D, H, W) corresponding to
a batch of 3D scalar fields
isolevel: float used as threshold to determine if a point is inside/outside
the volume. If None, then the average of the maximum and minimum value
of the scalar field is used.
return_local_coords: bool. If True the output vertices will be in local coordinates in
the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range
[0, W-1] x [0, H-1] x [0, D-1]
Returns:
verts: [{V_0}, {V_1}, ...] List of N sets of vertices of shape (|V_i|, 3) in FloatTensor
faces: [{F_0}, {F_1}, ...] List of N sets of faces of shape (|F_i|, 3) in LongTensors
"""
batched_verts, batched_faces = [], []
D, H, W = vol_batch.shape[1:]
for i in range(len(vol_batch)):
vol = vol_batch[i]
thresh = ((vol.max() + vol.min()) / 2).item() if isolevel is None else isolevel
# pyre-fixme[16]: `_marching_cubes` has no attribute `apply`.
verts, faces = _marching_cubes.apply(vol, thresh)
if len(faces) > 0 and len(verts) > 0:
# Convert from world coordinates ([0, D-1], [0, H-1], [0, W-1]) to
# local coordinates in the range [-1, 1]
if return_local_coords:
verts = (
Translate(x=+1.0, y=+1.0, z=+1.0, device=vol.device)
.scale((vol.new_tensor([W, H, D])[None] - 1) * 0.5)
.inverse()
).transform_points(verts[None])[0]
batched_verts.append(verts)
batched_faces.append(faces)
else:
batched_verts.append([])
batched_faces.append([])
return batched_verts, batched_faces
...@@ -4,19 +4,24 @@ ...@@ -4,19 +4,24 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import itertools
from fvcore.common.benchmark import benchmark from fvcore.common.benchmark import benchmark
from tests.test_marching_cubes import TestMarchingCubes from tests.test_marching_cubes import TestMarchingCubes
def bm_marching_cubes() -> None: def bm_marching_cubes() -> None:
kwargs_list = [ case_grid = {
{"batch_size": 1, "V": 5}, "algo_type": [
{"batch_size": 1, "V": 10}, "naive",
{"batch_size": 1, "V": 20}, "cextension",
{"batch_size": 1, "V": 40}, ],
{"batch_size": 5, "V": 5}, "batch_size": [1, 5, 20],
{"batch_size": 20, "V": 20}, "V": [5, 10, 20],
] }
test_cases = itertools.product(*case_grid.values())
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
benchmark( benchmark(
TestMarchingCubes.marching_cubes_with_init, TestMarchingCubes.marching_cubes_with_init,
"MARCHING_CUBES", "MARCHING_CUBES",
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment