Unverified Commit 40d36eb0 authored by Jakob Görgen's avatar Jakob Görgen
Browse files

symphony/runtime: added output listeners to command_executor Component

parent aef72737
...@@ -20,11 +20,12 @@ ...@@ -20,11 +20,12 @@
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import abc import abc
import asyncio import asyncio
import os import os
import pathlib import pathlib
import re
import shlex import shlex
import shutil import shutil
import signal import signal
...@@ -32,35 +33,89 @@ import typing as tp ...@@ -32,35 +33,89 @@ import typing as tp
from asyncio.subprocess import Process from asyncio.subprocess import Process
class OutputListener:
def __init__(self):
self.cmd_parts: list[str] = []
@abc.abstractmethod
async def handel_err(self, lines: list[str]) -> None:
pass
@abc.abstractmethod
async def handel_out(self, lines: list[str]) -> None:
pass
def toJSON(self) -> dict:
return {
"cmd": self.cmd_parts,
}
class LegacyOutputListener(OutputListener):
def __init__(self):
super().__init__()
self.stdout: list[str] = []
self.stderr: list[str] = []
self.merged_output: list[str] = []
def _add_to_lists(self, extend: list[str], to_add_to: list[str]) -> None:
if isinstance(extend, list):
to_add_to.extend(extend)
self.merged_output.extend(extend)
else:
raise Exception("ComponentOutputHandle: can only add str or list[str] to outputs")
async def handel_out(self, lines: list[str]) -> None:
self._add_to_lists(extend=lines, to_add_to=self.stdout)
async def handel_err(self, lines: list[str]) -> None:
self._add_to_lists(extend=lines, to_add_to=self.stderr)
def toJSON(self) -> dict:
json_obj = super().toJSON()
json_obj.update(
{
"stdout": self.stdout,
"stderr": self.stderr,
"merged_output": self.merged_output,
}
)
return json_obj
class Component(object): class Component(object):
def __init__(self, cmd_parts: tp.List[str], with_stdin=False): def __init__(self, cmd_parts: tp.List[str], with_stdin=False):
self.is_ready = False self.is_ready = False
self.stdout: tp.List[str] = []
self.stdout_buf = bytearray() self.stdout_buf = bytearray()
self.stderr: tp.List[str] = []
self.stderr_buf = bytearray() self.stderr_buf = bytearray()
self.cmd_parts = cmd_parts self.cmd_parts: list[str] = cmd_parts
#print(cmd_parts) self._output_handler: list[OutputListener] = []
self.with_stdin = with_stdin self.with_stdin: bool = with_stdin
self._proc: Process self._proc: Process
self._terminate_future: asyncio.Task self._terminate_future: asyncio.Task
def subscribe(self, listener: OutputListener) -> None:
listener.cmd_parts = self.cmd_parts
self._output_handler.append(listener)
def _parse_buf(self, buf: bytearray, data: bytes) -> tp.List[str]: def _parse_buf(self, buf: bytearray, data: bytes) -> tp.List[str]:
if data is not None: if data is not None:
buf.extend(data) buf.extend(data)
lines = [] lines = []
start = 0 start = 0
for i in range(0, len(buf)): for i in range(0, len(buf)):
if buf[i] == ord('\n'): if buf[i] == ord("\n"):
l = buf[start:i].decode('utf-8') l = buf[start:i].decode("utf-8")
lines.append(l) lines.append(l)
start = i + 1 start = i + 1
del buf[0:start] del buf[0:start]
if len(data) == 0 and len(buf) > 0: if len(data) == 0 and len(buf) > 0:
lines.append(buf.decode('utf-8')) lines.append(buf.decode("utf-8"))
return lines return lines
async def _consume_out(self, data: bytes) -> None: async def _consume_out(self, data: bytes) -> None:
...@@ -68,14 +123,16 @@ class Component(object): ...@@ -68,14 +123,16 @@ class Component(object):
ls = self._parse_buf(self.stdout_buf, data) ls = self._parse_buf(self.stdout_buf, data)
if len(ls) > 0 or eof: if len(ls) > 0 or eof:
await self.process_out(ls, eof=eof) await self.process_out(ls, eof=eof)
self.stdout.extend(ls) for h in self._output_handler:
await h.handel_out(ls)
async def _consume_err(self, data: bytes) -> None: async def _consume_err(self, data: bytes) -> None:
eof = len(data) == 0 eof = len(data) == 0
ls = self._parse_buf(self.stderr_buf, data) ls = self._parse_buf(self.stderr_buf, data)
if len(ls) > 0 or eof: if len(ls) > 0 or eof:
await self.process_err(ls, eof=eof) await self.process_err(ls, eof=eof)
self.stderr.extend(ls) for h in self._output_handler:
await h.handel_err(ls)
async def _read_stream(self, stream: asyncio.StreamReader, fn): async def _read_stream(self, stream: asyncio.StreamReader, fn):
while True: while True:
...@@ -87,12 +144,8 @@ class Component(object): ...@@ -87,12 +144,8 @@ class Component(object):
return return
async def _waiter(self) -> None: async def _waiter(self) -> None:
stdout_handler = asyncio.create_task( stdout_handler = asyncio.create_task(self._read_stream(self._proc.stdout, self._consume_out))
self._read_stream(self._proc.stdout, self._consume_out) stderr_handler = asyncio.create_task(self._read_stream(self._proc.stderr, self._consume_err))
)
stderr_handler = asyncio.create_task(
self._read_stream(self._proc.stderr, self._consume_err)
)
rc = await self._proc.wait() rc = await self._proc.wait()
await asyncio.gather(stdout_handler, stderr_handler) await asyncio.gather(stdout_handler, stderr_handler)
await self.terminated(rc) await self.terminated(rc)
...@@ -150,22 +203,14 @@ class Component(object): ...@@ -150,22 +203,14 @@ class Component(object):
return return
# before Python 3.11, asyncio.wait_for() throws asyncio.TimeoutError -_- # before Python 3.11, asyncio.wait_for() throws asyncio.TimeoutError -_-
except (TimeoutError, asyncio.TimeoutError): except (TimeoutError, asyncio.TimeoutError):
print( print(f"terminating component {self.cmd_parts[0]} " f"pid {self._proc.pid}", flush=True)
f'terminating component {self.cmd_parts[0]} '
f'pid {self._proc.pid}',
flush=True
)
await self.terminate() await self.terminate()
try: try:
await asyncio.wait_for(self._proc.wait(), delay) await asyncio.wait_for(self._proc.wait(), delay)
return return
except (TimeoutError, asyncio.TimeoutError): except (TimeoutError, asyncio.TimeoutError):
print( print(f"killing component {self.cmd_parts[0]} " f"pid {self._proc.pid}", flush=True)
f'killing component {self.cmd_parts[0]} '
f'pid {self._proc.pid}',
flush=True
)
await self.kill() await self.kill()
await self._proc.wait() await self._proc.wait()
...@@ -189,15 +234,7 @@ class Component(object): ...@@ -189,15 +234,7 @@ class Component(object):
class SimpleComponent(Component): class SimpleComponent(Component):
def __init__( def __init__(self, label: str, cmd_parts: tp.List[str], *args, verbose=True, canfail=False, **kwargs) -> None:
self,
label: str,
cmd_parts: tp.List[str],
*args,
verbose=True,
canfail=False,
**kwargs
) -> None:
self.label = label self.label = label
self.verbose = verbose self.verbose = verbose
self.canfail = canfail self.canfail = canfail
...@@ -207,109 +244,18 @@ class SimpleComponent(Component): ...@@ -207,109 +244,18 @@ class SimpleComponent(Component):
async def process_out(self, lines: tp.List[str], eof: bool) -> None: async def process_out(self, lines: tp.List[str], eof: bool) -> None:
if self.verbose: if self.verbose:
for _ in lines: for _ in lines:
print(self.label, 'OUT:', lines, flush=True) print(self.label, "OUT:", lines, flush=True)
async def process_err(self, lines: tp.List[str], eof: bool) -> None: async def process_err(self, lines: tp.List[str], eof: bool) -> None:
if self.verbose: if self.verbose:
for _ in lines: for _ in lines:
print(self.label, 'ERR:', lines, flush=True) print(self.label, "ERR:", lines, flush=True)
async def terminated(self, rc: int) -> None: async def terminated(self, rc: int) -> None:
if self.verbose: if self.verbose:
print(self.label, 'TERMINATED:', rc, flush=True) print(self.label, "TERMINATED:", rc, flush=True)
if not self.canfail and rc != 0: if not self.canfail and rc != 0:
raise RuntimeError('Command Failed: ' + str(self.cmd_parts)) raise RuntimeError("Command Failed: " + str(self.cmd_parts))
class SimpleRemoteComponent(SimpleComponent):
def __init__(
self,
host_name: str,
label: str,
cmd_parts: tp.List[str],
*args,
cwd: tp.Optional[str] = None,
ssh_extra_args: tp.Optional[tp.List[str]] = None,
**kwargs
) -> None:
if ssh_extra_args is None:
ssh_extra_args = []
self.host_name = host_name
self.extra_flags = ssh_extra_args
# add a wrapper to print the PID
remote_parts = ['echo', 'PID', '$$', '&&']
if cwd is not None:
# if necessary add a CD command
remote_parts += ['cd', cwd, '&&']
# escape actual command parts
cmd_parts = list(map(shlex.quote, cmd_parts))
# use exec to make sure the command we run keeps the PIDS
remote_parts += ['exec'] + cmd_parts
# wrap up command in ssh invocation
parts = self._ssh_cmd(remote_parts)
super().__init__(label, parts, *args, **kwargs)
self._pid_fut: tp.Optional[asyncio.Future] = None
def _ssh_cmd(self, parts: tp.List[str]) -> tp.List[str]:
"""SSH invocation of command for this host."""
return [
'ssh',
'-o',
'UserKnownHostsFile=/dev/null',
'-o',
'StrictHostKeyChecking=no'
] + self.extra_flags + [self.host_name, '--'] + parts
async def start(self) -> None:
"""Start this command (includes waiting for its pid)."""
self._pid_fut = asyncio.get_running_loop().create_future()
await super().start()
await self._pid_fut
async def process_out(self, lines: tp.List[str], eof: bool) -> None:
"""Scans output and set PID future once PID line found."""
if not self._pid_fut.done():
newlines = []
pid_re = re.compile(r'^PID\s+(\d+)\s*$')
for l in lines:
m = pid_re.match(l)
if m:
pid = int(m.group(1))
self._pid_fut.set_result(pid)
else:
newlines.append(l)
lines = newlines
if eof and not self._pid_fut.done():
# cancel PID future if it's not going to happen
print('PID not found but EOF already found:', self.label)
self._pid_fut.cancel()
await super().process_out(lines, eof)
async def _kill_cmd(self, sig: str) -> None:
"""Send signal to command by running ssh kill -$sig $PID."""
cmd_parts = self._ssh_cmd([
'kill', '-' + sig, str(self._pid_fut.result())
])
proc = await asyncio.create_subprocess_exec(*cmd_parts)
await proc.wait()
async def interrupt(self) -> None:
await self._kill_cmd('INT')
async def terminate(self) -> None:
await self._kill_cmd('TERM')
async def kill(self) -> None:
await self._kill_cmd('KILL')
class Executor(abc.ABC): class Executor(abc.ABC):
...@@ -318,9 +264,7 @@ class Executor(abc.ABC): ...@@ -318,9 +264,7 @@ class Executor(abc.ABC):
self.ip = None self.ip = None
@abc.abstractmethod @abc.abstractmethod
def create_component( def create_component(self, label: str, parts: tp.List[str], **kwargs) -> SimpleComponent:
self, label: str, parts: tp.List[str], **kwargs
) -> SimpleComponent:
pass pass
@abc.abstractmethod @abc.abstractmethod
...@@ -340,14 +284,10 @@ class Executor(abc.ABC): ...@@ -340,14 +284,10 @@ class Executor(abc.ABC):
pass pass
# runs the list of commands as strings sequentially # runs the list of commands as strings sequentially
async def run_cmdlist( async def run_cmdlist(self, label: str, cmds: tp.List[str], verbose=True) -> None:
self, label: str, cmds: tp.List[str], verbose=True
) -> None:
i = 0 i = 0
for cmd in cmds: for cmd in cmds:
cmd_c = self.create_component( cmd_c = self.create_component(label + "." + str(i), shlex.split(cmd), verbose=verbose)
label + '.' + str(i), shlex.split(cmd), verbose=verbose
)
await cmd_c.start() await cmd_c.start()
await cmd_c.wait() await cmd_c.wait()
...@@ -362,16 +302,12 @@ class Executor(abc.ABC): ...@@ -362,16 +302,12 @@ class Executor(abc.ABC):
class LocalExecutor(Executor): class LocalExecutor(Executor):
def create_component( def create_component(self, label: str, parts: list[str], **kwargs) -> SimpleComponent:
self, label: str, parts: tp.List[str], **kwargs
) -> SimpleComponent:
return SimpleComponent(label, parts, **kwargs) return SimpleComponent(label, parts, **kwargs)
async def await_file( async def await_file(self, path: str, delay=0.05, verbose=False, timeout=30) -> None:
self, path: str, delay=0.05, verbose=False, timeout=30
) -> None:
if verbose: if verbose:
print(f'await_file({path})') print(f"await_file({path})")
t = 0 t = 0
while not os.path.exists(path): while not os.path.exists(path):
if t >= timeout: if t >= timeout:
...@@ -391,86 +327,3 @@ class LocalExecutor(Executor): ...@@ -391,86 +327,3 @@ class LocalExecutor(Executor):
shutil.rmtree(path, ignore_errors=True) shutil.rmtree(path, ignore_errors=True)
elif os.path.exists(path): elif os.path.exists(path):
os.unlink(path) os.unlink(path)
class RemoteExecutor(Executor):
def __init__(self, host_name: str, workdir: str) -> None:
super().__init__()
self.host_name = host_name
self.cwd = workdir
self.ssh_extra_args = []
self.scp_extra_args = []
def create_component(
self, label: str, parts: tp.List[str], **kwargs
) -> SimpleRemoteComponent:
return SimpleRemoteComponent(
self.host_name,
label,
parts,
cwd=self.cwd,
ssh_extra_args=self.ssh_extra_args,
**kwargs
)
async def await_file(
self, path: str, delay=0.05, verbose=False, timeout=30
) -> None:
if verbose:
print(f'{self.host_name}.await_file({path}) started')
to_its = timeout / delay
loop_cmd = (
f'i=0 ; while [ ! -e {path} ] ; do '
f'if [ $i -ge {to_its:u} ] ; then exit 1 ; fi ; '
f'sleep {delay} ; '
'i=$(($i+1)) ; done; exit 0'
) % (path, to_its, delay)
parts = ['/bin/sh', '-c', loop_cmd]
sc = self.create_component(
f"{self.host_name}.await_file('{path}')",
parts,
canfail=False,
verbose=verbose
)
await sc.start()
await sc.wait()
# TODO: Implement opitimized await_files()
async def send_file(self, path: str, verbose=False) -> None:
parts = [
'scp',
'-o',
'UserKnownHostsFile=/dev/null',
'-o',
'StrictHostKeyChecking=no'
] + self.scp_extra_args + [path, f'{self.host_name}:{path}']
sc = SimpleComponent(
f'{self.host_name}.send_file("{path}")',
parts,
canfail=False,
verbose=verbose
)
await sc.start()
await sc.wait()
async def mkdir(self, path: str, verbose=False) -> None:
sc = self.create_component(
f"{self.host_name}.mkdir('{path}')", ['mkdir', '-p', path],
canfail=False,
verbose=verbose
)
await sc.start()
await sc.wait()
async def rmtree(self, path: str, verbose=False) -> None:
sc = self.create_component(
f'{self.host_name}.rmtree("{path}")', ['rm', '-rf', path],
canfail=False,
verbose=verbose
)
await sc.start()
await sc.wait()
...@@ -31,17 +31,21 @@ from simbricks.runtime import command_executor ...@@ -31,17 +31,21 @@ from simbricks.runtime import command_executor
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from simbricks.orchestration.simulation import base as sim_base from simbricks.orchestration.simulation import base as sim_base
class SimulationOutput: class SimulationOutput:
"""Manages an experiment's output.""" """Manages an experiment's output."""
def __init__(self, sim: sim_base.Simulation) -> None: def __init__(self, sim: sim_base.Simulation) -> None:
self._sim_name: str = sim.name self._sim_name: str = sim.name
self._start_time: float = None self._start_time: float | None = None
self._end_time: float = None self._end_time: float | None = None
self._success: bool = True self._success: bool = True
self._interrupted: bool = False self._interrupted: bool = False
self._metadata = sim.metadata self._metadata = sim.metadata
self._sims: dict[str, dict[str, str | list[str]]] = {} self._sims: dict[sim_base.Simulator, command_executor.OutputListener] = {}
def is_ended(self) -> bool:
return self._end_time or self._interrupted
def set_start(self) -> None: def set_start(self) -> None:
self._start_time = time.time() self._start_time = time.time()
...@@ -56,22 +60,38 @@ class SimulationOutput: ...@@ -56,22 +60,38 @@ class SimulationOutput:
self._success = False self._success = False
self._interrupted = True self._interrupted = True
def add_sim(self, sim: sim_base.Simulator, comp: command_executor.Component) -> None: def add_mapping(self, sim: sim_base.Simulator, output_handel: command_executor.OutputListener) -> None:
obj = { assert sim not in self._sims
"class": sim.__class__.__name__, self._sims[sim] = output_handel
"cmd": comp.cmd_parts,
"stdout": comp.stdout, def get_output_listener(self, sim: sim_base.Simulator) -> command_executor.OutputListener:
"stderr": comp.stderr, if sim not in self._sims:
} raise Exception("not output handel for simulator found")
self._sims[sim.full_name()] = obj return self._sims[sim]
def get_all_listeners(self) -> list[command_executor.OutputListener]:
return list(self._sims.values())
def toJSON(self) -> dict:
json_obj = {}
json_obj["_sim_name"] = self._sim_name
json_obj["_start_time"] = self._start_time
json_obj["_end_time"] = self._end_time
json_obj["_success"] = self._success
json_obj["_interrupted"] = self._interrupted
json_obj["_metadata"] = self._metadata
for sim, out in self._sims.items():
json_obj[sim.full_name()] = out.toJSON()
json_obj["class"] = sim.__class__.__name__
return json_obj
def dump(self, outpath: str) -> None: def dump(self, outpath: str) -> None:
json_obj = self.toJSON()
pathlib.Path(outpath).parent.mkdir(parents=True, exist_ok=True) pathlib.Path(outpath).parent.mkdir(parents=True, exist_ok=True)
with open(outpath, "w", encoding="utf-8") as file: with open(outpath, "w", encoding="utf-8") as file:
json.dump(self.__dict__, file, indent=4) json.dump(json_obj, file, indent=4)
def load(self, file: str) -> None:
with open(file, "r", encoding="utf-8") as fp:
for k, v in json.load(fp).items():
self.__dict__[k] = v
# def load(self, file: str) -> None:
# with open(file, "r", encoding="utf-8") as fp:
# for k, v in json.load(fp).items():
# self.__dict__[k] = v
...@@ -26,7 +26,7 @@ from __future__ import annotations ...@@ -26,7 +26,7 @@ from __future__ import annotations
import itertools import itertools
import abc import abc
from simbricks.orchestration.simulation import output from simbricks.runtime import output
from simbricks.orchestration.instantiation import base as inst_base from simbricks.orchestration.instantiation import base as inst_base
......
...@@ -24,6 +24,7 @@ from __future__ import annotations ...@@ -24,6 +24,7 @@ from __future__ import annotations
import asyncio import asyncio
from simbricks.utils import artifatcs as art
from simbricks.runtime import simulation_executor from simbricks.runtime import simulation_executor
from simbricks.runtime import command_executor from simbricks.runtime import command_executor
from simbricks.runtime.runs import base as run_base from simbricks.runtime.runs import base as run_base
...@@ -49,13 +50,14 @@ class LocalSimpleRuntime(run_base.Runtime): ...@@ -49,13 +50,14 @@ class LocalSimpleRuntime(run_base.Runtime):
async def do_run(self, run: run_base.Run) -> None: async def do_run(self, run: run_base.Run) -> None:
"""Actually executes `run`.""" """Actually executes `run`."""
try: try:
runner = simulation_executor.SimulationSimpleRunner( runner = simulation_executor.SimulationSimpleRunner(self._executor, run.instantiation, self._verbose)
self._executor, run.instantiation, self._verbose
)
if self._profile_int: if self._profile_int:
runner.profile_int = self.profile_int runner.profile_int = self.profile_int
await runner.prepare() await runner.prepare()
for sim in run.instantiation.simulation.all_simulators():
runner.add_listener(sim, command_executor.LegacyOutputListener())
except asyncio.CancelledError: except asyncio.CancelledError:
# it is safe to just exit here because we are not running any # it is safe to just exit here because we are not running any
# simulators yet # simulators yet
...@@ -68,8 +70,13 @@ class LocalSimpleRuntime(run_base.Runtime): ...@@ -68,8 +70,13 @@ class LocalSimpleRuntime(run_base.Runtime):
if self._verbose: if self._verbose:
print(f"Writing collected output of run {run.name()} to JSON file ...") print(f"Writing collected output of run {run.name()} to JSON file ...")
# dump output into a file and then, before cleanup, create an artifact
output_path = run.instantiation.get_simulation_output_path() output_path = run.instantiation.get_simulation_output_path()
run._output.dump(outpath=output_path) run._output.dump(outpath=output_path)
if run.instantiation.create_artifact:
art.create_artifact(
artifact_name=run.instantiation.artifact_name, paths_to_include=run.instantiation.artifact_paths
)
await runner.cleanup() await runner.cleanup()
...@@ -115,10 +122,7 @@ class LocalParallelRuntime(run_base.Runtime): ...@@ -115,10 +122,7 @@ class LocalParallelRuntime(run_base.Runtime):
if run.instantiation.simulation.resreq_cores() > self._cores: if run.instantiation.simulation.resreq_cores() > self._cores:
raise RuntimeError("Not enough cores available for run") raise RuntimeError("Not enough cores available for run")
if ( if self._mem is not None and run.instantiation.simulation.resreq_mem() > self._mem:
self._mem is not None
and run.instantiation.simulation.resreq_mem() > self._mem
):
raise RuntimeError("Not enough memory available for run") raise RuntimeError("Not enough memory available for run")
if run._prereq is None: if run._prereq is None:
...@@ -129,12 +133,12 @@ class LocalParallelRuntime(run_base.Runtime): ...@@ -129,12 +133,12 @@ class LocalParallelRuntime(run_base.Runtime):
async def do_run(self, run: run_base.Run) -> run_base.Run | None: async def do_run(self, run: run_base.Run) -> run_base.Run | None:
"""Actually executes `run`.""" """Actually executes `run`."""
try: try:
runner = simulation_executor.SimulationSimpleRunner( runner = simulation_executor.SimulationSimpleRunner(self._executor, run.instantiation, self._verbose)
self._executor, run.instantiation, self._verbose
)
if self._profile_int is not None: if self._profile_int is not None:
runner._profile_int = self._profile_int runner._profile_int = self._profile_int
await runner.prepare() await runner.prepare()
for sim in run.instantiation.simulation.all_simulators():
runner.add_listener(sim, command_executor.LegacyOutputListener())
except asyncio.CancelledError: except asyncio.CancelledError:
# it is safe to just exit here because we are not running any # it is safe to just exit here because we are not running any
# simulators yet # simulators yet
...@@ -159,9 +163,7 @@ class LocalParallelRuntime(run_base.Runtime): ...@@ -159,9 +163,7 @@ class LocalParallelRuntime(run_base.Runtime):
"""Wait for any run to terminate and return.""" """Wait for any run to terminate and return."""
assert self._pending_jobs assert self._pending_jobs
done, self._pending_jobs = await asyncio.wait( done, self._pending_jobs = await asyncio.wait(self._pending_jobs, return_when=asyncio.FIRST_COMPLETED)
self._pending_jobs, return_when=asyncio.FIRST_COMPLETED
)
for r_awaitable in done: for r_awaitable in done:
run = await r_awaitable run = await r_awaitable
...@@ -171,9 +173,7 @@ class LocalParallelRuntime(run_base.Runtime): ...@@ -171,9 +173,7 @@ class LocalParallelRuntime(run_base.Runtime):
def enough_resources(self, run: run_base.Run) -> bool: def enough_resources(self, run: run_base.Run) -> bool:
"""Check if enough cores and mem are available for the run.""" """Check if enough cores and mem are available for the run."""
simulation = ( simulation = run.instantiation.simulation # pylint: disable=redefined-outer-name
run.instantiation.simulation
) # pylint: disable=redefined-outer-name
if self._cores is not None: if self._cores is not None:
enough_cores = (self._cores - self._cores_used) >= simulation.resreq_cores() enough_cores = (self._cores - self._cores_used) >= simulation.resreq_cores()
......
...@@ -27,8 +27,7 @@ import shlex ...@@ -27,8 +27,7 @@ import shlex
import traceback import traceback
import abc import abc
from simbricks.utils import artifatcs as art from simbricks.runtime import output
from simbricks.orchestration.simulation import output
from simbricks.orchestration.simulation import base as sim_base from simbricks.orchestration.simulation import base as sim_base
from simbricks.orchestration.instantiation import base as inst_base from simbricks.orchestration.instantiation import base as inst_base
from simbricks.runtime import command_executor from simbricks.runtime import command_executor
...@@ -45,10 +44,9 @@ class SimulationBaseRunner(abc.ABC): ...@@ -45,10 +44,9 @@ class SimulationBaseRunner(abc.ABC):
self._instantiation: inst_base.Instantiation = instantiation self._instantiation: inst_base.Instantiation = instantiation
self._verbose: bool = verbose self._verbose: bool = verbose
self._profile_int: int | None = None self._profile_int: int | None = None
self._out = output.SimulationOutput(self._instantiation.simulation) self._out: output.SimulationOutput = output.SimulationOutput(self._instantiation.simulation)
self._running: list[ self._out_listener: dict[sim_base.Simulator, command_executor.OutputListener] = {}
tuple[sim_base.Simulator, command_executor.SimpleComponent] self._running: list[tuple[sim_base.Simulator, command_executor.SimpleComponent]] = []
] = []
self._sockets: list[inst_base.Socket] = [] self._sockets: list[inst_base.Socket] = []
self._wait_sims: list[command_executor.Component] = [] self._wait_sims: list[command_executor.Component] = []
...@@ -56,6 +54,15 @@ class SimulationBaseRunner(abc.ABC): ...@@ -56,6 +54,15 @@ class SimulationBaseRunner(abc.ABC):
def sim_executor(self, simulator: sim_base.Simulator) -> command_executor.Executor: def sim_executor(self, simulator: sim_base.Simulator) -> command_executor.Executor:
pass pass
def sim_listener(self, sim: sim_base.Simulator) -> command_executor.OutputListener:
if sim not in self._out_listener:
raise Exception(f"no listener specified for simulator {sim.id()}")
return self._out_listener[sim]
def add_listener(self, sim: sim_base.Simulator, listener: command_executor.OutputListener) -> None:
self._out_listener[sim] = listener
self._out.add_mapping(sim, listener)
async def start_sim(self, sim: sim_base.Simulator) -> None: async def start_sim(self, sim: sim_base.Simulator) -> None:
"""Start a simulator and wait for it to be ready.""" """Start a simulator and wait for it to be ready."""
...@@ -70,10 +77,11 @@ class SimulationBaseRunner(abc.ABC): ...@@ -70,10 +77,11 @@ class SimulationBaseRunner(abc.ABC):
return return
# run simulator # run simulator
executor = self._instantiation.executor # TODO: this should be a function or something executor = self._instantiation.executor # TODO: this should be a function or something
sc = executor.create_component( cmds = shlex.split(run_cmd)
name, shlex.split(run_cmd), verbose=self._verbose, canfail=True sc = executor.create_component(name, cmds, verbose=self._verbose, canfail=True)
) if listener := self.sim_listener(sim=sim):
sc.subscribe(listener=listener)
await sc.start() await sc.start()
self._running.append((sim, sc)) self._running.append((sim, sc))
...@@ -88,9 +96,7 @@ class SimulationBaseRunner(abc.ABC): ...@@ -88,9 +96,7 @@ class SimulationBaseRunner(abc.ABC):
print(f"{self._instantiation.simulation.name}: waiting for sockets {name}") print(f"{self._instantiation.simulation.name}: waiting for sockets {name}")
await self._instantiation.wait_for_sockets(sockets=wait_socks) await self._instantiation.wait_for_sockets(sockets=wait_socks)
if self._verbose: if self._verbose:
print( print(f"{self._instantiation.simulation.name}: waited successfully for sockets {name}")
f"{self._instantiation.simulation.name}: waited successfully for sockets {name}"
)
# add time delay if required # add time delay if required
delay = sim.start_delay() delay = sim.start_delay()
...@@ -107,11 +113,7 @@ class SimulationBaseRunner(abc.ABC): ...@@ -107,11 +113,7 @@ class SimulationBaseRunner(abc.ABC):
pass pass
async def before_cleanup(self) -> None: async def before_cleanup(self) -> None:
if self._instantiation.create_artifact: pass
art.create_artifact(
artifact_name=self._instantiation.artifact_name,
paths_to_include=self._instantiation.artifact_paths
)
async def after_cleanup(self) -> None: async def after_cleanup(self) -> None:
pass pass
...@@ -147,10 +149,6 @@ class SimulationBaseRunner(abc.ABC): ...@@ -147,10 +149,6 @@ class SimulationBaseRunner(abc.ABC):
for _, sc in self._running: for _, sc in self._running:
await sc.wait() await sc.wait()
# add all simulator components to the output
for sim, sc in self._running:
self._out.add_sim(sim, sc)
await self.after_cleanup() await self.after_cleanup()
return self._out return self._out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment