Commit 6d80b901 authored by Jonas Kaufmann's avatar Jonas Kaufmann Committed by Antoine Kaufmann
Browse files

ExperimentDistributedRunner: replace use of generic with overriding the type in the subclass

parent 0e69529e
......@@ -25,6 +25,7 @@ import itertools
import shlex
import traceback
import typing as tp
from abc import ABC, abstractmethod
from simbricks.exectools import Executor, SimpleComponent
from simbricks.experiment.experiment_environment import ExpEnv
......@@ -56,7 +57,8 @@ class Experiment(object):
simulator shut down, a check point created, and finally restored in the
accurate mode using this checkpoint."""
self.no_simbricks = False
"""If `true`, no simbricks adapters are used in any of the simulators."""
"""If `true`, no simbricks adapters are used in any of the
simulators."""
self.hosts: tp.List[HostSim] = []
"""The host simulators to run."""
self.pcidevs: tp.List[PCIDevSim] = []
......@@ -142,12 +144,9 @@ class DistributedExperiment(Experiment):
return True
T = tp.TypeVar('T', bound=Experiment)
class ExperimentBaseRunner(ABC):
class ExperimentBaseRunner(tp.Generic[T]):
def __init__(self, exp: T, env: ExpEnv, verbose: bool):
def __init__(self, exp: Experiment, env: ExpEnv, verbose: bool):
self.exp = exp
self.env = env
self.verbose = verbose
......@@ -156,8 +155,9 @@ class ExperimentBaseRunner(tp.Generic[T]):
self.sockets = []
self.wait_sims = []
@abstractmethod
def sim_executor(self, sim: Simulator) -> Executor:
raise NotImplementedError('Please implement this method')
pass
def sim_graph(self):
sims = self.exp.all_simulators()
......@@ -313,10 +313,10 @@ class ExperimentBaseRunner(tp.Generic[T]):
return self.out
class ExperimentSimpleRunner(ExperimentBaseRunner[Experiment]):
class ExperimentSimpleRunner(ExperimentBaseRunner):
"""Simple experiment runner with just one executor."""
def __init__(self, executor: Executor, *args, **kwargs):
def __init__(self, executor: Executor, exp: Experiment, *args, **kwargs):
self.executor = executor
super().__init__(*args, **kwargs)
......@@ -324,12 +324,13 @@ class ExperimentSimpleRunner(ExperimentBaseRunner[Experiment]):
return self.executor
class ExperimentDistributedRunner(ExperimentBaseRunner[DistributedExperiment]):
class ExperimentDistributedRunner(ExperimentBaseRunner):
"""Simple experiment runner with just one executor."""
def __init__(self, execs, *args, **kwargs):
def __init__(self, execs, exp: DistributedExperiment, *args, **kwargs):
self.execs = execs
super().__init__(*args, **kwargs)
self.exp = exp # overrides the type in the base class
assert self.exp.num_hosts <= len(execs)
def sim_executor(self, sim):
......
......@@ -33,8 +33,8 @@ from simbricks.runtime.common import Run, Runtime
class DistributedSimpleRuntime(Runtime):
def __init__(self, executors, verbose=False):
self.runnable = []
self.complete = []
self.runnable: tp.List[Run] = []
self.complete: tp.List[Run] = []
self.verbose = verbose
self.executors = executors
......@@ -46,7 +46,11 @@ class DistributedSimpleRuntime(Runtime):
async def do_run(self, run: Run):
runner = exp.ExperimentDistributedRunner(
self.executors, run.experiment, run.env, self.verbose
self.executors,
# we ensure the correct type in add_run()
tp.cast(exp.DistributedExperiment, run.experiment),
run.env,
self.verbose
)
for executor in self.executors:
await run.prep_dirs(executor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment