"models-2.13.1/official/nlp/modeling/layers/util.py" did not exist on "e53ccd8010d8359bd9c3a641101bb0f8c364007d"
pressure_solver_benchmark.py 4.38 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
from time import perf_counter

from benchmark_base import benchmark_cli

from veros import logger, runtime_settings
from veros.pyom_compat import load_pyom, pyom_from_state

VARIABLES_USED = ["cost", "cosu", "dxt", "dxu", "dyt", "dyu", "hu", "hv", "maskT"]

VARIABLES_ALLOCATED = VARIABLES_USED + ["hur", "hvr", "isle_boundary_mask"]

SETTINGS_USED = [
    "nx",
    "ny",
    "nz",
    "dt_tracer",
    "dt_mom",
    "enable_streamfunction",
    "enable_cyclic_x",
]


def get_dummy_state(infile):
    import h5py
    from veros.state import VerosState
    from veros.distributed import get_chunk_slices, exchange_overlap
    from veros.variables import VARIABLES, DIM_TO_SHAPE_VAR, get_shape
    from veros.settings import SETTINGS
    from veros.core.operators import numpy as npx, update, at

    variables_subset = {key: VARIABLES[key] for key in VARIABLES_ALLOCATED}

    state = VerosState(var_meta=variables_subset, setting_meta=SETTINGS, dimensions=DIM_TO_SHAPE_VAR)

    with h5py.File(infile, "r") as f:
        input_settings = dict(f["settings"].attrs.items())

        with state.settings.unlock():
            state.settings.update({sett: input_settings[sett] for sett in SETTINGS_USED})

        state.initialize_variables()

        dimensions = state.dimensions
        var_meta = state.var_meta

        input_data = {}
        for v in f.keys():
            if v == "settings":
                continue

            if v in var_meta:
                dims = var_meta[v].dims
            else:
                dims = ("xt", "yt")

            local_shape = get_shape(dimensions, dims, local=True, include_ghosts=True)
            gidx, lidx = get_chunk_slices(dimensions["xt"], dimensions["yt"], dims, include_overlap=True)

            if npx.issubdtype(f[v].dtype, npx.floating):
                dtype = runtime_settings.float_type
            else:
                dtype = str(f[v].dtype)

            var = npx.empty(local_shape, dtype=dtype)
            var = update(var, at[lidx], f[v][gidx])
            var = exchange_overlap(var, dims, state.settings.enable_cyclic_x)
            input_data[v] = var

    with state.variables.unlock():
        for key in VARIABLES_USED:
            setattr(state.variables, key, input_data[key])

    return state, input_data


@benchmark_cli
def main(pyom2_lib, timesteps, size):
    from veros.tools import get_assets
    from veros.distributed import barrier
    from veros.core.operators import flush, numpy as npx
    from veros.core.external.solve_pressure import get_linear_solver

    here = os.path.dirname(__file__)
    assets = get_assets("bench-external", os.path.join(here, "bench-external-assets.json"))

    total_size = size[0] * size[1] * size[2]

    if 1e8 <= total_size <= 1e9:
        infile = assets["01deg-press"]
    elif 1e6 <= total_size <= 1e7:
        infile = assets["1deg-press"]
    elif 1e4 <= total_size <= 1e5:
        infile = assets["4deg-press"]
    else:
        raise ValueError(
            "Pressure solver benchmark only support 4deg, 1deg, and 0.1deg resolution"
            " (n = 5e4, 5e6, 5e8, respectively)."
        )

    state, input_data = get_dummy_state(infile)
    solver = get_linear_solver(state)

    if not pyom2_lib:

        def reset():
            pass

        def run():
            return solver.solve(state, input_data["rhs"], input_data["x0"])

    else:
        pyom_obj = load_pyom(pyom2_lib)
        pyom_obj = pyom_from_state(state, pyom_obj, init_streamfunction=False)
        m = pyom_obj.main_module
        m.taum1, m.tau, m.taup1 = 1, 2, 3

        def reset():
            m.psi[..., m.taup1 - 1] = input_data["x0"]

        def run():
            pyom_obj.congrad_surf_press(
                m.is_pe - m.onx, m.ie_pe + m.onx, m.js_pe - m.onx, m.je_pe + m.onx, input_data["rhs"], 0
            )
            return m.psi[..., m.taup1 - 1]

    for _ in range(timesteps):
        reset()
        flush()
        barrier()

        start = perf_counter()
        res = run()
        flush()
        barrier()
        end = perf_counter()

        logger.debug(f"Time step took {end - start}s")

    if not pyom2_lib:
        # monitor residual to expected solution, with generous margin
        def rms(arr):
            return npx.sqrt(npx.mean(arr**2))

        rms_err = rms(res[2:-2, 2:-2] - input_data["res"][2:-2, 2:-2]) / rms(input_data["res"][2:-2, 2:-2])
        assert rms_err < 1e-3, rms_err


if __name__ == "__main__":
    main()