test_base.py 2.28 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
from textwrap import indent

from veros.variables import remove_ghosts
from veros.pyom_compat import state_from_pyom, VEROS_TO_PYOM_SETTING, VEROS_TO_PYOM_VAR


def _normalize(*arrays):
    if any(a.size == 0 for a in arrays):
        return arrays

    norm = np.nanmax(np.abs(arrays[0]))

    if norm == 0.0:
        return arrays

    return tuple(a / norm for a in arrays)


def compare_state(
    vs_state,
    pyom_obj,
    atol=1e-10,
    rtol=1e-8,
    include_ghosts=False,
    allowed_failures=None,
    normalize=False,
):
    IGNORE_SETTINGS = ("congr_max_iterations",)

    if allowed_failures is None:
        allowed_failures = []

    pyom_state = state_from_pyom(pyom_obj)

    def assert_setting(setting):
        vs_val = vs_state.settings.get(setting)
        setting = VEROS_TO_PYOM_SETTING.get(setting, setting)
        if setting is None or setting in IGNORE_SETTINGS:
            return

        pyom_val = pyom_state.settings.get(setting)
        assert vs_val == pyom_val, (vs_val, pyom_val)

    def assert_var(var):
        vs_val = vs_state.variables.get(var)

        var = VEROS_TO_PYOM_VAR.get(var, var)
        if var is None:
            return

        if var not in pyom_state.variables:
            return

        pyom_val = pyom_state.variables.get(var)

        if var in ("tau", "taup1", "taum1"):
            assert pyom_val == vs_val + 1
            return

        if not include_ghosts:
            vs_val = remove_ghosts(vs_val, vs_state.var_meta[var].dims)
            pyom_val = remove_ghosts(pyom_val, pyom_state.var_meta[var].dims)

        if normalize:
            vs_val, pyom_val = _normalize(vs_val, pyom_val)

        np.testing.assert_allclose(vs_val, pyom_val, atol=atol, rtol=rtol)

    passed = True

    for setting in vs_state.settings.fields():
        try:
            assert_setting(setting)
        except AssertionError as exc:
            if setting not in allowed_failures:
                print(f"{setting}:{indent(str(exc), ' ' * 4)}")
                passed = False

    for var in vs_state.variables.fields():
        try:
            assert_var(var)
        except AssertionError as exc:
            if var not in allowed_failures:
                print(f"{var}:{indent(str(exc), ' ' * 4)}")
                passed = False

    assert passed