test_marching_cubes.py 37.6 KB
Newer Older
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
Patrick Labatut's avatar
Patrick Labatut committed
2
3
4
5
6
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

7
8
9
10
11
import os
import pickle
import unittest

import torch
12
from pytorch3d.ops.marching_cubes import marching_cubes, marching_cubes_naive
13

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
14
15
from .common_testing import get_tests_dir, TestCaseMixin

16
17

USE_SCIKIT = False
18
DATA_DIR = get_tests_dir() / "data"
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34


def convert_to_local(verts, volume_dim):
    return (2 * verts) / (volume_dim - 1) - 1


class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):

    # Test single cubes. Each case corresponds to the corresponding
    # cube vertex configuration in each case here (0-indexed):
    # https://en.wikipedia.org/wiki/Marching_cubes#/media/File:MarchingCubes.svg

    def test_empty_volume(self):  # case 0
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

35
36
        expected_verts = torch.tensor([[]])
        expected_faces = torch.tensor([[]], dtype=torch.int64)
37
38
39
        self.assertClose(verts, expected_verts)
        self.assertClose(faces, expected_faces)

40
41
42
43
44
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts, expected_verts)
        self.assertClose(faces, expected_faces)

45
46
47
48
49
50
51
52
53
    def test_case1(self):  # case 1
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 0, 0, 0] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
        expected_verts = torch.tensor(
            [
                [0.5, 0, 0],
                [0, 0.5, 0],
54
                [0, 0, 0.5],
55
56
57
            ]
        )

58
        expected_faces = torch.tensor([[0, 1, 2]])
59
60
61
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

62
63
64
65
66
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

67
68
69
70
71
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

72
73
74
75
76
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

77
78
79
80
81
82
83
84
85
    def test_case2(self):
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 0:2, 0, 0] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
                [1.0000, 0.0000, 0.5000],
86
                [0.0000, 0.5000, 0.0000],
87
88
89
90
                [0.0000, 0.0000, 0.5000],
                [1.0000, 0.5000, 0.0000],
            ]
        )
91
        expected_faces = torch.tensor([[0, 1, 2], [3, 1, 0]])
92
93
94
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

95
96
97
98
99
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

100
101
102
103
104
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

105
106
107
108
109
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

110
111
112
113
114
115
116
117
118
    def test_case3(self):
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 0, 0, 0] = 0
        volume_data[0, 1, 1, 0] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
119
                [1.0000, 0.5000, 0.0000],
120
121
                [1.0000, 1.0000, 0.5000],
                [0.5000, 1.0000, 0.0000],
122
                [0.5000, 0.0000, 0.0000],
123
                [0.0000, 0.5000, 0.0000],
124
                [0.0000, 0.0000, 0.5000],
125
126
            ]
        )
127
        expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5]])
128
129
130
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

131
132
133
134
135
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

136
137
138
139
140
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

141
142
143
144
145
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

146
147
148
149
150
151
152
153
154
155
156
    def test_case4(self):
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 1, 0, 0] = 0
        volume_data[0, 1, 0, 1] = 0
        volume_data[0, 0, 0, 1] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
                [0.0000, 0.0000, 0.5000],
157
158
                [1.0000, 0.5000, 0.0000],
                [0.5000, 0.0000, 0.0000],
159
160
161
162
                [0.0000, 0.5000, 1.0000],
                [1.0000, 0.5000, 1.0000],
            ]
        )
163
        expected_faces = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 4, 1]])
164
165
166
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

167
168
169
170
171
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

172
173
174
175
176
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

177
178
179
180
181
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

182
183
184
185
186
187
188
189
190
191
    def test_case5(self):
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 0:2, 0, 0:2] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
                [1.0000, 0.5000, 0.0000],
                [0.0000, 0.5000, 0.0000],
192
193
                [1.0000, 0.5000, 1.0000],
                [0.0000, 0.5000, 1.0000],
