streamfunction_kernel.py 2.7 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import sys

import numpy as onp
from mpi4py import MPI

from veros import runtime_settings as rs, runtime_state as rst

rs.diskless_mode = True

if rst.proc_num > 1:
    rs.num_proc = (2, 2)
    assert rst.proc_num == 4

from veros.state import get_default_state, resize_dimension  # noqa: E402
from veros.distributed import gather  # noqa: E402
from veros.core.operators import numpy as npx, update, at  # noqa: E402
from veros.core.external.solvers import get_linear_solver  # noqa: E402


def get_inputs():
    state = get_default_state()
    settings = state.settings

    with settings.unlock():
        settings.nx = 100
        settings.ny = 40
        settings.nz = 1

        settings.enable_cyclic_x = True
        settings.enable_streamfunction = True

    state.initialize_variables()
    resize_dimension(state, "isle", 1)

    vs = state.variables

    nx_local, ny_local = settings.nx // rs.num_proc[0], settings.ny // rs.num_proc[1]
    idx_global = (
        slice(rst.proc_idx[0] * nx_local, (rst.proc_idx[0] + 1) * nx_local + 4),
        slice(rst.proc_idx[1] * ny_local, (rst.proc_idx[1] + 1) * ny_local + 4),
        Ellipsis,
    )

    with vs.unlock():
        vs.dxt = update(vs.dxt, at[...], 10e3)
        vs.dxu = update(vs.dxu, at[...], 10e3)

        vs.dyt = update(vs.dyt, at[...], 10e3)
        vs.dyu = update(vs.dyu, at[...], 10e3)

        hr_global = (
            1.0 / npx.linspace(500, 2000, settings.nx + 4)[:, None] * npx.ones((settings.nx + 4, settings.ny + 4))
        )
        vs.hur = hr_global[idx_global]
        vs.hvr = hr_global[idx_global]

        vs.cosu = update(vs.cosu, at[...], 1)
        vs.cost = update(vs.cost, at[...], 1)

        boundary_mask = npx.ones((settings.nx + 4, settings.ny + 4), dtype="bool")
        boundary_mask = update(boundary_mask, at[:50, :2], 0)
        boundary_mask = update(boundary_mask, at[20:30, 20:30], 0)
        vs.isle_boundary_mask = boundary_mask[idx_global]

    rhs = npx.ones_like(vs.hur)
    x0 = npx.zeros_like(vs.hur)
    return state, rhs, x0


if rst.proc_num == 1:
    comm = MPI.COMM_SELF.Spawn(sys.executable, args=["-m", "mpi4py", sys.argv[-1]], maxprocs=4)

    try:
        state, rhs, x0 = get_inputs()
        sol = get_linear_solver(state)
        psi = sol.solve(state, rhs, x0)
    except Exception as exc:
        print(str(exc))
        comm.Abort(1)
        raise

    other_psi = onp.empty_like(psi)
    comm.Recv(other_psi, 0)

    onp.testing.assert_allclose(psi, other_psi)
else:
    state, rhs, x0 = get_inputs()
    sol = get_linear_solver(state)
    psi = sol.solve(state, rhs, x0)

    psi_global = gather(psi, state.dimensions, ("xt", "yt"))

    if rst.proc_rank == 0:
        rs.mpi_comm.Get_parent().Send(onp.array(psi_global), 0)