distributed_test.py 1.1 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import sys
import subprocess

import pytest


def run_dist_kernel(kernel):
    pytest.importorskip("mpi4py")

    here = os.path.dirname(__file__)
    return subprocess.check_call(
        [sys.executable, "-m", "mpi4py", os.path.join(here, kernel)], stderr=subprocess.STDOUT, timeout=300
    )


def test_gather():
    run_dist_kernel("gather_kernel.py")


def test_scatter():
    run_dist_kernel("scatter_kernel.py")


def test_acc():
    run_dist_kernel("acc_kernel.py")


@pytest.mark.parametrize("solver", ["scipy", "scipy_jax", "petsc"])
@pytest.mark.parametrize("streamfunction", [True, False])
def test_linear_solver(solver, streamfunction):
    from veros import runtime_settings

    if solver == "scipy_jax" and runtime_settings.backend != "jax":
        pytest.skip("scipy_jax solver requires JAX")

    kernel = "streamfunction_kernel.py" if streamfunction else "pressure_kernel.py"
    orig_solver = os.environ.get("VEROS_LINEAR_SOLVER", "best")
    try:
        os.environ["VEROS_LINEAR_SOLVER"] = solver
        run_dist_kernel(kernel)
    finally:
        os.environ["VEROS_LINEAR_SOLVER"] = orig_solver