194
195
196
            ]
        )

197
        expected_faces = torch.tensor([[0, 1, 2], [2, 1, 3]])
198
199
200
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

201
202
203
204
205
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

206
207
208
209
210
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

211
212
213
214
215
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

216
217
218
219
220
221
222
223
224
225
226
227
228
    def test_case6(self):
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 1, 0, 0] = 0
        volume_data[0, 1, 0, 1] = 0
        volume_data[0, 0, 0, 1] = 0
        volume_data[0, 0, 1, 0] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
                [0.5000, 1.0000, 0.0000],
                [0.0000, 1.0000, 0.5000],
229
230
231
                [0.0000, 0.5000, 0.0000],
                [1.0000, 0.5000, 0.0000],
                [0.5000, 0.0000, 0.0000],
232
233
                [0.0000, 0.5000, 1.0000],
                [1.0000, 0.5000, 1.0000],
234
                [0.0000, 0.0000, 0.5000],
235
236
            ]
        )
237
        expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5], [3, 5, 6], [5, 4, 7]])
238
239
240
241

        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

242
243
244
245
246
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

247
248
249
250
251
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

252
253
254
255
256
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

257
258
259
260
261
262
263
264
265
266
267
    def test_case7(self):
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 0, 0, 0] = 0
        volume_data[0, 1, 0, 1] = 0
        volume_data[0, 1, 1, 0] = 0
        volume_data[0, 0, 1, 1] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
268
269
270
                [0.5000, 1.0000, 1.0000],
                [0.0000, 0.5000, 1.0000],
                [0.0000, 1.0000, 0.5000],
271
                [1.0000, 0.0000, 0.5000],
272
273
                [0.5000, 0.0000, 1.0000],
                [1.0000, 0.5000, 1.0000],
274
                [0.5000, 0.0000, 0.0000],
275
                [0.0000, 0.5000, 0.0000],
276
277
278
                [0.0000, 0.0000, 0.5000],
                [0.5000, 1.0000, 0.0000],
                [1.0000, 0.5000, 0.0000],
279
                [1.0000, 1.0000, 0.5000],
280
281
282
            ]
        )

283
        expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]])
284
285
286
287

        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

288
289
290
291
292
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

293
294
295
296
297
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

298
299
300
301
302
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

303
304
305
306
307
308
309
310
311
312
313
314
    def test_case8(self):
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 0, 0, 0] = 0
        volume_data[0, 0, 0, 1] = 0
        volume_data[0, 1, 0, 1] = 0
        volume_data[0, 0, 1, 1] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
                [1.0000, 0.5000, 1.0000],
315
316
317
                [0.0000, 1.0000, 0.5000],
                [0.5000, 1.0000, 1.0000],
                [1.0000, 0.0000, 0.5000],
318
                [0.0000, 0.5000, 0.0000],
319
                [0.5000, 0.0000, 0.0000],
320
321
            ]
        )
322
        expected_faces = torch.tensor([[0, 1, 2], [3, 1, 0], [3, 4, 1], [3, 5, 4]])
323
324
325
326

        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

327
328
329
330
331
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

332
333
334
335
336
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

337
338
339
340
341
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    def test_case9(self):
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 1, 0, 0] = 0
        volume_data[0, 0, 0, 1] = 0
        volume_data[0, 1, 0, 1] = 0
        volume_data[0, 0, 1, 1] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
                [0.5000, 0.0000, 0.0000],
                [0.0000, 0.0000, 0.5000],
                [0.0000, 1.0000, 0.5000],
                [1.0000, 0.5000, 1.0000],
                [1.0000, 0.5000, 0.0000],
358
                [0.5000, 1.0000, 1.0000],
359
360
            ]
        )
361
        expected_faces = torch.tensor([[0, 1, 2], [0, 2, 3], [0, 3, 4], [5, 3, 2]])
362
363
364
365

        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

366
367
368
369
370
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

