import functools from veros import runtime_settings as rs, runtime_state as rst from veros.routines import CURRENT_CONTEXT SCATTERED_DIMENSIONS = (("xt", "xu"), ("yt", "yu")) def dist_context_only(function=None, *, noop_return_arg=None): def decorator(function): @functools.wraps(function) def dist_context_only_wrapper(*args, **kwargs): if rst.proc_num == 1 or not CURRENT_CONTEXT.is_dist_safe: # no-op for sequential execution if noop_return_arg is None: return None # return input array unchanged return args[noop_return_arg] return function(*args, **kwargs) return dist_context_only_wrapper if function is not None: return decorator(function) return decorator def send(buf, dest, comm, tag=None): kwargs = {} if tag is not None: kwargs.update(tag=tag) if rs.backend == "jax": from mpi4jax import send token = CURRENT_CONTEXT.mpi4jax_token new_token = send(buf, dest=dest, comm=comm, token=token, **kwargs) CURRENT_CONTEXT.mpi4jax_token = new_token else: comm.Send(ascontiguousarray(buf), dest=dest, **kwargs) def recv(buf, source, comm, tag=None): kwargs = {} if tag is not None: kwargs.update(tag=tag) if rs.backend == "jax": from mpi4jax import recv token = CURRENT_CONTEXT.mpi4jax_token buf, new_token = recv(buf, source=source, comm=comm, token=token, **kwargs) CURRENT_CONTEXT.mpi4jax_token = new_token return buf buf = buf.copy() comm.Recv(buf, source=source, **kwargs) return buf def sendrecv(sendbuf, recvbuf, source, dest, comm, sendtag=None, recvtag=None): kwargs = {} if sendtag is not None: kwargs.update(sendtag=sendtag) if recvtag is not None: kwargs.update(recvtag=recvtag) if rs.backend == "jax": from mpi4jax import sendrecv token = CURRENT_CONTEXT.mpi4jax_token recvbuf, new_token = sendrecv(sendbuf, recvbuf, source=source, dest=dest, comm=comm, token=token, **kwargs) CURRENT_CONTEXT.mpi4jax_token = new_token return recvbuf recvbuf = recvbuf.copy() comm.Sendrecv(sendbuf=ascontiguousarray(sendbuf), recvbuf=recvbuf, source=source, dest=dest, **kwargs) return recvbuf def bcast(buf, comm, root=0): if rs.backend == "jax": from mpi4jax import bcast token = CURRENT_CONTEXT.mpi4jax_token buf, new_token = bcast(buf, root=root, comm=comm, token=token) CURRENT_CONTEXT.mpi4jax_token = new_token return buf return comm.bcast(buf, root=root) def allreduce(buf, op, comm): if rs.backend == "jax": from mpi4jax import allreduce token = CURRENT_CONTEXT.mpi4jax_token buf, new_token = allreduce(buf, op=op, comm=comm, token=token) CURRENT_CONTEXT.mpi4jax_token = new_token return buf from veros.core.operators import numpy as npx recvbuf = npx.empty_like(buf) comm.Allreduce(ascontiguousarray(buf), recvbuf, op=op) return recvbuf def ascontiguousarray(arr): assert rs.backend == "numpy" import numpy return numpy.ascontiguousarray(arr) def validate_decomposition(dimensions): nx, ny = dimensions["xt"], dimensions["yt"] if rs.mpi_comm is None: if rs.num_proc[0] > 1 or rs.num_proc[1] > 1: raise RuntimeError("mpi4py is required for distributed execution") return comm_size = rs.mpi_comm.Get_size() proc_num = rs.num_proc[0] * rs.num_proc[1] if proc_num != comm_size: raise RuntimeError(f"number of processes ({proc_num}) does not match size of communicator ({comm_size})") if nx % rs.num_proc[0]: raise ValueError("processes do not divide domain evenly in x-direction") if ny % rs.num_proc[1]: raise ValueError("processes do not divide domain evenly in y-direction") def get_chunk_size(nx, ny): return (nx // rs.num_proc[0], ny // rs.num_proc[1]) def proc_rank_to_index(rank): return (rank % rs.num_proc[0], rank // rs.num_proc[0]) def proc_index_to_rank(ix, iy): return ix + iy * rs.num_proc[0] def get_chunk_slices(nx, ny, dim_grid, proc_idx=None, include_overlap=False): if not dim_grid: return Ellipsis, Ellipsis if proc_idx is None: proc_idx = proc_rank_to_index(rst.proc_rank) px, py = proc_idx nxl, nyl = get_chunk_size(nx, ny) if include_overlap: sxl = 0 if px == 0 else 2 sxu = nxl + 4 if (px + 1) == rs.num_proc[0] else nxl + 2 syl = 0 if py == 0 else 2 syu = nyl + 4 if (py + 1) == rs.num_proc[1] else nyl + 2 else: sxl = syl = 0 sxu = nxl syu = nyl global_slice, local_slice = [], [] for dim in dim_grid: if dim in SCATTERED_DIMENSIONS[0]: global_slice.append(slice(sxl + px * nxl, sxu + px * nxl)) local_slice.append(slice(sxl, sxu)) elif dim in SCATTERED_DIMENSIONS[1]: global_slice.append(slice(syl + py * nyl, syu + py * nyl)) local_slice.append(slice(syl, syu)) else: global_slice.append(slice(None)) local_slice.append(slice(None)) return tuple(global_slice), tuple(local_slice) def get_process_neighbors(cyclic=False): this_x, this_y = proc_rank_to_index(rst.proc_rank) if this_x == 0: if cyclic: west = rs.num_proc[0] - 1 else: west = None else: west = this_x - 1 if this_x == rs.num_proc[0] - 1: if cyclic: east = 0 else: east = None else: east = this_x + 1 south = this_y - 1 if this_y != 0 else None north = this_y + 1 if this_y != (rs.num_proc[1] - 1) else None neighbors = dict( # direct neighbors west=(west, this_y), south=(this_x, south), east=(east, this_y), north=(this_x, north), # corners southwest=(west, south), southeast=(east, south), northeast=(east, north), northwest=(west, north), ) global_neighbors = {k: proc_index_to_rank(*i) if None not in i else None for k, i in neighbors.items()} return global_neighbors @dist_context_only(noop_return_arg=0) def exchange_overlap(arr, var_grid, cyclic): from veros.core.operators import numpy as npx, update, at # start west, go clockwise send_order = ( "west", "northwest", "north", "northeast", "east", "southeast", "south", "southwest", ) # start east, go clockwise recv_order = ( "east", "southeast", "south", "southwest", "west", "northwest", "north", "northeast", ) if len(var_grid) < 2: d1, d2 = var_grid[0], None else: d1, d2 = var_grid[:2] if d1 not in SCATTERED_DIMENSIONS[0] and d1 not in SCATTERED_DIMENSIONS[1] and d2 not in SCATTERED_DIMENSIONS[1]: # neither x nor y dependent, nothing to do return arr proc_neighbors = get_process_neighbors(cyclic) if d1 in SCATTERED_DIMENSIONS[0] and d2 in SCATTERED_DIMENSIONS[1]: overlap_slices_from = dict( west=(slice(2, 4), slice(0, None), Ellipsis), south=(slice(0, None), slice(2, 4), Ellipsis), east=(slice(-4, -2), slice(0, None), Ellipsis), north=(slice(0, None), slice(-4, -2), Ellipsis), southwest=(slice(2, 4), slice(2, 4), Ellipsis), southeast=(slice(-4, -2), slice(2, 4), Ellipsis), northeast=(slice(-4, -2), slice(-4, -2), Ellipsis), northwest=(slice(2, 4), slice(-4, -2), Ellipsis), ) overlap_slices_to = dict( west=(slice(0, 2), slice(0, None), Ellipsis), south=(slice(0, None), slice(0, 2), Ellipsis), east=(slice(-2, None), slice(0, None), Ellipsis), north=(slice(0, None), slice(-2, None), Ellipsis), southwest=(slice(0, 2), slice(0, 2), Ellipsis), southeast=(slice(-2, None), slice(0, 2), Ellipsis), northeast=(slice(-2, None), slice(-2, None), Ellipsis), northwest=(slice(0, 2), slice(-2, None), Ellipsis), ) else: if d1 in SCATTERED_DIMENSIONS[0]: send_order = ("west", "east") recv_order = ("east", "west") elif d1 in SCATTERED_DIMENSIONS[1]: send_order = ("north", "south") recv_order = ("south", "north") else: raise NotImplementedError() overlap_slices_from = dict( west=(slice(2, 4), Ellipsis), south=(slice(2, 4), Ellipsis), east=(slice(-4, -2), Ellipsis), north=(slice(-4, -2), Ellipsis), ) overlap_slices_to = dict( west=(slice(0, 2), Ellipsis), south=(slice(0, 2), Ellipsis), east=(slice(-2, None), Ellipsis), north=(slice(-2, None), Ellipsis), ) for send_dir, recv_dir in zip(send_order, recv_order): send_proc = proc_neighbors[send_dir] recv_proc = proc_neighbors[recv_dir] if send_proc is None and recv_proc is None: continue recv_idx = overlap_slices_to[recv_dir] recv_arr = npx.empty_like(arr[recv_idx]) send_idx = overlap_slices_from[send_dir] send_arr = arr[send_idx] if send_proc is None: recv_arr = recv(recv_arr, recv_proc, rs.mpi_comm) arr = update(arr, at[recv_idx], recv_arr) elif recv_proc is None: send(send_arr, send_proc, rs.mpi_comm) else: recv_arr = sendrecv(send_arr, recv_arr, source=recv_proc, dest=send_proc, comm=rs.mpi_comm) arr = update(arr, at[recv_idx], recv_arr) return arr def _memoize(function): cached = {} @functools.wraps(function) def memoized(*args): from mpi4py import MPI # MPI Comms are not hashable, so we use the underlying handle instead cache_args = tuple(MPI._handleof(arg) if isinstance(arg, MPI.Comm) else arg for arg in args) if cache_args not in cached: cached[cache_args] = function(*args) return cached[cache_args] return memoized @_memoize def _mpi_comm_along_axis(comm, procs, rank): return comm.Split(procs, rank) @dist_context_only(noop_return_arg=0) def _reduce(arr, op, axis=None): from veros.core.operators import numpy as npx if axis is None: comm = rs.mpi_comm else: assert axis in (0, 1) pi = proc_rank_to_index(rst.proc_rank) other_axis = 1 - axis comm = _mpi_comm_along_axis(rs.mpi_comm, pi[other_axis], rst.proc_rank) if npx.isscalar(arr): squeeze = True arr = npx.array([arr]) else: squeeze = False res = allreduce(arr, op=op, comm=comm) if squeeze: res = res[0] return res @dist_context_only(noop_return_arg=0) def global_and(arr, axis=None): from mpi4py import MPI return _reduce(arr, MPI.LAND, axis=axis) @dist_context_only(noop_return_arg=0) def global_or(arr, axis=None): from mpi4py import MPI return _reduce(arr, MPI.LOR, axis=axis) @dist_context_only(noop_return_arg=0) def global_max(arr, axis=None): from mpi4py import MPI return _reduce(arr, MPI.MAX, axis=axis) @dist_context_only(noop_return_arg=0) def global_min(arr, axis=None): from mpi4py import MPI return _reduce(arr, MPI.MIN, axis=axis) @dist_context_only(noop_return_arg=0) def global_sum(arr, axis=None): from mpi4py import MPI return _reduce(arr, MPI.SUM, axis=axis) @dist_context_only(noop_return_arg=2) def _gather_1d(nx, ny, arr, dim): from veros.core.operators import numpy as npx, update, at assert dim in (0, 1) otherdim = 1 - dim pi = proc_rank_to_index(rst.proc_rank) if pi[otherdim] != 0: return arr dim_grid = ["xt" if dim == 0 else "yt"] + [None] * (arr.ndim - 1) gidx, idx = get_chunk_slices(nx, ny, dim_grid, include_overlap=True) sendbuf = arr[idx] if rst.proc_rank == 0: buffer_list = [] for proc in range(1, rst.proc_num): pi = proc_rank_to_index(proc) if pi[otherdim] != 0: continue idx_g, idx_l = get_chunk_slices(nx, ny, dim_grid, include_overlap=True, proc_idx=pi) recvbuf = npx.empty_like(arr[idx_l]) recvbuf = recv(recvbuf, source=proc, tag=20, comm=rs.mpi_comm) buffer_list.append((idx_g, recvbuf)) out_shape = ((nx + 4, ny + 4)[dim],) + arr.shape[1:] out = npx.empty(out_shape, dtype=arr.dtype) out = update(out, at[gidx], sendbuf) for idx, val in buffer_list: out = update(out, at[idx], val) return out else: send(sendbuf, dest=0, tag=20, comm=rs.mpi_comm) return arr @dist_context_only(noop_return_arg=2) def _gather_xy(nx, ny, arr): from veros.core.operators import numpy as npx, update, at nxi, nyi = get_chunk_size(nx, ny) assert arr.shape[:2] == (nxi + 4, nyi + 4), arr.shape dim_grid = ["xt", "yt"] + [None] * (arr.ndim - 2) gidx, idx = get_chunk_slices(nx, ny, dim_grid, include_overlap=True) sendbuf = arr[idx] if rst.proc_rank == 0: buffer_list = [] for proc in range(1, rst.proc_num): idx_g, idx_l = get_chunk_slices(nx, ny, dim_grid, include_overlap=True, proc_idx=proc_rank_to_index(proc)) recvbuf = npx.empty_like(arr[idx_l]) recvbuf = recv(recvbuf, source=proc, tag=30, comm=rs.mpi_comm) buffer_list.append((idx_g, recvbuf)) out_shape = (nx + 4, ny + 4) + arr.shape[2:] out = npx.empty(out_shape, dtype=arr.dtype) out = update(out, at[gidx], sendbuf) for idx, val in buffer_list: out = update(out, at[idx], val) return out send(sendbuf, dest=0, tag=30, comm=rs.mpi_comm) return arr @dist_context_only(noop_return_arg=0) def gather(arr, dimensions, var_grid): nx, ny = dimensions["xt"], dimensions["yt"] if var_grid is None: return arr if len(var_grid) < 2: d1, d2 = var_grid[0], None else: d1, d2 = var_grid[:2] if d1 not in SCATTERED_DIMENSIONS[0] and d1 not in SCATTERED_DIMENSIONS[1] and d2 not in SCATTERED_DIMENSIONS[1]: # neither x nor y dependent, nothing to do return arr if d1 in SCATTERED_DIMENSIONS[0] and d2 not in SCATTERED_DIMENSIONS[1]: # only x dependent return _gather_1d(nx, ny, arr, 0) elif d1 in SCATTERED_DIMENSIONS[1]: # only y dependent return _gather_1d(nx, ny, arr, 1) elif d1 in SCATTERED_DIMENSIONS[0] and d2 in SCATTERED_DIMENSIONS[1]: # x and y dependent return _gather_xy(nx, ny, arr) else: raise NotImplementedError() @dist_context_only(noop_return_arg=0) def _scatter_constant(arr): return bcast(arr, rs.mpi_comm, root=0) @dist_context_only(noop_return_arg=2) def _scatter_1d(nx, ny, arr, dim): from veros.core.operators import numpy as npx, update, at assert dim in (0, 1) out_nx = get_chunk_size(nx, ny)[dim] dim_grid = ["xt" if dim == 0 else "yt"] + [None] * (arr.ndim - 1) _, local_slice = get_chunk_slices(nx, ny, dim_grid, include_overlap=True) if rst.proc_rank == 0: recvbuf = arr[local_slice] for proc in range(1, rst.proc_num): global_slice, _ = get_chunk_slices( nx, ny, dim_grid, include_overlap=True, proc_idx=proc_rank_to_index(proc) ) sendbuf = arr[global_slice] send(sendbuf, dest=proc, tag=40, comm=rs.mpi_comm) # arr changes shape in main process arr = npx.zeros((out_nx + 4,) + arr.shape[1:], dtype=arr.dtype) else: recvbuf = recv(arr[local_slice], source=0, tag=40, comm=rs.mpi_comm) arr = update(arr, at[local_slice], recvbuf) arr = exchange_overlap(arr, ["xt" if dim == 0 else "yt"], cyclic=False) return arr @dist_context_only(noop_return_arg=2) def _scatter_xy(nx, ny, arr): from veros.core.operators import numpy as npx, update, at nxi, nyi = get_chunk_size(nx, ny) dim_grid = ["xt", "yt"] + [None] * (arr.ndim - 2) _, local_slice = get_chunk_slices(nx, ny, dim_grid, include_overlap=True) if rst.proc_rank == 0: recvbuf = arr[local_slice] for proc in range(1, rst.proc_num): global_slice, _ = get_chunk_slices( nx, ny, dim_grid, include_overlap=True, proc_idx=proc_rank_to_index(proc) ) sendbuf = arr[global_slice] send(sendbuf, dest=proc, tag=50, comm=rs.mpi_comm) # arr changes shape in main process arr = npx.empty((nxi + 4, nyi + 4) + arr.shape[2:], dtype=arr.dtype) else: recvbuf = npx.empty_like(arr[local_slice]) recvbuf = recv(recvbuf, source=0, tag=50, comm=rs.mpi_comm) arr = update(arr, at[local_slice], recvbuf) arr = exchange_overlap(arr, ["xt", "yt"], cyclic=False) return arr @dist_context_only(noop_return_arg=0) def scatter(arr, dimensions, var_grid): from veros.core.operators import numpy as npx if var_grid is None: return _scatter_constant(arr) nx, ny = dimensions["xt"], dimensions["yt"] if len(var_grid) < 2: d1, d2 = var_grid[0], None else: d1, d2 = var_grid[:2] arr = npx.asarray(arr) if d1 not in SCATTERED_DIMENSIONS[0] and d1 not in SCATTERED_DIMENSIONS[1] and d2 not in SCATTERED_DIMENSIONS[1]: # neither x nor y dependent return _scatter_constant(arr) if d1 in SCATTERED_DIMENSIONS[0] and d2 not in SCATTERED_DIMENSIONS[1]: # only x dependent return _scatter_1d(nx, ny, arr, 0) elif d1 in SCATTERED_DIMENSIONS[1]: # only y dependent return _scatter_1d(nx, ny, arr, 1) elif d1 in SCATTERED_DIMENSIONS[0] and d2 in SCATTERED_DIMENSIONS[1]: # x and y dependent return _scatter_xy(nx, ny, arr) else: raise NotImplementedError("unreachable") @dist_context_only def barrier(): rs.mpi_comm.barrier() @dist_context_only def abort(): rs.mpi_comm.Abort()