"tests/models/gptj/test_modeling_tf_gptj.py" did not exist on "505f2d749eb52f4b8b803d8c9a5f04442446e6c2"
restart.py 6.27 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os

from veros import logger, runtime_settings, runtime_state
from veros.io_tools import hdf5 as h5tools
from veros.signals import do_not_disturb
from veros.distributed import get_chunk_slices, exchange_overlap
from veros.variables import get_shape


def read_from_h5(dimensions, var_meta, infile, groupname, enable_cyclic_x):
    from veros.core.operators import numpy as npx, update, at

    variables = {}

    for key, var in infile[groupname].items():
        if not var_meta[key].dims:
            variables[key] = npx.array(var)
            continue

        local_shape = get_shape(dimensions, var_meta[key].dims, local=True, include_ghosts=True)
        gidx, lidx = get_chunk_slices(dimensions["xt"], dimensions["yt"], var_meta[key].dims, include_overlap=True)

        # pass dtype as str to prevent endianness from leaking into array
        variables[key] = npx.empty(local_shape, dtype=str(var.dtype))
        variables[key] = update(variables[key], at[lidx], var[gidx])
        variables[key] = exchange_overlap(variables[key], var_meta[key].dims, enable_cyclic_x)

    attributes = {key: var.item() for key, var in infile[groupname].attrs.items()}

    return attributes, variables


def write_to_h5(dimensions, var_meta, var_data, outfile, groupname, attributes=None):
    if attributes is None:
        attributes = {}

    group = outfile.require_group(groupname)

    for key, var in var_data.items():
        var_dims = var_meta[key].dims
        if var_dims is None:
            var_dims = []

        global_shape = get_shape(dimensions, var_dims, local=False)
        gidx, lidx = get_chunk_slices(dimensions["xt"], dimensions["yt"], var_dims, include_overlap=True)

        kwargs = dict(
            exact=True,
        )

        if var_dims:
            chunksize = []
            for d in var_dims:
                if d in dimensions:
                    chunksize.append(get_shape(dimensions, (d,), local=True, include_ghosts=False)[0])
                else:
                    chunksize.append(1)

            kwargs.update(chunks=tuple(chunksize))

            if runtime_settings.hdf5_gzip_compression and runtime_state.proc_num == 1:
                kwargs.update(compression="gzip", compression_opts=1)

        group.require_dataset(key, global_shape, var.dtype, **kwargs)
        group[key][gidx] = var[lidx]

    for key, val in attributes.items():
        group.attrs[key] = val


def read_restart(state):
    settings = state.settings

    if not settings.restart_input_filename:
        return

    if runtime_settings.force_overwrite:
        raise RuntimeError("To prevent data loss, force_overwrite cannot be used in restart runs")

    statedict = dict(state.variables.items())
    statedict.update(state.settings.items())
    restart_filename = settings.restart_input_filename.format(**statedict)

    if not os.path.isfile(restart_filename):
        raise IOError(f"restart file {restart_filename} not found")

    logger.info(f"Reading restart data from {restart_filename}")

    with h5tools.threaded_io(restart_filename, "r") as infile, state.variables.unlock():
        # core restart
        restart_vars = {var: meta for var, meta in state.var_meta.items() if meta.write_to_restart and meta.active}
        _, restart_data = read_from_h5(state.dimensions, restart_vars, infile, "core", settings.enable_cyclic_x)

        for key in restart_vars.keys():
            try:
                var_data = restart_data[key]
            except KeyError:
                raise RuntimeError(f"No restart data found for variable {key} in {restart_filename}") from None

            setattr(state.variables, key, var_data)

        # diagnostic restarts
        for diag_name, diagnostic in state.diagnostics.items():
            if not diagnostic.var_meta:
                # nothing to do
                continue

            dimensions = dict(state.dimensions)
            if diagnostic.extra_dimensions:
                dimensions.update(diagnostic.extra_dimensions)

            restart_vars = {
                var: meta for var, meta in diagnostic.var_meta.items() if meta.write_to_restart and meta.active
            }
            _, restart_data = read_from_h5(dimensions, restart_vars, infile, diag_name, settings.enable_cyclic_x)

            for key in restart_vars.keys():
                try:
                    var_data = restart_data[key]
                except KeyError:
                    raise RuntimeError(
                        f'No restart data found for variable {key} in {restart_filename} (from diagnostic "{diag_name}")'
                    ) from None

                setattr(diagnostic.variables, key, var_data)

    return state


@do_not_disturb
def write_restart(state, force=False):
    vs = state.variables
    settings = state.settings

    if runtime_settings.diskless_mode:
        return

    if not settings.restart_output_filename:
        return

    write_now = force or (
        settings.restart_frequency and vs.itt > 0 and vs.time % settings.restart_frequency < settings.dt_tracer
    )

    if not write_now:
        return

    statedict = dict(state.variables.items())
    statedict.update(state.settings.items())
    restart_filename = settings.restart_output_filename.format(**statedict)

    logger.info(f"Writing restart file {restart_filename}")

    with h5tools.threaded_io(restart_filename, "w") as outfile:
        # core restart
        vs = state.variables
        restart_vars = {var: meta for var, meta in state.var_meta.items() if meta.write_to_restart and meta.active}
        restart_data = {var: getattr(vs, var) for var in restart_vars}
        write_to_h5(state.dimensions, restart_vars, restart_data, outfile, "core")

        # diagnostic restarts
        for diag_name, diagnostic in state.diagnostics.items():
            if not diagnostic.var_meta:
                # nothing to do
                continue

            dimensions = dict(state.dimensions)
            if diagnostic.extra_dimensions:
                dimensions.update(diagnostic.extra_dimensions)

            restart_vars = {
                var: meta for var, meta in diagnostic.var_meta.items() if meta.write_to_restart and meta.active
            }
            restart_data = {var: getattr(diagnostic.variables, var) for var in restart_vars}
            write_to_h5(dimensions, restart_vars, restart_data, outfile, diag_name)