benchmark_base.py 1.31 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
import click


def benchmark_cli(func):
    @click.option("--size", type=int, nargs=3, required=True)
    @click.option("--timesteps", type=int, required=True)
    @click.option("-f", "--pyom2-lib", type=click.Path(readable=True, dir_okay=False), default=None)
    @click.option("-b", "--backend", type=click.Choice(["numpy", "jax"]), default="numpy")
    @click.option("-d", "--device", type=click.Choice(["cpu", "gpu"]), default="cpu")
    @click.option("-n", "--nproc", type=int, nargs=2, default=(1, 1))
    @click.option("--float-type", type=click.Choice(["float64", "float32"]), default="float64")
    @click.option("-v", "--loglevel", type=click.Choice(["debug", "trace"]), default="debug")
    @click.option("--profile-mode", is_flag=True)
    @click.command()
    def inner(backend, device, nproc, float_type, loglevel, profile_mode, **kwargs):
        from veros import runtime_settings, runtime_state

        runtime_settings.update(
            backend=backend,
            device=device,
            float_type=float_type,
            num_proc=nproc,
            loglevel=loglevel,
            profile_mode=profile_mode,
        )

        if device == "gpu" and runtime_state.proc_num > 1:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(runtime_state.proc_rank)

        return func(**kwargs)

    return inner