371
372
373
374
375
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

376
377
378
379
380
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

381
382
383
384
385
386
387
388
389
390
    def test_case10(self):
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 0, 0, 0] = 0
        volume_data[0, 1, 1, 1] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
                [0.5000, 0.0000, 0.0000],
391
                [0.0000, 0.5000, 0.0000],
392
393
394
                [0.0000, 0.0000, 0.5000],
                [1.0000, 1.0000, 0.5000],
                [1.0000, 0.5000, 1.0000],
395
                [0.5000, 1.0000, 1.0000],
396
397
398
            ]
        )

399
        expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5]])
400
401
402
403

        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

404
405
406
407
408
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

409
410
411
412
413
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

414
415
416
417
418
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

419
420
421
422
423
424
425
426
427
428
429
    def test_case11(self):
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 0, 0, 0] = 0
        volume_data[0, 1, 0, 0] = 0
        volume_data[0, 1, 1, 1] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
                [1.0000, 0.0000, 0.5000],
430
                [0.0000, 0.5000, 0.0000],
431
                [0.0000, 0.0000, 0.5000],
432
                [1.0000, 0.5000, 0.0000],
433
434
                [1.0000, 1.0000, 0.5000],
                [1.0000, 0.5000, 1.0000],
435
                [0.5000, 1.0000, 1.0000],
436
437
438
            ]
        )

439
        expected_faces = torch.tensor([[0, 1, 2], [0, 3, 1], [4, 5, 6]])
440
441
442
443

        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

444
445
446
447
448
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

449
450
451
452
453
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

454
455
456
457
458
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

459
460
461
462
463
464
465
466
467
468
469
    def test_case12(self):
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 1, 0, 0] = 0
        volume_data[0, 0, 1, 0] = 0
        volume_data[0, 1, 1, 1] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
                [1.0000, 0.0000, 0.5000],
470
                [1.0000, 0.5000, 0.0000],
471
472
473
                [0.5000, 0.0000, 0.0000],
                [1.0000, 1.0000, 0.5000],
                [1.0000, 0.5000, 1.0000],
474
                [0.5000, 1.0000, 1.0000],
475
                [0.0000, 0.5000, 0.0000],
476
477
                [0.5000, 1.0000, 0.0000],
                [0.0000, 1.0000, 0.5000],
478
479
480
            ]
        )

481
        expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
482
483
484
485

        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

486
487
488
489
490
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

491
492
493
494
495
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

496
497
498
499
500
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

501
502
503
504
505
506
507
508
509
510
511
512
    def test_case13(self):
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 0, 0, 0] = 0
        volume_data[0, 0, 1, 0] = 0
        volume_data[0, 1, 0, 1] = 0
        volume_data[0, 1, 1, 1] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
                [1.0000, 0.0000, 0.5000],
513
                [0.5000, 0.0000, 1.0000],
514
                [1.0000, 1.0000, 0.5000],
515
516
517
                [0.5000, 1.0000, 1.0000],
                [0.0000, 0.0000, 0.5000],
                [0.5000, 0.0000, 0.0000],
518
519
520
521
522
                [0.5000, 1.0000, 0.0000],
                [0.0000, 1.0000, 0.5000],
            ]
        )

523
        expected_faces = torch.tensor([[0, 1, 2], [2, 1, 3], [4, 5, 6], [4, 6, 7]])
524
525
526
527

        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

528
529
530
531
532
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

533
534
535
536
537
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

538
539
540
541
542
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

543
544
545
546
547
548
549
550
551
552
553
554
555
    def test_case14(self):
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
        volume_data[0, 0, 0, 0] = 0
        volume_data[0, 0, 0, 1] = 0
        volume_data[0, 1, 0, 1] = 0
        volume_data[0, 1, 1, 1] = 0
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
                [0.5000, 0.0000, 0.0000],
                [0.0000, 0.5000, 0.0000],
