Commit 6b3fc27f authored by Jonas Kaufmann's avatar Jonas Kaufmann Committed by Antoine Kaufmann
Browse files

experiments/run.py: change imports to respect Python style guide

parent 0b3de18a
...@@ -30,25 +30,14 @@ import importlib.util ...@@ -30,25 +30,14 @@ import importlib.util
import json import json
import os import os
import pickle import pickle
import signal
import sys import sys
import typing as tp import typing as tp
from signal import SIGINT, signal
from simbricks.orchestration import exectools from simbricks.orchestration import exectools
from simbricks.orchestration.exectools import LocalExecutor, RemoteExecutor from simbricks.orchestration import experiments as exps
from simbricks.orchestration.experiment.experiment_environment import ExpEnv from simbricks.orchestration import runtime
from simbricks.orchestration.experiments import ( from simbricks.orchestration.experiment import experiment_environment
DistributedExperiment, Experiment
)
from simbricks.orchestration.runtime import common as rt_common
from simbricks.orchestration.runtime.common import Run
from simbricks.orchestration.runtime.distributed import (
DistributedSimpleRuntime, auto_dist
)
from simbricks.orchestration.runtime.local import (
LocalParallelRuntime, LocalSimpleRuntime
)
from simbricks.orchestration.runtime.slurm import SlurmRuntime
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
...@@ -238,9 +227,9 @@ def load_executors(path: str) -> tp.List[exectools.Executor]: ...@@ -238,9 +227,9 @@ def load_executors(path: str) -> tp.List[exectools.Executor]:
exs = [] exs = []
for h in hosts: for h in hosts:
if h['type'] == 'local': if h['type'] == 'local':
ex = LocalExecutor() ex = exectools.LocalExecutor()
elif h['type'] == 'remote': elif h['type'] == 'remote':
ex = RemoteExecutor(h['host'], h['workdir']) ex = exectools.RemoteExecutor(h['host'], h['workdir'])
if 'ssh_args' in h: if 'ssh_args' in h:
ex.ssh_extra_args += h['ssh_args'] ex.ssh_extra_args += h['ssh_args']
if 'scp_args' in h: if 'scp_args' in h:
...@@ -260,12 +249,11 @@ def warn_multi_exec(executors: tp.List[exectools.Executor]): ...@@ -260,12 +249,11 @@ def warn_multi_exec(executors: tp.List[exectools.Executor]):
) )
# pylint: disable=redefined-outer-name
def add_exp( def add_exp(
e: Experiment, e: exps.Experiment,
rt: rt_common.Runtime, rt: runtime.Runtime,
run: int, run: int,
prereq: tp.Optional[Run], prereq: tp.Optional[runtime.Run],
create_cp: bool, create_cp: bool,
restore_cp: bool, restore_cp: bool,
no_simbricks: bool, no_simbricks: bool,
...@@ -281,7 +269,7 @@ def add_exp( ...@@ -281,7 +269,7 @@ def add_exp(
if args.shmdir is not None: if args.shmdir is not None:
shmdir = f'{args.shmdir}/{e.name}/{run}' shmdir = f'{args.shmdir}/{e.name}/{run}'
env = ExpEnv(args.repo, workdir, cpdir) env = experiment_environment.ExpEnv(args.repo, workdir, cpdir)
env.create_cp = create_cp env.create_cp = create_cp
env.restore_cp = restore_cp env.restore_cp = restore_cp
env.no_simbricks = no_simbricks env.no_simbricks = no_simbricks
...@@ -291,7 +279,7 @@ def add_exp( ...@@ -291,7 +279,7 @@ def add_exp(
if args.shmdir is not None: if args.shmdir is not None:
env.shm_base = os.path.abspath(shmdir) env.shm_base = os.path.abspath(shmdir)
run = Run(e, run, env, outpath, prereq) run = runtime.Run(e, run, env, outpath, prereq)
rt.add_run(run) rt.add_run(run)
return run return run
...@@ -306,19 +294,21 @@ def main(): ...@@ -306,19 +294,21 @@ def main():
# initialize runtime # initialize runtime
if args.runtime == 'parallel': if args.runtime == 'parallel':
warn_multi_exec(executors) warn_multi_exec(executors)
rt = LocalParallelRuntime( rt = runtime.LocalParallelRuntime(
cores=args.cores, cores=args.cores,
mem=args.mem, mem=args.mem,
verbose=args.verbose, verbose=args.verbose,
executor=executors[0] executor=executors[0]
) )
elif args.runtime == 'slurm': elif args.runtime == 'slurm':
rt = SlurmRuntime(args.slurmdir, args, verbose=args.verbose) rt = runtime.SlurmRuntime(args.slurmdir, args, verbose=args.verbose)
elif args.runtime == 'dist': elif args.runtime == 'dist':
rt = DistributedSimpleRuntime(executors, verbose=args.verbose) rt = runtime.DistributedSimpleRuntime(executors, verbose=args.verbose)
else: else:
warn_multi_exec(executors) warn_multi_exec(executors)
rt = LocalSimpleRuntime(verbose=args.verbose, executor=executors[0]) rt = runtime.LocalSimpleRuntime(
verbose=args.verbose, executor=executors[0]
)
# load experiments # load experiments
if not args.pickled: if not args.pickled:
...@@ -345,8 +335,8 @@ def main(): ...@@ -345,8 +335,8 @@ def main():
sys.exit(0) sys.exit(0)
for e in experiments: for e in experiments:
if args.auto_dist and not isinstance(e, DistributedExperiment): if args.auto_dist and not isinstance(e, exps.DistributedExperiment):
e = auto_dist(e, executors, args.proxy_type) e = runtime.auto_dist(e, executors, args.proxy_type)
# apply filter if any specified # apply filter if any specified
if (args.filter) and (len(args.filter) > 0): if (args.filter) and (len(args.filter) > 0):
match = False match = False
...@@ -379,7 +369,7 @@ def main(): ...@@ -379,7 +369,7 @@ def main():
rt.add_run(pickle.load(f)) rt.add_run(pickle.load(f))
# register interrupt handler # register interrupt handler
signal(SIGINT, lambda *_: rt.interrupt()) signal.signal(signal.SIGINT, lambda *_: rt.interrupt())
# invoke runtime to run experiments # invoke runtime to run experiments
asyncio.run(rt.start()) asyncio.run(rt.start())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment