icosahedral_mesh_test.py 4.79 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for icosahedral_mesh."""

from absl.testing import absltest
from absl.testing import parameterized
import chex
from graphcast import icosahedral_mesh
import numpy as np


def _get_mesh_spec(splits: int):
  """Returns size of the final icosahedral mesh resulting from the splitting."""
  num_vertices = 12
  num_faces = 20
  for _ in range(splits):
    # Each previous face adds three new vertices, but each vertex is shared
    # by two faces.
    num_vertices += num_faces * 3 // 2
    num_faces *= 4
  return num_vertices, num_faces


class IcosahedralMeshTest(parameterized.TestCase):

  def test_icosahedron(self):
    mesh = icosahedral_mesh.get_icosahedron()
    _assert_valid_mesh(
        mesh, num_expected_vertices=12, num_expected_faces=20)

  @parameterized.parameters(list(range(5)))
  def test_get_hierarchy_of_triangular_meshes_for_sphere(self, splits):
    meshes = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
        splits=splits)
    prev_vertices = None
    for mesh_i, mesh in enumerate(meshes):
      # Check that `mesh` is valid.
      num_expected_vertices, num_expected_faces = _get_mesh_spec(mesh_i)
      _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces)

      # Check that the first N vertices from this mesh match all of the
      # vertices from the previous mesh.
      if prev_vertices is not None:
        leading_mesh_vertices = mesh.vertices[:prev_vertices.shape[0]]
        np.testing.assert_array_equal(leading_mesh_vertices, prev_vertices)

      # Increase the expected/previous values for the next iteration.
      if mesh_i < len(meshes) - 1:
        prev_vertices = mesh.vertices

  @parameterized.parameters(list(range(4)))
  def test_merge_meshes(self, splits):
    mesh_hierarchy = (
        icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
            splits=splits))
    mesh = icosahedral_mesh.merge_meshes(mesh_hierarchy)

    expected_faces = np.concatenate([m.faces for m in mesh_hierarchy], axis=0)
    np.testing.assert_array_equal(mesh.vertices, mesh_hierarchy[-1].vertices)
    np.testing.assert_array_equal(mesh.faces, expected_faces)

  def test_faces_to_edges(self):

    faces = np.array([[0, 1, 2],
                      [3, 4, 5]])

    # This also documents the order of the edges returned by the method.
    expected_edges = np.array(
        [[0, 1],
         [3, 4],
         [1, 2],
         [4, 5],
         [2, 0],
         [5, 3]])
    expected_senders = expected_edges[:, 0]
    expected_receivers = expected_edges[:, 1]

    senders, receivers = icosahedral_mesh.faces_to_edges(faces)

    np.testing.assert_array_equal(senders, expected_senders)
    np.testing.assert_array_equal(receivers, expected_receivers)


def _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces):
  vertices = mesh.vertices
  faces = mesh.faces
  chex.assert_shape(vertices, [num_expected_vertices, 3])
  chex.assert_shape(faces, [num_expected_faces, 3])

  # Vertices norm should be 1.
  vertices_norm = np.linalg.norm(vertices, axis=-1)
  np.testing.assert_allclose(vertices_norm, 1., rtol=1e-6)

  _assert_positive_face_orientation(vertices, faces)


def _assert_positive_face_orientation(vertices, faces):

  # Obtain a unit vector that points, in the direction of the face.
  face_orientation = np.cross(vertices[faces[:, 1]] - vertices[faces[:, 0]],
                              vertices[faces[:, 2]] - vertices[faces[:, 1]])
  face_orientation /= np.linalg.norm(face_orientation, axis=-1, keepdims=True)

  # And a unit vector pointing from the origin to the center of the face.
  face_centers = vertices[faces].mean(1)
  face_centers /= np.linalg.norm(face_centers, axis=-1, keepdims=True)

  # Positive orientation means those two vectors should be parallel
  # (dot product, 1), and not anti-parallel (dot product, -1).
  dot_center_orientation = np.einsum("ik,ik->i", face_orientation, face_centers)

  # Check that the face normal is parallel to the vector that joins the center
  # of the face to the center of the sphere. Note we need a small tolerance
  # because some discretizations are not exactly uniform, so it will not be
  # exactly parallel.
  np.testing.assert_allclose(dot_center_orientation, 1., atol=6e-4)


if __name__ == "__main__":
  absltest.main()