556
557
558
559
                [0.0000, 0.5000, 1.0000],
                [1.0000, 1.0000, 0.5000],
                [1.0000, 0.0000, 0.5000],
                [0.5000, 1.0000, 1.0000],
560
561
562
            ]
        )

563
        expected_faces = torch.tensor([[0, 1, 2], [0, 2, 3], [0, 3, 4], [3, 2, 5]])
564
565
566
567

        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

568
569
570
571
572
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

573
574
575
576
577
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 2)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

578
579
580
581
582
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

583
584
585
586
587
588
589
590
591
592

class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
    def test_single_point(self):
        volume_data = torch.zeros(1, 3, 3, 3)  # (B, W, H, D)
        volume_data[0, 1, 1, 1] = 1
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)

        expected_verts = torch.tensor(
            [
593
594
595
596
597
598
                [1.0000, 0.5000, 1.0000],
                [1.0000, 1.0000, 0.5000],
                [0.5000, 1.0000, 1.0000],
                [1.5000, 1.0000, 1.0000],
                [1.0000, 1.5000, 1.0000],
                [1.0000, 1.0000, 1.5000],
599
600
601
602
            ]
        )
        expected_faces = torch.tensor(
            [
603
604
605
606
607
608
609
                [0, 1, 2],
                [1, 0, 3],
                [1, 4, 2],
                [1, 3, 4],
                [0, 2, 5],
                [3, 0, 5],
                [2, 4, 5],
610
611
612
613
614
615
                [3, 5, 4],
            ]
        )
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

616
617
618
619
620
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

621
622
623
624
625
626
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 3)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)
        self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())

627
628
629
630
631
632
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)
        self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())

633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
    def test_cube(self):
        volume_data = torch.zeros(1, 5, 5, 5)  # (B, W, H, D)
        volume_data[0, 1, 1, 1] = 1
        volume_data[0, 1, 1, 2] = 1
        volume_data[0, 2, 1, 1] = 1
        volume_data[0, 2, 1, 2] = 1
        volume_data[0, 1, 2, 1] = 1
        volume_data[0, 1, 2, 2] = 1
        volume_data[0, 2, 2, 1] = 1
        volume_data[0, 2, 2, 2] = 1
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=False)

        expected_verts = torch.tensor(
            [
                [1.0000, 0.9000, 1.0000],
649
650
                [1.0000, 1.0000, 0.9000],
                [0.9000, 1.0000, 1.0000],
651
                [2.0000, 0.9000, 1.0000],
652
653
654
655
                [2.0000, 1.0000, 0.9000],
                [2.1000, 1.0000, 1.0000],
                [1.0000, 2.0000, 0.9000],
                [0.9000, 2.0000, 1.0000],
656
                [2.0000, 2.0000, 0.9000],
657
658
                [2.1000, 2.0000, 1.0000],
                [1.0000, 2.1000, 1.0000],
659
                [2.0000, 2.1000, 1.0000],
660
661
662
                [1.0000, 0.9000, 2.0000],
                [0.9000, 1.0000, 2.0000],
                [2.0000, 0.9000, 2.0000],
663
                [2.1000, 1.0000, 2.0000],
664
                [0.9000, 2.0000, 2.0000],
665
                [2.1000, 2.0000, 2.0000],
666
667
668
669
670
671
                [1.0000, 2.1000, 2.0000],
                [2.0000, 2.1000, 2.0000],
                [1.0000, 1.0000, 2.1000],
                [2.0000, 1.0000, 2.1000],
                [1.0000, 2.0000, 2.1000],
                [2.0000, 2.0000, 2.1000],
672
673
674
675
676
            ]
        )

        expected_faces = torch.tensor(
            [
677
678
679
680
681
682
683
684
685
686
                [0, 1, 2],
                [0, 3, 4],
                [1, 0, 4],
                [4, 3, 5],
                [1, 6, 7],
                [2, 1, 7],
                [4, 8, 1],
                [1, 8, 6],
                [8, 4, 5],
                [9, 8, 5],
687
                [6, 10, 7],
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
                [6, 8, 11],
                [10, 6, 11],
                [8, 9, 11],
                [12, 0, 2],
                [13, 12, 2],
                [3, 0, 14],
                [14, 0, 12],
                [15, 5, 3],
                [14, 15, 3],
                [2, 7, 13],
                [7, 16, 13],
                [5, 15, 9],
                [9, 15, 17],
                [10, 18, 16],
                [7, 10, 16],
                [11, 19, 10],
                [19, 18, 10],
705
706
                [9, 17, 19],
                [11, 9, 19],
707
708
709
                [12, 13, 20],
                [14, 12, 20],
                [21, 14, 20],
710
                [15, 14, 21],
711
712
                [13, 16, 22],
                [20, 13, 22],
713
                [21, 20, 23],
714
                [20, 22, 23],
715
716
                [17, 15, 21],
                [23, 17, 21],
717
                [16, 18, 22],
718
719
720
721
722
723
724
                [23, 22, 18],
                [19, 23, 18],
                [17, 23, 19],
            ]
        )
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)
725

