island.py 3.75 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import scipy.ndimage

from veros import veros_routine, logger
from veros.core import utilities
from veros.core.operators import numpy as npx

# fall back to vanilla NumPy for some operations
import numpy as onp


def _compute_isleperim(kmt, enable_cyclic_x):
    # TODO: remove this check after jax#6907 has landed
    if enable_cyclic_x:
        kmt = utilities.enforce_boundaries(kmt, enable_cyclic_x)

    kmt = onp.asarray(kmt)

    structure = onp.ones((3, 3))  # merge diagonally connected land masses

    # find all land masses
    labelled, _ = scipy.ndimage.label(kmt == 0, structure=structure)

    # find and set perimeter
    land_masses = labelled > 0
    inner = scipy.ndimage.binary_dilation(land_masses, structure=structure)

    perimeter = onp.logical_xor(inner, land_masses)
    labelled[perimeter] = -1

    # match wrapping periodic land masses
    if enable_cyclic_x:
        west_slice = onp.array(labelled[2])
        east_slice = onp.array(labelled[-2])

        for west_label in onp.unique(west_slice[west_slice > 0]):
            east_labels = onp.unique(east_slice[west_slice == west_label])
            east_labels = east_labels[~onp.isin(east_labels, [west_label, -1])]
            if not east_labels.size:
                # already labelled correctly
                continue
            assert len(onp.unique(east_labels)) == 1, (west_label, east_labels)
            labelled[labelled == east_labels[0]] = west_label

    # TODO: remove this check after jax#6907 has landed
    if enable_cyclic_x:
        labelled = utilities.enforce_boundaries(labelled, enable_cyclic_x)

    labelled = onp.asarray(labelled)

    # label landmasses in a way that is consistent with pyom
    labels = onp.unique(labelled[labelled > 0])

    label_idx = {}
    for label in labels:
        # find index of first island cell, scanning west to east, north to south
        label_idx[label] = onp.argmax(labelled[:, ::-1].T == label)

    sorted_labels = list(sorted(labels, key=lambda i: label_idx[i]))

    # ensure labels are numbered consecutively
    relabelled = onp.array(labelled)
    for new_label, label in enumerate(sorted_labels, 1):
        if label == new_label:
            continue
        relabelled[labelled == label] = new_label

    return npx.asarray(relabelled)


@veros_routine(dist_safe=False, local_variables=("kbot", "land_map"))
def isleperim(state):
    vs = state.variables
    settings = state.settings

    logger.debug(" Determining number of land masses")
    vs.land_map = _compute_isleperim(vs.kbot, settings.enable_cyclic_x)

    if vs.land_map.size < 10_000:
        logger.debug(_ascii_map(vs.land_map))


def _ascii_map(boundary_map):
    def _get_char(c):
        if c == 0:
            return "."
        if c < 0:
            return "#"
        return str(c % 10)

    boundary_map = onp.array(boundary_map)
    nx, ny = boundary_map.shape

    map_string = ""
    linewidth = 100
    iremain = nx
    istart = 0
    map_string += "\n"
    map_string += " " * (5 + min(linewidth, nx) // 2 - 13) + "Land mass and perimeter"
    map_string += "\n"
    for _ in range(1, nx // linewidth + 2):
        iline = min(iremain, linewidth)
        iremain = iremain - iline
        if iline > 0:
            map_string += "\n"
            map_string += "".join([f"{istart + i + 1 - 2:5d}" for i in range(1, iline + 1, 5)])
            map_string += "\n"
            for j in range(ny - 1, -1, -1):
                map_string += f"{j:3d} "
                map_string += "".join([_get_char(boundary_map[istart + i - 2, j]) for i in range(2, iline + 2)])
                map_string += "\n"
            map_string += "".join([f"{istart + i + 1 - 2:5d}" for i in range(1, iline + 1, 5)])
            map_string += "\n"
            istart = istart + iline
    map_string += "\n"

    return map_string