test_util.py 3.11 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared test utilities."""

from absl.testing import parameterized

from jax import config
from jax_cfd.base import grids
import numpy as np

config.parse_flags_with_absl()


class TestCase(parameterized.TestCase):
  """TestCase with assertions for arrays and grids.AlignedArray."""

  def _check_and_remove_alignment_and_grid(self, *arrays):
    """Check that array-like data values and other attributes match.

    If args type is GridArray, verify their offsets and grids match.
    If args type is GridVariable, verify their offsets, grids, and bc match.

    Args:
      *arrays: one or more Array, GridArray or GridVariable, but they all be the
        same type.

    Returns:
      The data-only arrays, with other attributes removed.
    """
    is_gridarray = [isinstance(array, grids.GridArray) for array in arrays]
    # GridArray
    if any(is_gridarray):
      self.assertTrue(
          all(is_gridarray), msg=f'arrays have mixed types: {arrays}')
      try:
        grids.consistent_offset(*arrays)
      except grids.InconsistentOffsetError as e:
        raise AssertionError(str(e)) from None
      try:
        grids.consistent_grid(*arrays)
      except grids.InconsistentGridError as e:
        raise AssertionError(str(e)) from None
      arrays = tuple(array.data for array in arrays)
    # GridVariable
    is_gridvariable = [
        isinstance(array, grids.GridVariable) for array in arrays
    ]
    if any(is_gridvariable):
      self.assertTrue(
          all(is_gridvariable), msg=f'arrays have mixed types: {arrays}')
      try:
        grids.consistent_offset(*arrays)
      except grids.InconsistentOffsetError as e:
        raise AssertionError(str(e)) from None
      try:
        grids.consistent_grid(*arrays)
      except grids.InconsistentGridError as e:
        raise AssertionError(str(e)) from None
      try:
        grids.unique_boundary_conditions(*arrays)
      except grids.InconsistentBoundaryConditionsError as e:
        raise AssertionError(str(e)) from None
      arrays = tuple(array.array.data for array in arrays)
    return arrays

  # pylint: disable=unbalanced-tuple-unpacking
  def assertArrayEqual(self, expected, actual, **kwargs):
    expected, actual = self._check_and_remove_alignment_and_grid(
        expected, actual)
    np.testing.assert_array_equal(expected, actual, **kwargs)

  def assertAllClose(self, expected, actual, **kwargs):
    expected, actual = self._check_and_remove_alignment_and_grid(
        expected, actual)
    np.testing.assert_allclose(expected, actual, **kwargs)

  # pylint: enable=unbalanced-tuple-unpacking