726
727
728
729
730
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, 0.9, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

731
732
733
734
735
736
737
738
        verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 5)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

        # Check all values are in the range [-1, 1]
        self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())

739
740
741
742
743
744
        # test C++ implementation
        verts, faces = marching_cubes(volume_data, 0.9, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)
        self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())

745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
    def test_cube_no_duplicate_verts(self):
        volume_data = torch.zeros(1, 5, 5, 5)  # (B, W, H, D)
        volume_data[0, 1, 1, 1] = 1
        volume_data[0, 1, 1, 2] = 1
        volume_data[0, 2, 1, 1] = 1
        volume_data[0, 2, 1, 2] = 1
        volume_data[0, 1, 2, 1] = 1
        volume_data[0, 1, 2, 2] = 1
        volume_data[0, 2, 2, 1] = 1
        volume_data[0, 2, 2, 2] = 1
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(volume_data, 1, return_local_coords=False)

        expected_verts = torch.tensor(
            [
760
761
762
763
764
                [2.0, 1.0, 1.0],
                [2.0, 2.0, 1.0],
                [1.0, 1.0, 1.0],
                [1.0, 2.0, 1.0],
                [2.0, 1.0, 1.0],
765
                [1.0, 1.0, 1.0],
766
                [2.0, 1.0, 2.0],
767
                [1.0, 1.0, 2.0],
768
                [1.0, 1.0, 1.0],
769
                [1.0, 2.0, 1.0],
770
                [1.0, 1.0, 2.0],
771
772
773
774
775
                [1.0, 2.0, 2.0],
                [2.0, 1.0, 1.0],
                [2.0, 1.0, 2.0],
                [2.0, 2.0, 1.0],
                [2.0, 2.0, 2.0],
776
777
778
779
780
781
782
783
                [2.0, 2.0, 1.0],
                [2.0, 2.0, 2.0],
                [1.0, 2.0, 1.0],
                [1.0, 2.0, 2.0],
                [2.0, 1.0, 2.0],
                [1.0, 1.0, 2.0],
                [2.0, 2.0, 2.0],
                [1.0, 2.0, 2.0],
784
785
786
787
788
            ]
        )

        expected_faces = torch.tensor(
            [
789
790
791
792
793
794
795
796
797
798
799
800
                [0, 1, 2],
                [2, 1, 3],
                [4, 5, 6],
                [6, 5, 7],
                [8, 9, 10],
                [9, 11, 10],
                [12, 13, 14],
                [14, 13, 15],
                [16, 17, 18],
                [17, 19, 18],
                [20, 21, 22],
                [21, 23, 22],
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
            ]
        )
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

        verts, faces = marching_cubes_naive(volume_data, 1, return_local_coords=True)
        expected_verts = convert_to_local(expected_verts, 5)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

        # Check all values are in the range [-1, 1]
        self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())

    def test_sphere(self):
        # (B, W, H, D)
        volume = torch.Tensor(
            [
                [
                    [(x - 10) ** 2 + (y - 10) ** 2 + (z - 10) ** 2 for z in range(20)]
                    for y in range(20)
                ]
                for x in range(20)
            ]
        ).unsqueeze(0)
        volume = volume.permute(0, 3, 2, 1)  # (B, D, H, W)
        verts, faces = marching_cubes_naive(
            volume, isolevel=64, return_local_coords=False
        )

        data_filename = "test_marching_cubes_data/sphere_level64.pickle"
        filename = os.path.join(DATA_DIR, data_filename)
        with open(filename, "rb") as file:
            verts_and_faces = pickle.load(file)
