"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "b902f28dbe742e76b5cdc1d4c826752352928822"
Commit bfdc0f43 authored by Jonas Kaufmann's avatar Jonas Kaufmann
Browse files

orchestration/instantiation/base: refactor classes Socket and SockType to a separate module

parent f8a938ff
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
from __future__ import annotations from __future__ import annotations
import enum
import pathlib import pathlib
import shutil import shutil
import typing import typing
...@@ -35,25 +34,13 @@ from simbricks.orchestration.system import pcie as sys_pcie ...@@ -35,25 +34,13 @@ from simbricks.orchestration.system import pcie as sys_pcie
from simbricks.orchestration.system import mem as sys_mem from simbricks.orchestration.system import mem as sys_mem
from simbricks.orchestration.system import eth as sys_eth from simbricks.orchestration.system import eth as sys_eth
from simbricks.orchestration.system.host import disk_images from simbricks.orchestration.system.host import disk_images
from simbricks.orchestration.instantiation import inst_socket
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from simbricks.orchestration.simulation import base as sim_base from simbricks.orchestration.simulation import base as sim_base
from simbricks.runtime import command_executor from simbricks.runtime import command_executor
class SockType(enum.Enum):
LISTEN = enum.auto()
CONNECT = enum.auto()
class Socket(util_base.IdObj):
def __init__(self, path: str = "", ty: SockType = SockType.LISTEN):
super().__init__()
self._path = path
self._type = ty
class InstantiationEnvironment(util_base.IdObj): class InstantiationEnvironment(util_base.IdObj):
def __init__( def __init__(
...@@ -94,7 +81,7 @@ class Instantiation(): ...@@ -94,7 +81,7 @@ class Instantiation():
self._preserve_checkpoints: bool = True self._preserve_checkpoints: bool = True
self.preserve_tmp_folder: bool = False self.preserve_tmp_folder: bool = False
# NOTE: temporary data structure # NOTE: temporary data structure
self._socket_per_interface: dict[sys_base.Interface, Socket] = {} self._socket_per_interface: dict[sys_base.Interface, inst_socket.Socket] = {}
# NOTE: temporary data structure # NOTE: temporary data structure
self._sim_dependency: ( self._sim_dependency: (
dict[sim_base.Simulator, set[sim_base.Simulator]] | None dict[sim_base.Simulator, set[sim_base.Simulator]] | None
...@@ -137,21 +124,21 @@ class Instantiation(): ...@@ -137,21 +124,21 @@ class Instantiation():
) )
def _updated_tracker_mapping( def _updated_tracker_mapping(
self, interface: sys_base.Interface, socket: Socket self, interface: sys_base.Interface, socket: inst_socket.Socket
) -> None: ) -> None:
# update interface mapping # update interface mapping
if interface in self._socket_per_interface: if interface in self._socket_per_interface:
raise Exception("an interface cannot be associated with two sockets") raise Exception("an interface cannot be associated with two sockets")
self._socket_per_interface[interface] = socket self._socket_per_interface[interface] = socket
def _get_socket_by_interface(self, interface: sys_base.Interface) -> Socket | None: def _get_socket_by_interface(self, interface: sys_base.Interface) -> inst_socket.Socket | None:
if interface not in self._socket_per_interface: if interface not in self._socket_per_interface:
return None return None
return self._socket_per_interface[interface] return self._socket_per_interface[interface]
def _get_opposing_socket_by_interface( def _get_opposing_socket_by_interface(
self, interface: sys_base.Interface self, interface: sys_base.Interface
) -> Socket | None: ) -> inst_socket.Socket | None:
opposing_interface = self._get_opposing_interface(interface=interface) opposing_interface = self._get_opposing_interface(interface=interface)
socket = self._get_socket_by_interface(interface=opposing_interface) socket = self._get_socket_by_interface(interface=opposing_interface)
return socket return socket
...@@ -179,27 +166,27 @@ class Instantiation(): ...@@ -179,27 +166,27 @@ class Instantiation():
enforce_existence=False, enforce_existence=False,
) )
def _create_opposing_socket(self, socket: Socket, socket_type: SockType) -> Socket: def _create_opposing_socket(self, socket: inst_socket.Socket, socket_type: inst_socket.SockType) -> inst_socket.Socket:
new_ty = ( new_ty = (
SockType.LISTEN if socket._type == SockType.CONNECT else SockType.CONNECT inst_socket.SockType.LISTEN if socket._type == inst_socket.SockType.CONNECT else inst_socket.SockType.CONNECT
) )
if new_ty != socket_type: if new_ty != socket_type:
raise Exception( raise Exception(
f"cannot create opposing socket, as required type is not supported: required={new_ty.name}, supported={socket_type.name}" f"cannot create opposing socket, as required type is not supported: required={new_ty.name}, supported={socket_type.name}"
) )
new_path = socket._path new_path = socket._path
new_socket = Socket(path=new_path, ty=new_ty) new_socket = inst_socket.Socket(path=new_path, ty=new_ty)
return new_socket return new_socket
def get_socket(self, interface: sys_base.Interface) -> Socket | None: def get_socket(self, interface: sys_base.Interface) -> inst_socket.Socket | None:
socket = self._get_socket_by_interface(interface=interface) socket = self._get_socket_by_interface(interface=interface)
if socket: if socket:
return socket return socket
return None return None
def _get_socket( def _get_socket(
self, interface: sys_base.Interface, socket_type: SockType self, interface: sys_base.Interface, socket_type: inst_socket.SockType
) -> Socket: ) -> inst_socket.Socket:
if self._opposing_interface_within_same_sim(interface=interface): if self._opposing_interface_within_same_sim(interface=interface):
raise Exception( raise Exception(
...@@ -223,7 +210,7 @@ class Instantiation(): ...@@ -223,7 +210,7 @@ class Instantiation():
# create socket if opposing socket was not created yet # create socket if opposing socket was not created yet
sock_path = self._interface_to_sock_path(interface=interface) sock_path = self._interface_to_sock_path(interface=interface)
new_socket = Socket(path=sock_path, ty=socket_type) new_socket = inst_socket.Socket(path=sock_path, ty=socket_type)
self._updated_tracker_mapping(interface=interface, socket=new_socket) self._updated_tracker_mapping(interface=interface, socket=new_socket)
print(f"created socket: {new_socket._path}") print(f"created socket: {new_socket._path}")
return new_socket return new_socket
...@@ -252,32 +239,32 @@ class Instantiation(): ...@@ -252,32 +239,32 @@ class Instantiation():
def update_a_depends_on_b(inf_a: sys_base.Interface, inf_b: sys_base.Interface): def update_a_depends_on_b(inf_a: sys_base.Interface, inf_b: sys_base.Interface):
sim_a = self.find_sim_by_interface(interface=inf_a) sim_a = self.find_sim_by_interface(interface=inf_a)
sim_b = self.find_sim_by_interface(interface=inf_b) sim_b = self.find_sim_by_interface(interface=inf_b)
a_sock: set[SockType] = sim_a.supported_socket_types(interface=inf_a) a_sock: set[inst_socket.SockType] = sim_a.supported_socket_types(interface=inf_a)
b_sock: set[SockType] = sim_b.supported_socket_types(interface=inf_b) b_sock: set[inst_socket.SockType] = sim_b.supported_socket_types(interface=inf_b)
if a_sock != b_sock: if a_sock != b_sock:
if len(a_sock) == 0 or len(b_sock) == 0: if len(a_sock) == 0 or len(b_sock) == 0:
raise Exception( raise Exception(
"cannot create socket and resolve dependency if no socket type is supported for an interface" "cannot create socket and resolve dependency if no socket type is supported for an interface"
) )
if SockType.CONNECT in a_sock: if inst_socket.SockType.CONNECT in a_sock:
assert SockType.LISTEN in b_sock assert inst_socket.SockType.LISTEN in b_sock
insert_dependency(sim_a, depends_on=sim_b) insert_dependency(sim_a, depends_on=sim_b)
self._get_socket(interface=inf_a, socket_type=SockType.CONNECT) self._get_socket(interface=inf_a, socket_type=inst_socket.SockType.CONNECT)
self._get_socket(interface=inf_b, socket_type=SockType.LISTEN) self._get_socket(interface=inf_b, socket_type=inst_socket.SockType.LISTEN)
else: else:
assert SockType.CONNECT in b_sock assert inst_socket.SockType.CONNECT in b_sock
insert_dependency(sim_b, depends_on=sim_a) insert_dependency(sim_b, depends_on=sim_a)
self._get_socket(interface=inf_b, socket_type=SockType.CONNECT) self._get_socket(interface=inf_b, socket_type=inst_socket.SockType.CONNECT)
self._get_socket(interface=inf_a, socket_type=SockType.LISTEN) self._get_socket(interface=inf_a, socket_type=inst_socket.SockType.LISTEN)
else: else:
# deadlock? # deadlock?
if len(a_sock) != 2 or len(b_sock) != 2: if len(a_sock) != 2 or len(b_sock) != 2:
raise Exception("cannot solve deadlock") raise Exception("cannot solve deadlock")
# both support both we just pick an order # both support both we just pick an order
insert_dependency(sim_a, depends_on=sim_b) insert_dependency(sim_a, depends_on=sim_b)
self._get_socket(interface=sim_a, socket_type=SockType.CONNECT) self._get_socket(interface=sim_a, socket_type=inst_socket.SockType.CONNECT)
self._get_socket(interface=sim_b, socket_type=SockType.LISTEN) self._get_socket(interface=sim_b, socket_type=inst_socket.SockType.LISTEN)
# build dependency graph # build dependency graph
for sim in self.simulation.all_simulators(): for sim in self.simulation.all_simulators():
...@@ -299,7 +286,7 @@ class Instantiation(): ...@@ -299,7 +286,7 @@ class Instantiation():
async def wait_for_sockets( async def wait_for_sockets(
self, self,
sockets: list[Socket] = [], sockets: list[inst_socket.Socket] = [],
) -> None: ) -> None:
wait_socks = list(map(lambda sock: sock._path, sockets)) wait_socks = list(map(lambda sock: sock._path, sockets))
await self.executor.await_files(wait_socks, verbose=True) await self.executor.await_files(wait_socks, verbose=True)
......
# Copyright 2022 Max Planck Institute for Software Systems, and
# National University of Singapore
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import enum
from simbricks.utils import base as util_base
class SockType(enum.Enum):
LISTEN = enum.auto()
CONNECT = enum.auto()
class Socket(util_base.IdObj):
def __init__(self, path: str = "", ty: SockType = SockType.LISTEN):
super().__init__()
self._path = path
self._type = ty
...@@ -27,6 +27,7 @@ import asyncio ...@@ -27,6 +27,7 @@ import asyncio
import typing as tp import typing as tp
import simbricks.orchestration.system as sys_conf import simbricks.orchestration.system as sys_conf
import simbricks.orchestration.instantiation.base as inst_base import simbricks.orchestration.instantiation.base as inst_base
import simbricks.orchestration.instantiation.socket as inst_socket
import simbricks.orchestration.simulation.channel as sim_chan import simbricks.orchestration.simulation.channel as sim_chan
import simbricks.utils.base as utils_base import simbricks.utils.base as utils_base
...@@ -146,21 +147,21 @@ class Simulator(utils_base.IdObj): ...@@ -146,21 +147,21 @@ class Simulator(utils_base.IdObj):
@staticmethod @staticmethod
def filter_sockets( def filter_sockets(
sockets: list[inst_base.Socket], sockets: list[inst_socket.Socket],
filter_type: inst_base.SockType = inst_base.SockType.LISTEN, filter_type: inst_socket.sockType = inst_socket.sockType.LISTEN,
) -> list[inst_base.Socket]: ) -> list[inst_socket.Socket]:
res = list(filter(lambda sock: sock._type == filter_type, sockets)) res = list(filter(lambda sock: sock._type == filter_type, sockets))
return res return res
@staticmethod @staticmethod
def split_sockets_by_type( def split_sockets_by_type(
sockets: list[inst_base.Socket], sockets: list[inst_socket.Socket],
) -> tuple[list[inst_base.Socket], list[inst_base.Socket]]: ) -> tuple[list[inst_socket.Socket], list[inst_socket.Socket]]:
listen = Simulator.filter_sockets( listen = Simulator.filter_sockets(
sockets=sockets, filter_type=inst_base.SockType.LISTEN sockets=sockets, filter_type=inst_socket.sockType.LISTEN
) )
connect = Simulator.filter_sockets( connect = Simulator.filter_sockets(
sockets=sockets, filter_type=inst_base.SockType.CONNECT sockets=sockets, filter_type=inst_socket.sockType.CONNECT
) )
return listen, connect return listen, connect
...@@ -229,7 +230,7 @@ class Simulator(utils_base.IdObj): ...@@ -229,7 +230,7 @@ class Simulator(utils_base.IdObj):
def _get_socks_by_comp( def _get_socks_by_comp(
self, inst: inst_base.Instantiation, comp: sys_conf.Component self, inst: inst_base.Instantiation, comp: sys_conf.Component
) -> list[inst_base.Socket]: ) -> list[inst_socket.Socket]:
if comp not in self._components: if comp not in self._components:
raise Exception("comp must be a simulators component") raise Exception("comp must be a simulators component")
sockets = [] sockets = []
...@@ -241,15 +242,15 @@ class Simulator(utils_base.IdObj): ...@@ -241,15 +242,15 @@ class Simulator(utils_base.IdObj):
def _get_socks_by_all_comp( def _get_socks_by_all_comp(
self, inst: inst_base.Instantiation self, inst: inst_base.Instantiation
) -> list[inst_base.Socket]: ) -> list[inst_socket.Socket]:
sockets = [] sockets = []
for comp in self._components: for comp in self._components:
sockets.extend(self._get_socks_by_comp(inst=inst, comp=comp)) sockets.extend(self._get_socks_by_comp(inst=inst, comp=comp))
return sockets return sockets
def _get_all_sockets_by_type( def _get_all_sockets_by_type(
self, inst: inst_base.Instantiation, sock_type: inst_base.SockType self, inst: inst_base.Instantiation, sock_type: inst_socket.sockType
) -> list[inst_base.Socket]: ) -> list[inst_socket.Socket]:
sockets = self._get_socks_by_all_comp(inst=inst) sockets = self._get_socks_by_all_comp(inst=inst)
sockets = Simulator.filter_sockets(sockets=sockets, filter_type=sock_type) sockets = Simulator.filter_sockets(sockets=sockets, filter_type=sock_type)
return sockets return sockets
...@@ -285,21 +286,21 @@ class Simulator(utils_base.IdObj): ...@@ -285,21 +286,21 @@ class Simulator(utils_base.IdObj):
@abc.abstractmethod @abc.abstractmethod
def supported_socket_types( def supported_socket_types(
self, interface: sys_conf.Interface self, interface: sys_conf.Interface
) -> set[inst_base.SockType]: ) -> set[inst_socket.sockType]:
return [] return []
# Sockets to be cleaned up: always the CONNECTING sockets # Sockets to be cleaned up: always the CONNECTING sockets
# pylint: disable=unused-argument # pylint: disable=unused-argument
def sockets_cleanup(self, inst: inst_base.Instantiation) -> list[inst_base.Socket]: def sockets_cleanup(self, inst: inst_base.Instantiation) -> list[inst_socket.Socket]:
return self._get_all_sockets_by_type( return self._get_all_sockets_by_type(
inst=inst, sock_type=inst_base.SockType.LISTEN inst=inst, sock_type=inst_socket.sockType.LISTEN
) )
# sockets to wait for indicating the simulator is ready # sockets to wait for indicating the simulator is ready
# pylint: disable=unused-argument # pylint: disable=unused-argument
def sockets_wait(self, inst: inst_base.Instantiation) -> list[inst_base.Socket]: def sockets_wait(self, inst: inst_base.Instantiation) -> list[inst_socket.Socket]:
return self._get_all_sockets_by_type( return self._get_all_sockets_by_type(
inst=inst, sock_type=inst_base.SockType.LISTEN inst=inst, sock_type=inst_socket.sockType.LISTEN
) )
def start_delay(self) -> int: def start_delay(self) -> int:
......
...@@ -31,6 +31,7 @@ from simbricks.orchestration.system import host as sys_host ...@@ -31,6 +31,7 @@ from simbricks.orchestration.system import host as sys_host
from simbricks.orchestration.system import pcie as sys_pcie from simbricks.orchestration.system import pcie as sys_pcie
from simbricks.orchestration.system import mem as sys_mem from simbricks.orchestration.system import mem as sys_mem
from simbricks.utils import base as utils_base from simbricks.utils import base as utils_base
from simbricks.orchestration.instantiation import socket as inst_socket
class HostSim(sim_base.Simulator): class HostSim(sim_base.Simulator):
...@@ -60,7 +61,7 @@ class HostSim(sim_base.Simulator): ...@@ -60,7 +61,7 @@ class HostSim(sim_base.Simulator):
def supported_socket_types( def supported_socket_types(
self, interface: system.Interface self, interface: system.Interface
) -> set[inst_base.SockType]: ) -> set[inst_base.SockType]:
return [inst_base.SockType.CONNECT] return [inst_socket.SockType.CONNECT]
class Gem5Sim(HostSim): class Gem5Sim(HostSim):
......
...@@ -26,6 +26,7 @@ from simbricks.orchestration.system import base as sys_base ...@@ -26,6 +26,7 @@ from simbricks.orchestration.system import base as sys_base
from simbricks.orchestration.system import eth as sys_eth from simbricks.orchestration.system import eth as sys_eth
from simbricks.orchestration.simulation import base as sim_base from simbricks.orchestration.simulation import base as sim_base
from simbricks.orchestration.instantiation import base as inst_base from simbricks.orchestration.instantiation import base as inst_base
from simbricks.orchestration.instantiation import socket as inst_socket
from simbricks.utils import base as base_utils from simbricks.utils import base as base_utils
...@@ -48,7 +49,7 @@ class NetSim(sim_base.Simulator): ...@@ -48,7 +49,7 @@ class NetSim(sim_base.Simulator):
def supported_socket_types( def supported_socket_types(
self, interface: sys_base.Interface self, interface: sys_base.Interface
) -> set[inst_base.SockType]: ) -> set[inst_socket.SockType]:
return [inst_base.SockType.CONNECT] return [inst_base.SockType.CONNECT]
def toJSON(self) -> dict: def toJSON(self) -> dict:
...@@ -293,12 +294,12 @@ class NS3DumbbellNet(SimpleNS3Sim): ...@@ -293,12 +294,12 @@ class NS3DumbbellNet(SimpleNS3Sim):
left_socks = self._get_socks_by_comp(inst=inst, comp=self._left) left_socks = self._get_socks_by_comp(inst=inst, comp=self._left)
for sock in left_socks: for sock in left_socks:
assert sock._type == inst_base.SockType.CONNECT assert sock._type == inst_socket.SockType.CONNECT
cmd += f"--SimbricksPortLeft={sock._path} " cmd += f"--SimbricksPortLeft={sock._path} "
right_sockets = self._get_socks_by_comp(inst=inst, comp=self._right) right_sockets = self._get_socks_by_comp(inst=inst, comp=self._right)
for sock in right_sockets: for sock in right_sockets:
assert sock._type == inst_base.SockType.CONNECT assert sock._type == inst_socket.SockType.CONNECT
cmd += f"--SimbricksPortRight={sock._path} " cmd += f"--SimbricksPortRight={sock._path} "
if self.opt is not None: if self.opt is not None:
......
...@@ -27,6 +27,7 @@ from simbricks.orchestration.system import pcie as sys_pcie ...@@ -27,6 +27,7 @@ from simbricks.orchestration.system import pcie as sys_pcie
from simbricks.orchestration.system import eth as sys_eth from simbricks.orchestration.system import eth as sys_eth
from simbricks.orchestration.system import nic as sys_nic from simbricks.orchestration.system import nic as sys_nic
from simbricks.orchestration.instantiation import base as inst_base from simbricks.orchestration.instantiation import base as inst_base
from simbricks.orchestration.instantiation import socket as inst_socket
from simbricks.orchestration.simulation import base as sim_base from simbricks.orchestration.simulation import base as sim_base
...@@ -43,8 +44,8 @@ class PCIDevSim(sim_base.Simulator): ...@@ -43,8 +44,8 @@ class PCIDevSim(sim_base.Simulator):
def supported_socket_types( def supported_socket_types(
self, interface: sys_base.Interface self, interface: sys_base.Interface
) -> set[inst_base.SockType]: ) -> set[inst_socket.SockType]:
return [inst_base.SockType.LISTEN] return [inst_socket.SockType.LISTEN]
class NICSim(PCIDevSim): class NICSim(PCIDevSim):
...@@ -75,11 +76,11 @@ class NICSim(PCIDevSim): ...@@ -75,11 +76,11 @@ class NICSim(PCIDevSim):
nic_device = nic_devices[0] nic_device = nic_devices[0]
socket = inst.get_socket(interface=nic_device._pci_if) socket = inst.get_socket(interface=nic_device._pci_if)
assert socket is not None and socket._type == inst_base.SockType.LISTEN assert socket is not None and socket._type == inst_socket.SockType.LISTEN
cmd += f"{socket._path} " cmd += f"{socket._path} "
socket = inst.get_socket(interface=nic_device._eth_if) socket = inst.get_socket(interface=nic_device._eth_if)
assert socket is not None and socket._type == inst_base.SockType.LISTEN assert socket is not None and socket._type == inst_socket.SockType.LISTEN
cmd += f"{socket._path} " cmd += f"{socket._path} "
cmd += ( cmd += (
......
...@@ -30,6 +30,7 @@ import abc ...@@ -30,6 +30,7 @@ import abc
from simbricks.runtime import output from simbricks.runtime import output
from simbricks.orchestration.simulation import base as sim_base from simbricks.orchestration.simulation import base as sim_base
from simbricks.orchestration.instantiation import base as inst_base from simbricks.orchestration.instantiation import base as inst_base
from simbricks.orchestration.instantiation import socket as inst_socket
from simbricks.runtime import command_executor from simbricks.runtime import command_executor
from simbricks.utils import graphlib from simbricks.utils import graphlib
...@@ -47,7 +48,7 @@ class SimulationBaseRunner(abc.ABC): ...@@ -47,7 +48,7 @@ class SimulationBaseRunner(abc.ABC):
self._out: output.SimulationOutput = output.SimulationOutput(self._instantiation.simulation) self._out: output.SimulationOutput = output.SimulationOutput(self._instantiation.simulation)
self._out_listener: dict[sim_base.Simulator, command_executor.OutputListener] = {} self._out_listener: dict[sim_base.Simulator, command_executor.OutputListener] = {}
self._running: list[tuple[sim_base.Simulator, command_executor.SimpleComponent]] = [] self._running: list[tuple[sim_base.Simulator, command_executor.SimpleComponent]] = []
self._sockets: list[inst_base.Socket] = [] self._sockets: list[inst_socket.Socket] = []
self._wait_sims: list[command_executor.Component] = [] self._wait_sims: list[command_executor.Component] = []
@abc.abstractmethod @abc.abstractmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment