"src/targets/vscode:/vscode.git/clone" did not exist on "91dcfa3496ac9bb9eea0d81fce1e530b886fb4b3"
setup.py 17.7 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
from veros.core.operators import numpy as npx
import numpy as onp

import scipy.interpolate
import scipy.spatial


def interpolate(coords, var, interp_coords, missing_value=None, fill=True, kind="linear"):
    """Interpolate globally defined data to a different (regular) grid.

    Arguments:
       coords: Tuple of coordinate arrays for each dimension.
       var (:obj:`ndarray` of dim (nx1, ..., nxd)): Variable data to interpolate.
       interp_coords: Tuple of coordinate arrays to interpolate to.
       missing_value (optional): Value denoting cells of missing data in ``var``.
          Is replaced by `NaN` before interpolating. Defaults to `None`, which means
          no replacement is taking place.
       fill (bool, optional): Whether `NaN` values should be replaced by the nearest
          finite value after interpolating. Defaults to ``True``.
       kind (str, optional): Order of interpolation. Supported are `nearest` and
          `linear` (default).

    Returns:
       :obj:`ndarray` containing the interpolated values on the grid spanned by
       ``interp_coords``.

    """
    if len(coords) != len(interp_coords) or len(coords) != var.ndim:
        raise ValueError("Dimensions of coordinates and values do not match")

    if missing_value is not None:
        invalid_mask = npx.isclose(var, missing_value)
        var = npx.where(invalid_mask, npx.nan, var)

    if var.ndim > 1 and coords[0].ndim == 1:
        interp_grid = npx.rollaxis(npx.array(npx.meshgrid(*interp_coords, indexing="ij")), 0, len(interp_coords) + 1)
    else:
        interp_grid = interp_coords

    def as_floatarray(x):
        return onp.array(x, dtype="float64")

    coords = tuple(as_floatarray(c) for c in coords)
    var = scipy.interpolate.interpn(
        coords, as_floatarray(var), as_floatarray(interp_grid), bounds_error=False, fill_value=onp.nan, method=kind
    )
    var = npx.asarray(var)

    if fill:
        var = fill_holes(var)

    return var


def fill_holes(data):
    """A simple inpainting function that replaces NaN values in `data` with the
    nearest finite value.
    """
    data = onp.array(data)
    dim = data.ndim
    flag = ~onp.isnan(data)

    slcs = [slice(None)] * dim

    while onp.any(~flag):
        for i in range(dim):
            slcs1 = slcs[:]
            slcs2 = slcs[:]
            slcs1[i] = slice(0, -1)
            slcs2[i] = slice(1, None)

            slcs1 = tuple(slcs1)
            slcs2 = tuple(slcs2)

            # replace from the right
            repmask = onp.logical_and(~flag[slcs1], flag[slcs2])
            data[slcs1][repmask] = data[slcs2][repmask]
            flag[slcs1][repmask] = True

            # replace from the left
            repmask = onp.logical_and(~flag[slcs2], flag[slcs1])
            data[slcs2][repmask] = data[slcs1][repmask]
            flag[slcs2][repmask] = True

    return npx.asarray(data)


def get_periodic_interval(current_time, cycle_length, rec_spacing, n_rec):
    """Used for linear interpolation between periodic time intervals.

    One common application is the interpolation of external forcings that are defined
    at discrete times (e.g. one value per month of a standard year) to the current
    time step.

    Arguments:
       current_time (float): Time to interpolate to.
       cycle_length (float): Total length of one periodic cycle.
       rec_spacing (float): Time spacing between each data record.
       n_rec (int): Total number of records available.

    Returns:
       :obj:`tuple` containing (n1, f1), (n2, f2): Indices and weights for the interpolated
       record array.

    Example:
       The following interpolates a record array ``data`` containing 12 monthly values
       to the current time step:

       >>> year_in_seconds = 60. * 60. * 24. * 365.
       >>> current_time = 60. * 60. * 24. * 45. # mid-february
       >>> print(data.shape)
       (360, 180, 12)
       >>> (n1, f1), (n2, f2) = get_periodic_interval(current_time, year_in_seconds, year_in_seconds / 12, 12)
       >>> data_at_current_time = f1 * data[..., n1] + f2 * data[..., n2]

    """
    current_time = current_time % cycle_length
    # using npx.array works with both NumPy and JAX
    t_idx_1 = npx.array(current_time // rec_spacing, dtype="int")
    t_idx_2 = npx.array((1 + t_idx_1) % n_rec, dtype="int")
    weight_2 = (current_time - rec_spacing * t_idx_1) / rec_spacing
    weight_1 = 1.0 - weight_2
    return (t_idx_1, weight_1), (t_idx_2, weight_2)


def make_cyclic(longitude, array=None, wrap=360.0):
    """Create a cyclic version of a longitude array and (optionally) another array.

    Arguments:
        longitude (ndarray): Longitude array of shape (nlon, ...).
        array (ndarray): Another array that is to be made cyclic of shape (nlon, ...).
        wrap (float): Wrapping value, defaults to 360 (degrees).

    Returns:
        Tuple containing (cyclic_longitudes, cyclic_array) if `array` is given, otherwise
        just the ndarray cyclic_longitudes of shape (2 * nlon, ...).

    """
    lonsize = longitude.shape[0]
    cyclic_longitudes = npx.hstack(
        (longitude[lonsize // 2 :, ...] - wrap, longitude, longitude[: lonsize // 2, ...] + wrap)
    )
    if array is None:
        return cyclic_longitudes
    cyclic_array = npx.hstack((array[lonsize // 2 :, ...], array, array[: lonsize // 2, ...]))
    return cyclic_longitudes, cyclic_array


def get_coastline_distance(coords, coast_mask, spherical=False, radius=None, num_candidates=None):
    """Calculate the (approximate) distance of each water cell from the nearest coastline.

    Arguments:
        coords (tuple of ndarrays): Tuple containing x and y (longitude and latitude)
            coordinate arrays of shape (nx, ny).
        coast_mask (ndarray): Boolean mask indicating whether a cell is a land cell
            (must be same shape as coordinate arrays).
        spherical (bool): Use spherical instead of Cartesian coordinates.
            When this is `True`, cyclical boundary conditions are used, and the
            resulting distances are only approximate. Cells are pre-sorted by
            Euclidean lon-lat distance, and great circle distances are calculated for
            the first `num_candidates` elements. Defaults to `False`.
        radius (float): Radius of spherical coordinate system. Must be given when
            `spherical` is `True`.
        num_candidates (int): Number of candidates to calculate great circle distances
            for for each water cell. The higher this value, the more accurate the returned
            distances become when `spherical` is `True`. Defaults to the square root
            of the number of coastal cells.

    Returns:
        :obj:`ndarray` of shape (nx, ny) indicating the distance to the nearest land
        cell (0 if cell is land).

    Example:
        The following returns coastal distances of all T cells for a spherical Veros setup.

        >>> coords = npx.meshgrid(vs.xt[2:-2], vs.yt[2:-2], indexing='ij')
        >>> dist = tools.get_coastline_distance(coords, vs.kbot > 0, spherical=True, radius=settings.radius)

    """
    if not len(coords) == 2:
        raise ValueError("coords must be lon-lat tuple")
    if not all(c.shape == coast_mask.shape for c in coords):
        raise ValueError("coordinates must have same shape as coastal mask")
    if spherical and not radius:
        raise ValueError("radius must be given for spherical coordinates")

    watercoords = onp.array([c[~coast_mask] for c in coords]).T
    if spherical:
        coastcoords = onp.array(make_cyclic(coords[0][coast_mask], coords[1][coast_mask])).T
    else:
        coastcoords = onp.array((coords[0][coast_mask], coords[1][coast_mask])).T
    coast_kdtree = scipy.spatial.cKDTree(coastcoords)

    distance = onp.zeros(coords[0].shape)

    if spherical:

        def spherical_distance(coords1, coords2):
            """Calculate great circle distance from latitude and longitude"""
            coords1 *= onp.pi / 180.0
            coords2 *= onp.pi / 180.0
            lon1, lon2, lat1, lat2 = coords1[..., 0], coords2[..., 0], coords1[..., 1], coords2[..., 1]
            return radius * onp.arccos(
                onp.sin(lat1) * onp.sin(lat2) + onp.cos(lat1) * onp.cos(lat2) * onp.cos(lon1 - lon2)
            )

        if not num_candidates:
            num_candidates = int(onp.sqrt(onp.count_nonzero(~coast_mask)))

        i_nearest = coast_kdtree.query(watercoords, k=num_candidates)[1]
        approx_nearest = coastcoords[i_nearest]
        distance[~coast_mask] = onp.min(spherical_distance(approx_nearest, watercoords[..., onp.newaxis, :]), axis=-1)

    else:
        distance[~coast_mask] = coast_kdtree.query(watercoords)[0]

    return npx.asarray(distance)


def get_uniform_grid_steps(total_length, stepsize):
    """Get uniform grid step sizes in an interval.

    Arguments:
        total_length (float): total length of the resulting grid
        stepsize (float): grid step size

    Returns:
        :obj:`ndarray` of grid steps

    Example:
        >>> uniform_steps = uniform_grid_setup(6., 0.25)
        >>> uniform_steps
        [ 0.25,  0.25,  0.25,  0.25,  0.25,  0.25,  0.25,  0.25,  0.25,
          0.25,  0.25,  0.25,  0.25,  0.25,  0.25,  0.25,  0.25,  0.25,
          0.25,  0.25,  0.25,  0.25,  0.25,  0.25 ]

    """
    if total_length % stepsize:
        raise ValueError("total length must be an integer multiple of stepsize")
    return stepsize * npx.ones(int(total_length / stepsize))


def get_stretched_grid_steps(
    n_cells, total_length, minimum_stepsize, stretching_factor=2.5, two_sided_grid=False, refine_towards="upper"
):
    """Computes stretched grid steps for regional and global domains with either
    one or two-sided stretching using a hyperbolic tangent stretching function.

    Arguments:
        n_cells (int): Number of grid points.
        total_length (float): Length of the grid interval to be covered (sum of the
            resulting grid steps).
        minimum_stepsize (float): Grid step size on the lower end of the interval.
        stretching_factor (float, optional): Coefficient of the `tanh` stretching
            function. The higher this value, the more abrupt the step sizes change.
        two_sided_grid (bool, optional): If set to `True`, the resulting grid will be symmetrical
            around the center. Defaults to `False`.
        refine_towards ('upper' or 'lower', optional): The side of the interval that is to be refined.
            Defaults to 'upper'.

    Returns:
        :obj:`ndarray` of shape `(n_cells)` containing grid steps.

    Examples:
        >>> dyt = get_stretched_grid_steps(14, 180, 5)
        >>> dyt
        [  5.10517337   5.22522948   5.47813251   5.99673813   7.00386752
           8.76808565  11.36450896  14.34977676  16.94620006  18.71041819
          19.71754758  20.2361532   20.48905624  20.60911234]
        >>> dyt.sum()
        180.0

        >>> dyt = get_stretched_grid_steps(14, 180, 5, stretching_factor=4.)
        >>> dyt
        [  5.00526979   5.01802837   5.06155549   5.20877528   5.69251688
           7.14225176  10.51307232  15.20121339  18.57203395  20.02176884
          20.50551044  20.65273022  20.69625734  20.70901593]
        >>> dyt.sum()
        180.0

    """

    if refine_towards not in ("upper", "lower"):
        raise ValueError('refine_towards must be "upper" or "lower"')
    if two_sided_grid:
        if n_cells % 2:
            raise ValueError(f"number of grid points must be even integer number (given: {n_cells})")
        n_cells = n_cells / 2

    stretching_function = npx.tanh(stretching_factor * npx.linspace(-1, 1, n_cells))

    if refine_towards == "lower":
        stretching_function = stretching_function[::-1]
    if two_sided_grid:
        stretching_function = npx.concatenate((stretching_function[::-1], stretching_function))

    def normalize_sum(var, sum_value, minimum_value=0.0):
        if abs(var.sum()) < 1e-5:
            var += 1
        var *= (sum_value - len(var) * minimum_value) / var.sum()
        return var + minimum_value

    stretching_function = normalize_sum(stretching_function, total_length, minimum_stepsize)
    assert abs(1 - npx.sum(stretching_function) / total_length) < 1e-5, "precision error"
    return stretching_function


def get_vinokur_grid_steps(
    n_cells, total_length, lower_stepsize, upper_stepsize=None, two_sided_grid=False, refine_towards="upper"
):
    """Computes stretched grid steps for regional and global domains with either
    one or two-sided stretching using Vinokur stretching.

    This stretching function minimizes discretization errors on finite difference
    grids.

    Arguments:
        n_cells (int): Number of grid points.
        total_length (float): Length of the grid interval to be covered (sum of the
            resulting grid steps).
        lower_stepsize (float): Grid step size on the lower end of the interval.
        upper_stepsize (float or ``None``, optional): Grid step size on the upper end of the interval.
            If not given, the one-sided version of the algorithm is used (that enforces zero curvature
            on the upper end).
        two_sided_grid (bool, optional): If set to `True`, the resulting grid will be symmetrical
            around the center. Defaults to `False`.
        refine_towards ('upper' or 'lower', optional): The side of the interval that is to be refined.
            Defaults to 'upper'.

    Returns:
        :obj:`ndarray` of shape `(n_cells)` containing grid steps.

    Reference:
        Vinokur, Marcel, On One-Dimensional Stretching Functions for Finite-Difference Calculations,
        Journal of Computational Physics. 50, 215, 1983.

    Examples:
        >>> dyt = get_vinokur_grid_steps(14, 180, 5, two_sided_grid=True)
        >>> dyt
        [ 18.2451554   17.23915939  15.43744632  13.17358802  10.78720589
           8.53852027   6.57892471   6.57892471   8.53852027  10.78720589
          13.17358802  15.43744632  17.23915939  18.2451554 ]
        >>> dyt.sum()
        180.

        >>> dyt = get_vinokur_grid_steps(14, 180, 5, upper_stepsize=10)
        >>> dyt
        [  5.9818365    7.3645667    8.92544833  10.61326984  12.33841985
          13.97292695  15.36197306  16.3485688   16.80714121  16.67536919
          15.97141714  14.78881918  13.27136448  11.57887877 ]
        >>> dyt.sum()
        180.

    """
    if refine_towards not in ("upper", "lower"):
        raise ValueError('refine_towards must be "upper" or "lower"')
    if two_sided_grid:
        if n_cells % 2:
            raise ValueError(f"number of grid points must be an even integer (given: {n_cells})")
        n_cells = n_cells // 2

    n_cells += 1

    def approximate_sinc_inverse(y):
        """Approximate inverse of sin(y) / y"""
        if y < 0.26938972:
            inv = npx.pi * (
                1
                - y
                + y**2
                - (1 + npx.pi**2 / 6) * y**3
                + 6.794732 * y**4
                - 13.205501 * y**5
                + 11.726095 * y**6
            )
        else:
            ybar = 1.0 - y
            inv = npx.sqrt(6 * ybar) * (
                1
                + 0.15 * ybar
                + 0.057321429 * ybar**2
                + 0.048774238 * ybar**3
                - 0.053337753 * ybar**4
                + 0.075845134 * ybar**5
            )
        assert abs(1 - npx.sin(inv) / inv / y) < 1e-2, "precision error"
        return inv

    def approximate_sinhc_inverse(y):
        """Approximate inverse of sinh(y) / y"""
        if y < 2.7829681:
            ybar = y - 1.0
            inv = npx.sqrt(6 * ybar) * (
                1
                - 0.15 * ybar
                + 0.057321429 * ybar**2
                - 0.024907295 * ybar**3
                + 0.0077424461 * ybar**4
                - 0.0010794123 * ybar**5
            )
        else:
            v = npx.log(y)
            w = 1.0 / y - 0.028527431
            inv = (
                v
                + (1 + 1.0 / v) * npx.log(2 * v)
                - 0.02041793
                + 0.24902722 * w
                + 1.9496443 * w**2
                - 2.6294547 * w**3
                + 8.56795911 * w**4
            )
        assert abs(1 - npx.sinh(inv) / inv / y) < 1e-2, "precision error"
        return inv

    target_sum = total_length
    if two_sided_grid:
        target_sum *= 0.5

    s0 = float(target_sum) / float(lower_stepsize * n_cells)
    if upper_stepsize:
        s1 = float(target_sum) / float(upper_stepsize * n_cells)
        a, b = npx.sqrt(s1 / s0), npx.sqrt(s1 * s0)
        if b > 1:
            stretching_factor = approximate_sinhc_inverse(b)
            stretched_grid = 0.5 + 0.5 * npx.tanh(stretching_factor * npx.linspace(-0.5, 0.5, n_cells)) / npx.tanh(
                0.5 * stretching_factor
            )
        else:
            stretching_factor = approximate_sinc_inverse(b)
            stretched_grid = 0.5 + 0.5 * npx.tan(stretching_factor * npx.linspace(-0.5, 0.5, n_cells)) / npx.tan(
                0.5 * stretching_factor
            )
        stretched_grid = stretched_grid / (a + (1.0 - a) * stretched_grid)
    else:
        if s0 > 1:
            stretching_factor = approximate_sinhc_inverse(s0) * 0.5
            stretched_grid = 1 + npx.tanh(stretching_factor * npx.linspace(0.0, 1.0, n_cells)) / npx.tanh(
                stretching_factor
            )
        else:
            stretching_factor = approximate_sinc_inverse(s0) * 0.5
            stretched_grid = 1 + npx.tan(stretching_factor * npx.linspace(0.0, 1.0, n_cells)) / npx.tan(
                stretching_factor
            )

    stretched_grid_steps = npx.diff(stretched_grid * target_sum)
    if refine_towards == "upper":
        stretched_grid_steps = stretched_grid_steps[::-1]
    if two_sided_grid:
        stretched_grid_steps = npx.concatenate((stretched_grid_steps[::-1], stretched_grid_steps))

    assert abs(1 - npx.sum(stretched_grid_steps) / total_length) < 1e-5, "precision error"
    return stretched_grid_steps