834
835
        expected_verts = verts_and_faces["verts"]
        expected_faces = verts_and_faces["faces"]
836
837
838
839

        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

840
841
842
843
844
        # test C++ implementation
        verts, faces = marching_cubes(volume, 64, return_local_coords=False)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

845
846
847
848
849
850
851
852
853
854
855
        verts, faces = marching_cubes_naive(
            volume, isolevel=64, return_local_coords=True
        )

        expected_verts = convert_to_local(expected_verts, 20)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)

        # Check all values are in the range [-1, 1]
        self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())

856
857
858
859
860
861
        # test C++ implementation
        verts, faces = marching_cubes(volume, 64, return_local_coords=True)
        self.assertClose(verts[0], expected_verts)
        self.assertClose(faces[0], expected_faces)
        self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())

862
863
864
865
866
867
868
869
870
871
872
873
874
    # Uses skimage.draw.ellipsoid
    def test_double_ellipsoid(self):
        if USE_SCIKIT:
            import numpy as np
            from skimage.draw import ellipsoid

            ellip_base = ellipsoid(6, 10, 16, levelset=True)
            ellip_double = np.concatenate(
                (ellip_base[:-1, ...], ellip_base[2:, ...]), axis=0
            )
            volume = torch.Tensor(ellip_double).unsqueeze(0)
            volume = volume.permute(0, 3, 2, 1)  # (B, D, H, W)
            verts, faces = marching_cubes_naive(volume, isolevel=0.001)
875
            verts2, faces2 = marching_cubes(volume, isolevel=0.001)
876
877
878
879
880
881
882
883

            data_filename = "test_marching_cubes_data/double_ellipsoid.pickle"
            filename = os.path.join(DATA_DIR, data_filename)
            with open(filename, "rb") as file:
                verts_and_faces = pickle.load(file)
            expected_verts = verts_and_faces["verts"]
            expected_faces = verts_and_faces["faces"]

884
885
            self.assertClose(verts[0], expected_verts)
            self.assertClose(faces[0], expected_faces)
886
887
            self.assertClose(verts2[0], expected_verts)
            self.assertClose(faces2[0], expected_faces)
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903

    def test_cube_surface_area(self):
        if USE_SCIKIT:
            from skimage.measure import marching_cubes_classic, mesh_surface_area

            volume_data = torch.zeros(1, 5, 5, 5)
            volume_data[0, 1, 1, 1] = 1
            volume_data[0, 1, 1, 2] = 1
            volume_data[0, 2, 1, 1] = 1
            volume_data[0, 2, 1, 2] = 1
            volume_data[0, 1, 2, 1] = 1
            volume_data[0, 1, 2, 2] = 1
            volume_data[0, 2, 2, 1] = 1
            volume_data[0, 2, 2, 2] = 1
            volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
            verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
904
            verts_c, faces_c = marching_cubes(volume_data, return_local_coords=False)
905
906
907
            verts_sci, faces_sci = marching_cubes_classic(volume_data[0])

            surf = mesh_surface_area(verts[0], faces[0])
