routines.py 11.6 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
import functools
import inspect
import threading
from contextlib import ExitStack, contextmanager

from veros import logger

from veros.state import VerosState


# stack helpers


class RoutineStack:
    def __init__(self):
        self.keep_full_stack = False
        self._stack = []
        self._current_idx = []

    @property
    def stack_level(self):
        return len(self._current_idx)

    def append(self, val):
        frame = self._stack
        for i in self._current_idx:
            frame = frame[i][1]

        self._current_idx.append(len(frame))
        frame.append([val, []])

    def pop(self):
        frame = self._stack
        for i in self._current_idx[:-1]:
            frame = frame[i][1]

        if self.keep_full_stack:
            last_val = frame[-1][0]
        else:
            last_val = frame.pop()[0]
        self._current_idx.pop()
        return last_val


# global context

CURRENT_CONTEXT = threading.local()
CURRENT_CONTEXT.is_dist_safe = True
CURRENT_CONTEXT.routine_stack = RoutineStack()
CURRENT_CONTEXT.mpi4jax_token = None


@contextmanager
def nullcontext():
    yield


@contextmanager
def enter_routine(name, routine_obj, timer=None, dist_safe=True):
    from veros import runtime_state as rst
    from veros.distributed import abort

    stack = CURRENT_CONTEXT.routine_stack

    logger.trace("{}> {}", "-" * stack.stack_level, name)
    stack.append(routine_obj)

    reset_dist_safe = False
    if CURRENT_CONTEXT.is_dist_safe:
        if not dist_safe and rst.proc_num > 1:
            CURRENT_CONTEXT.is_dist_safe = False
            reset_dist_safe = True

    timer_ctx = nullcontext() if timer is None else timer

    try:
        with timer_ctx:
            yield

    except:  # noqa: E722
        if reset_dist_safe:
            abort()
        raise

    finally:
        if reset_dist_safe:
            CURRENT_CONTEXT.is_dist_safe = True

        r = stack.pop()
        assert r is routine_obj

        exec_time = ""
        if timer is not None:
            exec_time = f"({timer.last_time:.3f}s)"

        logger.trace("<{} {} {}", "-" * stack.stack_level, name, exec_time)


# helper functions


def _get_func_name(function):
    return f"{inspect.getmodule(function).__name__}:{function.__qualname__}"


def _is_method(function):
    if inspect.ismethod(function):
        return True

    # hack for unbound methods: check if first argument is called "self"
    spec = inspect.getfullargspec(function)
    return spec.args and spec.args[0] == "self"


# routine


def veros_routine(function=None, *, dist_safe=True, local_variables=()):
    """
    .. note::

      This decorator should be applied to all functions that access the Veros state object
      (even when subclassing :class:`veros.VerosSetup`).

    The first argument to the decorated function must be a VerosState instance.

    Veros routines cannot return anything. All changes must be applied to the passed state object.

    Parameters:
        dist_safe (bool): If set to False, all variables specified in local_variables are synced
            to the root process before execution and synced back after. This means that the routine
            will only be executed on rank 0. Has no effect in non-distributed contexts.

        local_variables (Tuple[str]): List of variable names to be synced if dist_safe=False. This
            must include all variables retrieved from the state object throughout the routine (inputs
            *and* outputs).

    Example:
       >>> from veros import VerosSetup, veros_routine
       >>>
       >>> class MyModel(VerosSetup):
       >>>     @veros_routine
       >>>     def set_topography(self, state):
       >>>         vs = state.variables
       >>>         settings = state.settings
       >>>         vs.kbot = npx.random.randint(0, settings.nz, size=vs.kbot.shape)

    """

    def inner_decorator(function):
        narg = 1 if _is_method(function) else 0
        num_params = len(inspect.signature(function).parameters)
        if narg >= num_params:
            raise TypeError("Veros routines must take at least one argument")

        routine = VerosRoutine(function, state_argnum=narg, dist_safe=dist_safe, local_variables=local_variables)
        routine = functools.wraps(function)(routine)
        return routine

    if function is not None:
        return inner_decorator(function)

    return inner_decorator


class VerosRoutine:
    """Do not instantiate directly!"""

    def __init__(self, function, dist_safe=True, local_variables=(), state_argnum=0):
        if isinstance(local_variables, str):
            local_variables = (local_variables,)

        self.function = function
        self.dist_safe = dist_safe
        self.local_variables = local_variables
        self.state_argnum = state_argnum
        self.name = _get_func_name(self.function)

    def __call__(self, *args, **kwargs):
        from veros import runtime_state as rst
        from veros.state import VerosState, DistSafeVariableWrapper
        from veros.core.operators import flush

        veros_state = args[self.state_argnum]

        if not isinstance(veros_state, VerosState):
            raise TypeError(f"Argument {self.state_argnum} to this Veros routine must be a VerosState object")

        timer = veros_state.profile_timers[self.name]

        with ExitStack() as es:
            vars_initialized = veros_state._variables is not None

            if vars_initialized:
                es.enter_context(veros_state.variables.unlock())

            execute = True
            restore_vars = False

            if not self.dist_safe:
                orig_vars = veros_state._variables
                if not isinstance(orig_vars, DistSafeVariableWrapper):
                    restore_vars = True
                    veros_state._variables = DistSafeVariableWrapper(orig_vars, self.local_variables)
                    veros_state._variables._gather_variables()

                execute = rst.proc_rank == 0

            routine_ctx = enter_routine(name=self.name, routine_obj=self, timer=timer, dist_safe=self.dist_safe)

            out = None
            try:
                with routine_ctx:
                    if execute:
                        out = self.function(*args, **kwargs)

            finally:
                if restore_vars:
                    veros_state._variables._scatter_variables()
                    veros_state._variables = orig_vars

                flush()

        if out is not None:
            logger.warning(
                f"Routine {self.name} returned object of type {type(out)}. Return objects are silently dropped."
            )

    def __get__(self, instance, _):
        return functools.partial(self.__call__, instance)

    def __repr__(self):
        return f"<{self.__class__.__name__} {self.name} at {hex(id(self))}>"


# kernel


def veros_kernel(function=None, *, static_args=()):
    """Decorator that marks a function as a kernel that can be JIT compiled if supported
    by the backend.

    Kernels cannot modify the Veros state object. Instead, all modifications have to be
    returned explicitly.

    Parameters:
        static_args (Tuple[str]): Names of kernel arguments that should be static.

    Example:
        >>> from veros import veros_kernel, KernelOutput
        >>>
        >>> @veros_kernel
        >>> def double_psi(state):
        >>>     vs = state.variables
        >>>     vs.psi = 2 * vs.psi
        >>>     return KernelOutput(psi=vs.psi)

    """

    def inner_decorator(function):
        kernel = VerosKernel(function, static_args=static_args)
        kernel = functools.wraps(function)(kernel)
        return kernel

    if function is not None:
        return inner_decorator(function)

    return inner_decorator


class VerosKernel:
    """Do not instantiate directly!"""

    def __init__(self, function, static_args=()):
        """Do some parameter introspection."""

        # make sure function signature is in the form we need
        self.name = _get_func_name(function)
        self.func_sig = inspect.signature(function)

        func_params = self.func_sig.parameters

        allowed_param_types = (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)

        if any(p.kind not in allowed_param_types for p in func_params.values()):
            raise ValueError(f"Veros kernels do not support *args, **kwargs, or keyword-only parameters ({self.name})")

        # parse static args
        if isinstance(static_args, str):
            static_args = (static_args,)

        func_argnames = list(func_params.keys())

        self.static_argnums = []
        for static_arg in static_args:
            try:
                arg_index = func_argnames.index(static_arg)
            except ValueError:
                raise ValueError(
                    f'Veros kernel {self.name} has no argument "{static_arg}", but it is given in static_args'
                ) from None

            self.static_argnums.append(arg_index)

        self.function = function

    def __call__(self, *args, **kwargs):
        from veros import runtime_settings, runtime_state
        from veros.core.operators import flush

        inject_tokens = runtime_settings.backend == "jax" and runtime_state.proc_num > 1

        # apply JIT
        if runtime_settings.backend == "jax":
            import jax

            CompiledFunction = type(jax.jit(lambda: None))

            if not isinstance(self.function, CompiledFunction):
                if inject_tokens:
                    function = self.function

                    @functools.wraps(function)
                    def token_wrapper(*args):
                        inputs = args[:-1]
                        token = args[-1]
                        CURRENT_CONTEXT.mpi4jax_token = token
                        out = function(*inputs)
                        token = CURRENT_CONTEXT.mpi4jax_token
                        return out, token

                    if CURRENT_CONTEXT.mpi4jax_token is None:
                        CURRENT_CONTEXT.mpi4jax_token = jax.lax.create_token()

                    self.function = token_wrapper

                self.function = jax.jit(self.function, static_argnums=self.static_argnums)

        # JAX only accepts positional args when using static_argnums
        # so convert everything to positional for consistency
        bound_args = self.func_sig.bind(*args, **kwargs)
        bound_args.apply_defaults()

        veros_state = None
        for argval in bound_args.arguments.values():
            if isinstance(argval, VerosState):
                veros_state = argval
                break

        called_with_state = veros_state is not None

        # when profiling, make sure all inputs are ready before starting the timer
        if runtime_settings.profile_mode:
            flush()

        if called_with_state:
            timer = veros_state.profile_timers[self.name]
        else:
            timer = None

        with ExitStack() as es:
            if called_with_state:
                es.enter_context(veros_state.variables.unlock())

            args = list(bound_args.arguments.values())

            if inject_tokens:
                args.append(CURRENT_CONTEXT.mpi4jax_token)

            with enter_routine(self.name, self, timer):
                out = self.function(*args)

                if runtime_settings.profile_mode:
                    flush()

            if inject_tokens:
                out, token = out
                CURRENT_CONTEXT.mpi4jax_token = token

        return out

    def __repr__(self):
        return f"<{self.__class__.__name__} {self.name} at {hex(id(self))}>"


def is_veros_routine(func):
    if isinstance(func, functools.partial):
        func = func.func

    if inspect.ismethod(func):
        func = func.__self__

    return isinstance(func, VerosRoutine)