signed_distance.py 3.03 KB
Newer Older
ashawkey's avatar
ashawkey committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import trimesh
import argparse
import torch
import cubvh
import time


def create_dodecahedron(radius=1, center=np.array([0, 0, 0])):

    vertices = np.array([
        -0.57735,  -0.57735,  0.57735,
        0.934172,  0.356822,  0,
        0.934172,  -0.356822,  0,
        -0.934172,  0.356822,  0,
        -0.934172,  -0.356822,  0,
        0,  0.934172,  0.356822,
        0,  0.934172,  -0.356822,
        0.356822,  0,  -0.934172,
        -0.356822,  0,  -0.934172,
        0,  -0.934172,  -0.356822,
        0,  -0.934172,  0.356822,
        0.356822,  0,  0.934172,
        -0.356822,  0,  0.934172,
        0.57735,  0.57735,  -0.57735,
        0.57735,  0.57735,  0.57735,
        -0.57735,  0.57735,  -0.57735,
        -0.57735,  0.57735,  0.57735,
        0.57735,  -0.57735,  -0.57735,
        0.57735,  -0.57735,  0.57735,
        -0.57735,  -0.57735,  -0.57735,
        ]).reshape((-1,3), order="C")

    faces = np.array([
        19, 3, 2,
        12, 19, 2,
        15, 12, 2,
        8, 14, 2,
        18, 8, 2,
        3, 18, 2,
        20, 5, 4,
        9, 20, 4,
        16, 9, 4,
        13, 17, 4,
        1, 13, 4,
        5, 1, 4,
        7, 16, 4,
        6, 7, 4,
        17, 6, 4,
        6, 15, 2,
        7, 6, 2,
        14, 7, 2,
        10, 18, 3,
        11, 10, 3,
        19, 11, 3,
        11, 1, 5,
        10, 11, 5,
        20, 10, 5,
        20, 9, 8,
        10, 20, 8,
        18, 10, 8,
        9, 16, 7,
        8, 9, 7,
        14, 8, 7,
        12, 15, 6,
        13, 12, 6,
        17, 13, 6,
        13, 1, 11,
        12, 13, 11,
        19, 12, 11,
        ]).reshape((-1, 3), order="C")-1

    length = np.linalg.norm(vertices, axis=1).reshape((-1, 1))
    vertices = vertices / length * radius + center

    return trimesh.Trimesh(vertices=vertices, faces=faces)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--N', default=1000, type=int)
    parser.add_argument('--mesh', default='', type=str)
    
    opt = parser.parse_args()

    if opt.mesh == '':
        mesh = create_dodecahedron()
    else:
        mesh = trimesh.load(opt.mesh, force='mesh', skip_material=True)


    # query nearest triangles for a set of points
    points = torch.randn(opt.N, 3, device='cuda', dtype=torch.float32)

    # Ours
    _t0 = time.time()
    BVH = cubvh.cuBVH(mesh.vertices, mesh.faces)
    torch.cuda.synchronize()
    _t1 = time.time()
    distances, face_id, _ = BVH.signed_distance(points, return_uvw=False, mode='raystab')
    torch.cuda.synchronize()
    _t2 = time.time()
    print(f'[TIME] Ours total {_t2 - _t0:.6f}s = build {_t1 - _t0:.6f}s + query {_t2 - _t1:.6f}s')

    # GT results by trimesh
    _t0 = time.time()
    gt_distances = -trimesh.proximity.signed_distance(mesh, points.cpu().numpy())
    _t1 = time.time()
    print(f'[TIME] Trimesh total {_t1 - _t0:.6f}s')

    # verify correctness
    
    distances = distances.cpu().numpy().astype(np.float32)
    np.testing.assert_allclose(
        distances,
        gt_distances.astype(np.float32),
        atol=1e-5
    )