908
            surf_c = mesh_surface_area(verts_c[0], faces_c[0])
909
910
911
            surf_sci = mesh_surface_area(verts_sci, faces_sci)

            self.assertClose(surf, surf_sci)
912
            self.assertClose(surf, surf_c)
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932

    def test_sphere_surface_area(self):
        if USE_SCIKIT:
            from skimage.measure import marching_cubes_classic, mesh_surface_area

            # (B, W, H, D)
            volume = torch.Tensor(
                [
                    [
                        [
                            (x - 10) ** 2 + (y - 10) ** 2 + (z - 10) ** 2
                            for z in range(20)
                        ]
                        for y in range(20)
                    ]
                    for x in range(20)
                ]
            ).unsqueeze(0)
            volume = volume.permute(0, 3, 2, 1)  # (B, D, H, W)
            verts, faces = marching_cubes_naive(volume, isolevel=64)
933
            verts_c, faces_c = marching_cubes(volume, isolevel=64)
934
935
936
            verts_sci, faces_sci = marching_cubes_classic(volume[0], level=64)

            surf = mesh_surface_area(verts[0], faces[0])
937
            surf_c = mesh_surface_area(verts_c[0], faces_c[0])
938
939
940
            surf_sci = mesh_surface_area(verts_sci, faces_sci)

            self.assertClose(surf, surf_sci)
941
            self.assertClose(surf, surf_c)
942
943
944
945
946
947
948
949
950
951
952
953
954
955

    def test_double_ellipsoid_surface_area(self):
        if USE_SCIKIT:
            import numpy as np
            from skimage.draw import ellipsoid
            from skimage.measure import marching_cubes_classic, mesh_surface_area

            ellip_base = ellipsoid(6, 10, 16, levelset=True)
            ellip_double = np.concatenate(
                (ellip_base[:-1, ...], ellip_base[2:, ...]), axis=0
            )
            volume = torch.Tensor(ellip_double).unsqueeze(0)
            volume = volume.permute(0, 3, 2, 1)  # (B, D, H, W)
            verts, faces = marching_cubes_naive(volume, isolevel=0)
956
            verts_c, faces_c = marching_cubes(volume, isolevel=0)
957
958
959
            verts_sci, faces_sci = marching_cubes_classic(volume[0], level=0)

            surf = mesh_surface_area(verts[0], faces[0])
960
            surf_c = mesh_surface_area(verts_c[0], faces_c[0])
961
962
963
            surf_sci = mesh_surface_area(verts_sci, faces_sci)

            self.assertClose(surf, surf_sci)
964
            self.assertClose(surf, surf_c)
965

966
967
968
969
970
971
972
    def test_ball_example(self):
        N = 15
        axis_tensor = torch.arange(0, N)
        X, Y, Z = torch.meshgrid(axis_tensor, axis_tensor, axis_tensor, indexing="ij")
        u = (X - 15) ** 2 + (Y - 15) ** 2 + (Z - 15) ** 2 - 8**2
        u = u[None].float()
        verts, faces = marching_cubes_naive(u, 0, return_local_coords=False)
973
974
975
        verts2, faces2 = marching_cubes(u, 0, return_local_coords=False)
        self.assertClose(verts[0], verts2[0])
        self.assertClose(faces[0], faces2[0])
976

977
    @staticmethod
978
    def marching_cubes_with_init(algo_type: str, batch_size: int, V: int):
979
980
981
982
        device = torch.device("cuda:0")
        volume_data = torch.rand(
            (batch_size, V, V, V), dtype=torch.float32, device=device
        )
983
984
        algo_table = {
            "naive": marching_cubes_naive,
985
            "cextension": marching_cubes,
986
        }
987
988

        def convert():
989
            algo_table[algo_type](volume_data, return_local_coords=False)
990
991
992
            torch.cuda.synchronize()

        return convert