Commit ef746cfa authored by mashun1's avatar mashun1
Browse files

veros

parents
Pipeline #1302 canceled with stages
[tool.black]
line-length = 120
target-version = ['py37', 'py38', 'py39', 'py310']
[tool.pytest.ini_options]
testpaths = [
"test",
]
[build-system]
requires = ["setuptools>=65.0.0", "wheel", "cython"]
build-backend = "setuptools.build_meta"
click==8.1.7
entrypoints==0.4
requests==2.32.3
numpy==2.0.0
scipy==1.13.1
h5netcdf==1.3.0
h5py==3.11.0
pillow==10.3.0
loguru==0.7.2
tqdm==4.66.4
xarray
matplitlib
netcdf4
cmocean
cython
\ No newline at end of file
#! /usr/bin/env python
import sys
import os
import subprocess
import multiprocessing
import importlib.util
import re
import time
import math
import itertools
import json
import click
import numpy as np
"""
Runs selected Veros benchmarks back to back and writes timing results to a JSON file.
"""
TESTDIR = os.path.join(os.path.dirname(__file__), os.path.relpath("benchmarks"))
COMPONENTS = ["numpy", "numpy-mpi", "jax", "jax-gpu", "jax-mpi", "jax-gpu-mpi", "fortran", "fortran-mpi"]
STATIC_SETTINGS = " --size {nx} {ny} {nz} --timesteps {timesteps} --float-type {float_type}"
BENCHMARK_COMMANDS = {
"numpy": "{python} {filename}" + STATIC_SETTINGS,
"numpy-mpi": "OMP_NUM_THREADS=1 {mpiexec} -n {nproc} {python} {filename} --nproc {decomp}" + STATIC_SETTINGS,
"jax": "{python} {filename} -b jax" + STATIC_SETTINGS,
"jax-gpu": "{python} {filename} -b jax --device gpu" + STATIC_SETTINGS,
"jax-mpi": "OMP_NUM_THREADS=1 {mpiexec} -n {nproc} {python} {filename} -b jax --nproc {decomp}" + STATIC_SETTINGS,
"jax-gpu-mpi": "OMP_NUM_THREADS=1 {mpiexec} -n {nproc} {python} {filename} -b jax --device gpu --nproc {decomp}"
+ STATIC_SETTINGS,
"fortran": "{python} {filename} --pyom2-lib {pyom2_lib}" + STATIC_SETTINGS,
"fortran-mpi": "{mpiexec} -n {nproc} {python} {filename} --pyom2-lib {pyom2_lib} --nproc {decomp}"
+ STATIC_SETTINGS,
}
SLURM_COMMANDS = {
"numpy": "{mpiexec} --ntasks 1 --cpus-per-task {nproc} -- {python} {filename} -b numpy" + STATIC_SETTINGS,
"numpy-mpi": "{mpiexec} --ntasks {nproc} --cpus-per-task 1 -- {python} {filename} -b numpy --nproc {decomp}"
+ STATIC_SETTINGS,
"jax": "{mpiexec} --ntasks 1 --cpus-per-task {nproc} -- {python} {filename} -b jax" + STATIC_SETTINGS,
"jax-gpu": "{mpiexec} --ntasks 1 --cpus-per-task {nproc} -- {python} {filename} -b jax --device gpu"
+ STATIC_SETTINGS,
"jax-mpi": "{mpiexec} --ntasks {nproc} --cpus-per-task 1 -- {python} {filename} -b jax --nproc {decomp}"
+ STATIC_SETTINGS,
"jax-gpu-mpi": "{mpiexec} --ntasks {nproc} --cpus-per-task 1 -- {python} {filename} -b jax --device gpu --nproc {decomp}"
+ STATIC_SETTINGS,
"fortran": "{mpiexec} --ntasks 1 -- {python} {filename} --pyom2-lib {pyom2_lib}" + STATIC_SETTINGS,
"fortran-mpi": "{mpiexec} --ntasks {nproc} --cpus-per-task 1 -- {python} {filename} --pyom2-lib {pyom2_lib} --nproc {decomp}"
+ STATIC_SETTINGS,
}
AVAILABLE_BENCHMARKS = [f for f in os.listdir(TESTDIR) if f.endswith("_benchmark.py")]
TIME_PATTERN = r"Time step took ([-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?)s"
def check_arguments(pyom2_lib, components, float_type, burnin, timesteps, **kwargs):
fortran_version = check_pyom2_lib(pyom2_lib)
if "fortran" in components or "fortran-mpi" in components:
if not pyom2_lib:
raise click.UsageError("Path to fortran library must be given when running fortran components")
if not fortran_version:
raise click.UsageError("Fortran library failed to import")
if fortran_version != "parallel" and "fortran-mpi" in components:
raise click.UsageError("Fortran library must be compiled with MPI support for fortran-mpi component")
if float_type != "float64" and ("fortran" in components or "fortran-mpi" in components):
raise click.UsageError('Can run Fortran components only with "float64" float type')
if not burnin < timesteps:
raise click.UsageError("burnin must be smaller than number of timesteps")
def check_pyom2_lib(path):
if not path:
return None
def _check_library(module):
spec = importlib.util.spec_from_file_location(module, path)
try:
importlib.util.module_from_spec(spec)
except ImportError:
return False
else:
return True
if _check_library("pyOM_code"):
return "sequential"
if _check_library("pyOM_code_MPI"):
return "parallel"
return None
def _factorize(num):
j = 2
while num > 1:
for i in range(j, int(math.sqrt(num + 0.05)) + 1):
if num % i == 0:
num /= i
j = i
yield i
break
else:
if num > 1:
yield num
break
def _decompose_num(num, into=2):
out = [1] * into
for fac, i in zip(_factorize(num), itertools.cycle(range(into))):
out[i] *= fac
return tuple(map(int, out))
def _round_to_multiple(num, divisor):
return int(round(num / divisor) * divisor)
@click.command("veros-benchmarks", help="Run Veros benchmarks")
@click.option("-f", "--pyom2-lib", type=str, help="Path to PyOM2 fortran library")
@click.option(
"-s", "--sizes", multiple=True, type=float, required=True, help="Problem sizes to test (total number of elements)"
)
@click.option(
"-c",
"--components",
multiple=True,
type=click.Choice(COMPONENTS),
default=["numpy"],
metavar="COMPONENT",
help="Numerical backend components to benchmark (possible values: {})".format(", ".join(COMPONENTS)),
)
@click.option(
"-n",
"--nproc",
type=int,
default=multiprocessing.cpu_count(),
help="Number of processes / threads for parallel execution",
)
@click.option(
"-o",
"--outfile",
type=click.Path(exists=False),
default="benchmark_{}.json".format(time.time()),
help="JSON file to write timings to",
)
@click.option("-t", "--timesteps", default=100, type=int, help="Number of time steps that each benchmark is run for")
@click.option(
"--only",
multiple=True,
default=AVAILABLE_BENCHMARKS,
help="Run only these benchmarks (possible values: {})".format(", ".join(AVAILABLE_BENCHMARKS)),
type=click.Choice(AVAILABLE_BENCHMARKS),
required=False,
metavar="BENCHMARK",
)
@click.option("--mpiexec", default=None, help="Executable used for calling MPI (e.g. mpirun, mpiexec)")
@click.option("--slurm", is_flag=True, help="Run benchmarks using SLURM scheduling command (srun)")
@click.option("--debug", is_flag=True, help="Additionally print each command that is executed")
@click.option("--float-type", default="float64", help="Data type for floating point arrays in Veros components")
@click.option("--burnin", default=3, type=int, help="Number of iterations to exclude in timings")
def run(**kwargs):
check_arguments(**kwargs)
proc_decom = _decompose_num(kwargs["nproc"], 2)
settings = kwargs.copy()
settings["decomp"] = f"{proc_decom[0]} {proc_decom[1]}"
out_data = {}
all_passed = True
try:
for f in kwargs["only"]:
out_data[f] = []
click.echo(f"running benchmark {f}")
for size in kwargs["sizes"]:
nz = min(max(math.ceil(0.5 * size ** (1 / 3)), 2), 120)
n = math.ceil((size / nz) ** (1 / 2))
nx = _round_to_multiple(n, proc_decom[0])
ny = _round_to_multiple(n, proc_decom[1])
real_size = nx * ny * nz
click.echo(f" current size: {real_size}")
cmd_args = settings.copy()
cmd_args.update(
{
"python": sys.executable,
"filename": os.path.realpath(os.path.join(TESTDIR, f)),
"nx": nx,
"ny": ny,
"nz": nz,
}
)
if cmd_args["mpiexec"] is None:
if kwargs["slurm"]:
cmd_args["mpiexec"] = "srun"
else:
cmd_args["mpiexec"] = "mpirun"
for comp in kwargs["components"]:
cmd = (SLURM_COMMANDS[comp] if kwargs["slurm"] else BENCHMARK_COMMANDS[comp]).format(**cmd_args)
if kwargs["debug"]:
click.echo(f" $ {cmd}")
sys.stdout.write(f" {comp:<15} ... ")
sys.stdout.flush()
try: # must run each benchmark in its own Python subprocess to reload the Fortran library
output = subprocess.check_output(cmd, shell=True, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
click.echo("failed")
click.echo(e.output.decode("utf-8"))
all_passed = False
continue
output = output.decode("utf-8")
iteration_times = list(map(float, re.findall(TIME_PATTERN, output)))[kwargs["burnin"] :]
if not iteration_times:
raise RuntimeError("could not extract iteration times from output")
total_elapsed = sum(iteration_times)
click.echo(f"{total_elapsed:>6.2f}s")
out_data[f].append(
{
"component": comp,
"size": real_size,
"wall_time": total_elapsed,
"per_iteration": {
"best": float(np.min(iteration_times)),
"worst": float(np.max(iteration_times)),
"mean": float(np.mean(iteration_times)),
"stdev": float(np.std(iteration_times)),
},
}
)
finally:
with open(kwargs["outfile"], "w") as f:
json.dump({"benchmarks": out_data, "settings": settings}, f, indent=4, sort_keys=True)
raise SystemExit(int(not all_passed))
if __name__ == "__main__":
run()
[metadata]
description-file = README.md
[versioneer]
VCS = git
style = pep440
versionfile_source = veros/_version.py
versionfile_build = veros/_version.py
tag_prefix = v
[flake8]
exclude = veros/tools/filelock.py
max-line-length = 120
select = C,E,F,W,B,B950
extend-ignore = E203,E501,W503
#!/usr/bin/env python
# coding=utf-8
from setuptools import setup, find_packages
from setuptools.extension import Extension
from codecs import open
import os
import re
import sys
from Cython.Build import cythonize
here = os.path.abspath(os.path.dirname(__file__))
sys.path.append(here)
import versioneer # noqa: E402
import cuda_ext # noqa: E402
CLASSIFIERS = """
Development Status :: 4 - Beta
Intended Audience :: Science/Research
License :: OSI Approved :: MIT License
Programming Language :: Python :: 3
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Programming Language :: Python :: Implementation :: CPython
Topic :: Scientific/Engineering
Operating System :: Microsoft :: Windows
Operating System :: POSIX
Operating System :: Unix
Operating System :: MacOS
"""
MINIMUM_VERSIONS = {
"numpy": "1.13",
"requests": "2.18",
"jax": "0.2.10",
}
CONSOLE_SCRIPTS = [
"veros = veros.cli.veros:cli",
"veros-run = veros.cli.veros_run:cli",
"veros-copy-setup = veros.cli.veros_copy_setup:cli",
"veros-resubmit = veros.cli.veros_resubmit:cli",
"veros-create-mask = veros.cli.veros_create_mask:cli",
]
PACKAGE_DATA = ["setups/*/assets.json", "setups/*/*.npy", "setups/*/*.png"]
with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
long_description = f.read()
def parse_requirements(reqfile):
requirements = []
with open(os.path.join(here, reqfile), encoding="utf-8") as f:
for line in f:
line = line.strip()
pkg = re.match(r"(\w+)\b.*", line).group(1)
if pkg in MINIMUM_VERSIONS:
line = "".join([line, ",>=", MINIMUM_VERSIONS[pkg]])
line = line.replace("==", "<=")
requirements.append(line)
return requirements
INSTALL_REQUIRES = parse_requirements("requirements.txt")
# jax_req = parse_requirements("requirements_jax.txt")
# for line in jax_req: # inject jaxlib requirement
# if line.startswith("jax"):
# jax_req.append(line.replace("jax", "jaxlib"))
# break
EXTRAS_REQUIRE = {
"test": ["pytest", "pytest-cov", "pytest-forked", "xarray"],
# "jax": jax_req,
}
def get_extensions(require_cython_ext, require_cuda_ext):
cuda_info = cuda_ext.cuda_info
extension_modules = {
"veros.core.special.tdma_cython_": ["tdma_cython_.pyx"],
"veros.core.special.tdma_cuda_": ["tdma_cuda_.pyx", "cuda_tdma_kernels.cu"],
}
def is_cuda_ext(sources):
return any(source.endswith(".cu") for source in sources)
extensions = []
for module, sources in extension_modules.items():
extension_dir = os.path.join(*module.split(".")[:-1])
kwargs = dict()
if is_cuda_ext(sources):
kwargs.update(
library_dirs=cuda_info["lib64"],
libraries=["cudart"],
runtime_library_dirs=cuda_info["lib64"],
include_dirs=cuda_info["include"],
)
ext = Extension(
name=module,
sources=[os.path.join(extension_dir, f) for f in sources],
extra_compile_args={
"gcc": [],
"nvcc": cuda_info["cflags"],
},
**kwargs,
)
extensions.append(ext)
extensions = cythonize(extensions, language_level=3, exclude_failures=True)
for ext in extensions:
is_required = (not is_cuda_ext(ext.sources) and require_cython_ext) or (
is_cuda_ext(ext.sources) and require_cuda_ext
)
if not is_required:
ext.optional = True
return extensions
cmdclass = versioneer.get_cmdclass()
build_ext = type("custom_build_ext", (cuda_ext.custom_build_ext, cmdclass["build_ext"]), {})
cmdclass.update(build_ext=build_ext)
def _env_to_bool(envvar):
return os.environ.get(envvar, "").lower() in ("1", "true", "on")
extensions = get_extensions(
require_cython_ext=_env_to_bool("VEROS_REQUIRE_CYTHON_EXT"),
require_cuda_ext=_env_to_bool("VEROS_REQUIRE_CUDA_EXT"),
)
setup(
name="veros",
license="MIT",
author="Dion Häfner (NBI Copenhagen)",
author_email="dion.haefner@nbi.ku.dk",
keywords="oceanography python parallel numpy multi-core geophysics ocean-model mpi4py jax",
description="The versatile ocean simulator, in pure Python, powered by JAX.",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://veros.readthedocs.io",
python_requires=">=3.8",
version=versioneer.get_version(),
cmdclass=cmdclass,
packages=find_packages(),
install_requires=INSTALL_REQUIRES,
extras_require=EXTRAS_REQUIRE,
ext_modules=extensions,
entry_points={"console_scripts": CONSOLE_SCRIPTS, "veros.setup_dirs": ["base = veros.setups"]},
package_data={"veros": PACKAGE_DATA},
classifiers=[c for c in CLASSIFIERS.split("\n") if c],
zip_safe=False,
)
import os
import sys
import filecmp
import fnmatch
import pkg_resources
import subprocess
from textwrap import dedent
from click.testing import CliRunner
import pytest
import veros.cli
SETUPS = (
"acc",
"acc_basic",
"global_4deg",
"global_1deg",
"global_flexible",
"north_atlantic",
)
@pytest.fixture(scope="module")
def runner():
return CliRunner()
@pytest.mark.parametrize("setup", SETUPS)
def test_veros_copy_setup(setup, runner, tmpdir):
result = runner.invoke(veros.cli.veros_copy_setup.cli, [setup, "--to", os.path.join(tmpdir, setup)])
assert result.exit_code == 0, setup
assert not result.output
outpath = os.path.join(tmpdir, setup)
srcpath = pkg_resources.resource_filename("veros", f"setups/{setup}")
ignore = [
f
for f in os.listdir(srcpath)
if any(fnmatch.fnmatch(f, pattern) for pattern in veros.cli.veros_copy_setup.IGNORE_PATTERNS)
]
comparer = filecmp.dircmp(outpath, srcpath, ignore=ignore)
assert not comparer.left_only and not comparer.right_only
with open(os.path.join(outpath, f"{setup}.py"), "r") as f:
setup_content = f.read()
assert "VEROS_VERSION" in setup_content
def test_veros_run(runner, tmpdir):
from veros import runtime_settings as rs
setup = "acc"
with runner.isolated_filesystem(tmpdir):
result = runner.invoke(veros.cli.veros_copy_setup.cli, [setup])
old_rs = {key: getattr(rs, key) for key in rs.__settings__}
object.__setattr__(rs, "__locked__", False)
try:
result = runner.invoke(
veros.cli.veros_run.cli, [os.path.join(setup, f"{setup}.py"), "--backend", rs.backend]
)
finally:
# restore old settings
for key, val in old_rs.items():
object.__setattr__(rs, key, val)
assert result.exit_code == 0
def test_import_isolation(tmpdir):
TEST_KERNEL = dedent(
"""
import sys
import veros.cli
for mod in sys.modules:
print(mod)
"""
)
tmpfile = tmpdir / "isolation.py"
with open(tmpfile, "w") as f:
f.write(TEST_KERNEL)
proc = subprocess.run([sys.executable, tmpfile], check=True, capture_output=True, text=True)
imported_modules = proc.stdout.split()
veros_modules = [mod for mod in imported_modules if mod.startswith("veros.")]
for mod in veros_modules:
assert mod.startswith("veros.cli") or mod == "veros._version"
# make sure using the CLI does not initialize MPI
assert "mpi4py" not in imported_modules
import os
import pytest
def pytest_addoption(parser):
parser.addoption("--pyom2-lib", default=None, help="Path to PyOM2 library (must be given for consistency tests)")
parser.addoption("--backend", choices=["numpy", "jax"], default="numpy", help="Numerical backend to test")
def pytest_configure(config):
backend = config.getoption("--backend")
os.environ["VEROS_BACKEND"] = backend
def pytest_collection_modifyitems(config, items):
if config.getoption("--pyom2-lib"):
return
skip = pytest.mark.skip(reason="need --pyom2-lib option to run")
for item in items:
if "pyom2_lib" in item.fixturenames:
item.add_marker(skip)
def pytest_generate_tests(metafunc):
option_value = metafunc.config.option.pyom2_lib
if "pyom2_lib" in metafunc.fixturenames and option_value is not None:
metafunc.parametrize("pyom2_lib", [option_value])
@pytest.fixture(autouse=True)
def set_random_seed():
import numpy as np
np.random.seed(17)
import sys
import numpy as np
from mpi4py import MPI
from veros import runtime_settings as rs, runtime_state as rst
from veros.distributed import gather
rs.linear_solver = "scipy"
rs.diskless_mode = True
if rst.proc_num > 1:
rs.num_proc = (2, 2)
assert rst.proc_num == 4
from veros.setups.acc import ACCSetup # noqa: E402
sim = ACCSetup(
override=dict(
runlen=86400 * 10,
)
)
if rst.proc_num == 1:
comm = MPI.COMM_SELF.Spawn(sys.executable, args=["-m", "mpi4py", sys.argv[-1]], maxprocs=4)
try:
sim.setup()
sim.run()
except Exception as exc:
print(str(exc))
comm.Abort(1)
raise
other_psi = np.empty_like(sim.state.variables.psi)
comm.Recv(other_psi, 0)
np.testing.assert_allclose(sim.state.variables.psi, other_psi)
else:
sim.setup()
sim.run()
psi_global = gather(sim.state.variables.psi, sim.state.dimensions, ("xt", "yt"))
if rst.proc_rank == 0:
rs.mpi_comm.Get_parent().Send(np.array(psi_global), 0)
import os
import sys
import subprocess
import pytest
def run_dist_kernel(kernel):
pytest.importorskip("mpi4py")
here = os.path.dirname(__file__)
return subprocess.check_call(
[sys.executable, "-m", "mpi4py", os.path.join(here, kernel)], stderr=subprocess.STDOUT, timeout=300
)
def test_gather():
run_dist_kernel("gather_kernel.py")
def test_scatter():
run_dist_kernel("scatter_kernel.py")
def test_acc():
run_dist_kernel("acc_kernel.py")
@pytest.mark.parametrize("solver", ["scipy", "scipy_jax", "petsc"])
@pytest.mark.parametrize("streamfunction", [True, False])
def test_linear_solver(solver, streamfunction):
from veros import runtime_settings
if solver == "scipy_jax" and runtime_settings.backend != "jax":
pytest.skip("scipy_jax solver requires JAX")
kernel = "streamfunction_kernel.py" if streamfunction else "pressure_kernel.py"
orig_solver = os.environ.get("VEROS_LINEAR_SOLVER", "best")
try:
os.environ["VEROS_LINEAR_SOLVER"] = solver
run_dist_kernel(kernel)
finally:
os.environ["VEROS_LINEAR_SOLVER"] = orig_solver
import numpy as np
from mpi4py import MPI
from veros import runtime_settings as rs, runtime_state as rst
from veros.distributed import gather
if rst.proc_num == 1:
import sys
comm = MPI.COMM_SELF.Spawn(sys.executable, args=["-m", "mpi4py", sys.argv[-1]], maxprocs=4)
res = np.empty((8, 8))
comm.Recv(res, 0)
np.testing.assert_array_equal(
res,
np.array(
[
[0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0],
[0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0],
[0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0],
[0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0],
[1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0],
[1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0],
[1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0],
[1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0],
]
),
)
else:
rs.num_proc = (2, 2)
assert rst.proc_num == 4
from veros.core.operators import numpy as npx
dimensions = dict(xt=4, yt=4)
a = rst.proc_rank * npx.ones((6, 6))
b = gather(a, dimensions, ("xt", "yt"))
if rst.proc_rank == 0:
rs.mpi_comm.Get_parent().Send(np.array(b), 0)
import sys
import numpy as onp
from mpi4py import MPI
from veros import runtime_settings as rs, runtime_state as rst
rs.diskless_mode = True
if rst.proc_num > 1:
rs.num_proc = (2, 2)
assert rst.proc_num == 4
from veros.state import get_default_state, resize_dimension # noqa: E402
from veros.distributed import gather # noqa: E402
from veros.core.operators import numpy as npx, update, at # noqa: E402
from veros.core.external.solvers import get_linear_solver # noqa: E402
def get_inputs():
state = get_default_state()
settings = state.settings
with settings.unlock():
settings.nx = 100
settings.ny = 40
settings.nz = 1
settings.dt_tracer = 1800
settings.dt_mom = 1800
settings.enable_cyclic_x = True
settings.enable_streamfunction = False
state.initialize_variables()
resize_dimension(state, "isle", 1)
vs = state.variables
nx_local, ny_local = settings.nx // rs.num_proc[0], settings.ny // rs.num_proc[1]
idx_global = (
slice(rst.proc_idx[0] * nx_local, (rst.proc_idx[0] + 1) * nx_local + 4),
slice(rst.proc_idx[1] * ny_local, (rst.proc_idx[1] + 1) * ny_local + 4),
Ellipsis,
)
with vs.unlock():
vs.dxt = update(vs.dxt, at[...], 10e3)
vs.dxu = update(vs.dxu, at[...], 10e3)
vs.dyt = update(vs.dyt, at[...], 10e3)
vs.dyu = update(vs.dyu, at[...], 10e3)
h_global = npx.linspace(500, 2000, settings.nx + 4)[:, None] * npx.ones((settings.nx + 4, settings.ny + 4))
vs.hu = h_global[idx_global]
vs.hv = h_global[idx_global]
vs.cosu = update(vs.cosu, at[...], 1)
vs.cost = update(vs.cost, at[...], 1)
boundary_mask = npx.ones((settings.nx + 4, settings.ny + 4, settings.nz), dtype="bool")
boundary_mask = update(boundary_mask, at[:50, :2], 0)
boundary_mask = update(boundary_mask, at[20:30, 20:30], 0)
vs.maskT = boundary_mask[idx_global]
rhs = npx.ones_like(vs.hur)
x0 = npx.ones_like(vs.hur)
return state, rhs, x0
if rst.proc_num == 1:
comm = MPI.COMM_SELF.Spawn(sys.executable, args=["-m", "mpi4py", sys.argv[-1]], maxprocs=4)
try:
state, rhs, x0 = get_inputs()
sol = get_linear_solver(state)
psi = sol.solve(state, rhs, x0)
except Exception as exc:
print(str(exc))
comm.Abort(1)
raise
other_psi = onp.empty_like(psi)
comm.Recv(other_psi, 0)
onp.testing.assert_allclose(psi, other_psi)
else:
state, rhs, x0 = get_inputs()
sol = get_linear_solver(state)
psi = sol.solve(state, rhs, x0)
psi_global = gather(psi, state.dimensions, ("xt", "yt"))
if rst.proc_rank == 0:
rs.mpi_comm.Get_parent().Send(onp.array(psi_global), 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment