tridiag_test.py 2.14 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import pytest
import numpy as np

from veros import runtime_settings
from veros.pyom_compat import load_pyom


@pytest.mark.skipif(runtime_settings.backend != "jax", reason="Must use JAX backend")
@pytest.mark.parametrize("use_ext", [True, False])
def test_solve_tridiag_jax(pyom2_lib, use_ext):
    from veros.core.operators import solve_tridiagonal_jax
    from veros.core.utilities import create_water_masks

    pyom_obj = load_pyom(pyom2_lib)

    nx, ny, nz = 70, 60, 50
    a, b, c, d = (np.random.randn(nx, ny, nz) for _ in range(4))
    kbot = np.random.randint(0, nz, size=(nx, ny))

    out_pyom = np.zeros((nx, ny, nz))
    for i in range(nx):
        for j in range(ny):
            ks = kbot[i, j] - 1
            ke = nz

            if ks < 0:
                continue

            out_pyom[i, j, ks:ke] = pyom_obj.solve_tridiag(
                a=a[i, j, ks:ke], b=b[i, j, ks:ke], c=c[i, j, ks:ke], d=d[i, j, ks:ke], n=ke - ks
            )

    _, water_mask, edge_mask = create_water_masks(kbot, nz)
    object.__setattr__(runtime_settings, "use_special_tdma", use_ext)
    out_vs = solve_tridiagonal_jax(a, b, c, d, water_mask, edge_mask)

    np.testing.assert_allclose(out_pyom, out_vs)


@pytest.mark.skipif(runtime_settings.backend != "numpy", reason="Must use NumPy backend")
def test_solve_tridiag_numpy(pyom2_lib):
    from veros.core.operators import solve_tridiagonal_numpy
    from veros.core.utilities import create_water_masks

    pyom_obj = load_pyom(pyom2_lib)

    nx, ny, nz = 70, 60, 50
    a, b, c, d = (np.random.randn(nx, ny, nz) for _ in range(4))
    kbot = np.random.randint(0, nz, size=(nx, ny))

    out_pyom = np.zeros((nx, ny, nz))
    for i in range(nx):
        for j in range(ny):
            ks = kbot[i, j] - 1
            ke = nz

            if ks < 0:
                continue

            out_pyom[i, j, ks:ke] = pyom_obj.solve_tridiag(
                a=a[i, j, ks:ke], b=b[i, j, ks:ke], c=c[i, j, ks:ke], d=d[i, j, ks:ke], n=ke - ks
            )

    _, water_mask, edge_mask = create_water_masks(kbot, nz)
    out_vs = solve_tridiagonal_numpy(a, b, c, d, water_mask, edge_mask)

    np.testing.assert_allclose(out_pyom, out_vs)