Commit 6d80b901 authored by Jonas Kaufmann's avatar Jonas Kaufmann Committed by Antoine Kaufmann
Browse files

ExperimentDistributedRunner: replace use of generic with overriding the type in the subclass

parent 0e69529e
...@@ -25,6 +25,7 @@ import itertools ...@@ -25,6 +25,7 @@ import itertools
import shlex import shlex
import traceback import traceback
import typing as tp import typing as tp
from abc import ABC, abstractmethod
from simbricks.exectools import Executor, SimpleComponent from simbricks.exectools import Executor, SimpleComponent
from simbricks.experiment.experiment_environment import ExpEnv from simbricks.experiment.experiment_environment import ExpEnv
...@@ -56,7 +57,8 @@ class Experiment(object): ...@@ -56,7 +57,8 @@ class Experiment(object):
simulator shut down, a check point created, and finally restored in the simulator shut down, a check point created, and finally restored in the
accurate mode using this checkpoint.""" accurate mode using this checkpoint."""
self.no_simbricks = False self.no_simbricks = False
"""If `true`, no simbricks adapters are used in any of the simulators.""" """If `true`, no simbricks adapters are used in any of the
simulators."""
self.hosts: tp.List[HostSim] = [] self.hosts: tp.List[HostSim] = []
"""The host simulators to run.""" """The host simulators to run."""
self.pcidevs: tp.List[PCIDevSim] = [] self.pcidevs: tp.List[PCIDevSim] = []
...@@ -142,12 +144,9 @@ class DistributedExperiment(Experiment): ...@@ -142,12 +144,9 @@ class DistributedExperiment(Experiment):
return True return True
T = tp.TypeVar('T', bound=Experiment) class ExperimentBaseRunner(ABC):
def __init__(self, exp: Experiment, env: ExpEnv, verbose: bool):
class ExperimentBaseRunner(tp.Generic[T]):
def __init__(self, exp: T, env: ExpEnv, verbose: bool):
self.exp = exp self.exp = exp
self.env = env self.env = env
self.verbose = verbose self.verbose = verbose
...@@ -156,8 +155,9 @@ class ExperimentBaseRunner(tp.Generic[T]): ...@@ -156,8 +155,9 @@ class ExperimentBaseRunner(tp.Generic[T]):
self.sockets = [] self.sockets = []
self.wait_sims = [] self.wait_sims = []
@abstractmethod
def sim_executor(self, sim: Simulator) -> Executor: def sim_executor(self, sim: Simulator) -> Executor:
raise NotImplementedError('Please implement this method') pass
def sim_graph(self): def sim_graph(self):
sims = self.exp.all_simulators() sims = self.exp.all_simulators()
...@@ -313,10 +313,10 @@ class ExperimentBaseRunner(tp.Generic[T]): ...@@ -313,10 +313,10 @@ class ExperimentBaseRunner(tp.Generic[T]):
return self.out return self.out
class ExperimentSimpleRunner(ExperimentBaseRunner[Experiment]): class ExperimentSimpleRunner(ExperimentBaseRunner):
"""Simple experiment runner with just one executor.""" """Simple experiment runner with just one executor."""
def __init__(self, executor: Executor, *args, **kwargs): def __init__(self, executor: Executor, exp: Experiment, *args, **kwargs):
self.executor = executor self.executor = executor
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -324,12 +324,13 @@ class ExperimentSimpleRunner(ExperimentBaseRunner[Experiment]): ...@@ -324,12 +324,13 @@ class ExperimentSimpleRunner(ExperimentBaseRunner[Experiment]):
return self.executor return self.executor
class ExperimentDistributedRunner(ExperimentBaseRunner[DistributedExperiment]): class ExperimentDistributedRunner(ExperimentBaseRunner):
"""Simple experiment runner with just one executor.""" """Simple experiment runner with just one executor."""
def __init__(self, execs, *args, **kwargs): def __init__(self, execs, exp: DistributedExperiment, *args, **kwargs):
self.execs = execs self.execs = execs
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.exp = exp # overrides the type in the base class
assert self.exp.num_hosts <= len(execs) assert self.exp.num_hosts <= len(execs)
def sim_executor(self, sim): def sim_executor(self, sim):
......
...@@ -33,8 +33,8 @@ from simbricks.runtime.common import Run, Runtime ...@@ -33,8 +33,8 @@ from simbricks.runtime.common import Run, Runtime
class DistributedSimpleRuntime(Runtime): class DistributedSimpleRuntime(Runtime):
def __init__(self, executors, verbose=False): def __init__(self, executors, verbose=False):
self.runnable = [] self.runnable: tp.List[Run] = []
self.complete = [] self.complete: tp.List[Run] = []
self.verbose = verbose self.verbose = verbose
self.executors = executors self.executors = executors
...@@ -46,7 +46,11 @@ class DistributedSimpleRuntime(Runtime): ...@@ -46,7 +46,11 @@ class DistributedSimpleRuntime(Runtime):
async def do_run(self, run: Run): async def do_run(self, run: Run):
runner = exp.ExperimentDistributedRunner( runner = exp.ExperimentDistributedRunner(
self.executors, run.experiment, run.env, self.verbose self.executors,
# we ensure the correct type in add_run()
tp.cast(exp.DistributedExperiment, run.experiment),
run.env,
self.verbose
) )
for executor in self.executors: for executor in self.executors:
await run.prep_dirs(executor) await run.prep_dirs(executor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment