operators.py 4.81 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import warnings
from contextlib import contextmanager

from veros import runtime_settings, runtime_state, veros_kernel


class Index:
    __slots__ = ()

    @staticmethod
    def __getitem__(key):
        return key


def noop(*args, **kwargs):
    pass


@contextmanager
def make_writeable(*arrs):
    orig_writeable = [arr.flags.writeable for arr in arrs]
    writeable_arrs = []
    try:
        for arr in arrs:
            arr = arr.copy()
            arr.flags.writeable = True
            writeable_arrs.append(arr)

        if len(writeable_arrs) == 1:
            yield writeable_arrs[0]
        else:
            yield writeable_arrs

    finally:
        for arr, orig_val in zip(writeable_arrs, orig_writeable):
            try:
                arr.flags.writeable = orig_val
            except ValueError:
                pass


def update_numpy(arr, at, to):
    with make_writeable(arr) as warr:
        warr[at] = to
    return warr


def update_add_numpy(arr, at, to):
    with make_writeable(arr) as warr:
        warr[at] += to
    return warr


def update_multiply_numpy(arr, at, to):
    with make_writeable(arr) as warr:
        warr[at] *= to
    return warr


def solve_tridiagonal_numpy(a, b, c, d, water_mask, edge_mask):
    import numpy as np
    from scipy.linalg import lapack

    out = np.zeros(a.shape, dtype=a.dtype)

    if not np.any(water_mask):
        return out

    # remove couplings between slices
    with make_writeable(a, c) as warr:
        a, c = warr
        a[edge_mask] = 0
        c[..., -1] = 0

    sol = lapack.dgtsv(a[water_mask][1:], b[water_mask], c[water_mask][:-1], d[water_mask])[3]
    out[water_mask] = sol
    return out


def fori_numpy(lower, upper, body_fun, init_val):
    val = init_val
    for i in range(lower, upper):
        val = body_fun(i, val)
    return val


def scan_numpy(f, init, xs, length=None):
    import numpy as np

    if xs is None:
        xs = [None] * length

    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)

    return carry, np.stack(ys)


@veros_kernel
def solve_tridiagonal_jax(a, b, c, d, water_mask, edge_mask):
    import jax.lax
    import jax.numpy as jnp

    use_ext = runtime_settings.use_special_tdma

    try:
        from veros.core.special.tdma_ import tdma, HAS_CPU_EXT, HAS_GPU_EXT
    except ImportError:
        if use_ext:
            raise
        has_ext = False
    else:
        has_ext = (HAS_CPU_EXT and runtime_settings.device == "cpu") or (
            HAS_GPU_EXT and runtime_settings.device == "gpu"
        )

    if use_ext is None:
        if not has_ext:
            warnings.warn("Could not use custom TDMA implementation, falling back to pure JAX")
            use_ext = False
        else:
            use_ext = True

    if use_ext and not has_ext:
        raise RuntimeError("Could not use custom TDMA implementation")

    if use_ext:
        return tdma(a, b, c, d, water_mask, edge_mask)

    a = water_mask * a * jnp.logical_not(edge_mask)
    b = jnp.where(water_mask, b, 1.0)
    c = water_mask * c
    d = water_mask * d

    def compute_primes(last_primes, x):
        last_cp, last_dp = last_primes
        a, b, c, d = x
        cp = c / (b - a * last_cp)
        dp = (d - a * last_dp) / (b - a * last_cp)
        new_primes = (cp, dp)
        return new_primes, new_primes

    diags_transposed = [jnp.moveaxis(arr, 2, 0) for arr in (a, b, c, d)]
    init = jnp.zeros(a.shape[:-1], dtype=a.dtype)
    _, primes = jax.lax.scan(compute_primes, (init, init), diags_transposed)

    def backsubstitution(last_x, x):
        cp, dp = x
        new_x = dp - cp * last_x
        return new_x, new_x

    _, sol = jax.lax.scan(backsubstitution, init, primes, reverse=True)
    return jnp.moveaxis(sol, 0, 2)


def update_jax(arr, at, to):
    return arr.at[at].set(to)


def update_add_jax(arr, at, to):
    return arr.at[at].add(to)


def update_multiply_jax(arr, at, to):
    return arr.at[at].multiply(to)


def flush_jax():
    import jax

    dummy = jax.device_put(0.0) + 0.0
    try:
        dummy.block_until_ready()
    except AttributeError:
        # if we are jitting, dummy is not a DeviceArray that we can wait for
        pass


numpy = runtime_state.backend_module

if runtime_settings.backend == "numpy":
    update = update_numpy
    update_add = update_add_numpy
    update_multiply = update_multiply_numpy
    at = Index()
    solve_tridiagonal = solve_tridiagonal_numpy
    for_loop = fori_numpy
    scan = scan_numpy
    flush = noop

elif runtime_settings.backend == "jax":
    import jax.lax

    update = update_jax
    update_add = update_add_jax
    update_multiply = update_multiply_jax
    at = Index()
    solve_tridiagonal = solve_tridiagonal_jax
    for_loop = jax.lax.fori_loop
    scan = jax.lax.scan
    flush = flush_jax

else:
    raise ValueError(f"Unrecognized backend {runtime_settings.backend}")