Unverified Commit 2f2e030d authored by Jakob Görgen's avatar Jakob Görgen
Browse files

symphony/orchestration: moved helper methods for synchronization and latency...

symphony/orchestration: moved helper methods for synchronization and latency into system and simulation classes
parent 7294c065
...@@ -72,9 +72,7 @@ class Simulator(utils_base.IdObj, abc.ABC): ...@@ -72,9 +72,7 @@ class Simulator(utils_base.IdObj, abc.ABC):
return list(filter(pred, self._components)) return list(filter(pred, self._components))
def filter_components_by_type(self, ty: type[T]) -> list[T]: def filter_components_by_type(self, ty: type[T]) -> list[T]:
return self.filter_components_by_pred( return self.filter_components_by_pred(pred=lambda comp: isinstance(comp, ty), ty=ty)
pred=lambda comp: isinstance(comp, ty), ty=ty
)
@property @property
def extra_args(self) -> str: def extra_args(self) -> str:
...@@ -136,9 +134,7 @@ class Simulator(utils_base.IdObj, abc.ABC): ...@@ -136,9 +134,7 @@ class Simulator(utils_base.IdObj, abc.ABC):
instance._wait = bool(utils_base.get_json_attr_top(json_obj, "wait")) instance._wait = bool(utils_base.get_json_attr_top(json_obj, "wait"))
instance._start_tick = int(utils_base.get_json_attr_top(json_obj, "start_tick")) instance._start_tick = int(utils_base.get_json_attr_top(json_obj, "start_tick"))
instance._extra_args = utils_base.get_json_attr_top_or_none( instance._extra_args = utils_base.get_json_attr_top_or_none(json_obj, "extra_args")
json_obj, "extra_args"
)
return instance return instance
...@@ -154,9 +150,7 @@ class Simulator(utils_base.IdObj, abc.ABC): ...@@ -154,9 +150,7 @@ class Simulator(utils_base.IdObj, abc.ABC):
def split_sockets_by_type( def split_sockets_by_type(
sockets: list[inst_socket.Socket], sockets: list[inst_socket.Socket],
) -> tuple[list[inst_socket.Socket], list[inst_socket.Socket]]: ) -> tuple[list[inst_socket.Socket], list[inst_socket.Socket]]:
listen = Simulator.filter_sockets( listen = Simulator.filter_sockets(sockets=sockets, filter_type=inst_socket.SockType.LISTEN)
sockets=sockets, filter_type=inst_socket.SockType.LISTEN
)
connect = Simulator.filter_sockets( connect = Simulator.filter_sockets(
sockets=sockets, filter_type=inst_socket.SockType.CONNECT sockets=sockets, filter_type=inst_socket.SockType.CONNECT
) )
...@@ -174,9 +168,7 @@ class Simulator(utils_base.IdObj, abc.ABC): ...@@ -174,9 +168,7 @@ class Simulator(utils_base.IdObj, abc.ABC):
run_sync = False run_sync = False
for channel in channels: for channel in channels:
sync_period = ( sync_period = (
min(sync_period, channel.sync_period) min(sync_period, channel.sync_period) if sync_period else channel.sync_period
if sync_period
else channel.sync_period
) )
run_sync = run_sync or channel._synchronized run_sync = run_sync or channel._synchronized
latency = ( latency = (
...@@ -218,10 +210,7 @@ class Simulator(utils_base.IdObj, abc.ABC): ...@@ -218,10 +210,7 @@ class Simulator(utils_base.IdObj, abc.ABC):
self._simulation.add_spec_sim_map(comp, self) self._simulation.add_spec_sim_map(comp, self)
def _chan_needs_instance(self, chan: sys_conf.Channel) -> bool: def _chan_needs_instance(self, chan: sys_conf.Channel) -> bool:
if ( if chan.a.component in self._components and chan.b.component in self._components:
chan.a.component in self._components
and chan.b.component in self._components
):
return False return False
return True return True
...@@ -237,9 +226,7 @@ class Simulator(utils_base.IdObj, abc.ABC): ...@@ -237,9 +226,7 @@ class Simulator(utils_base.IdObj, abc.ABC):
sockets.append(socket) sockets.append(socket)
return sockets return sockets
def _get_socks_by_all_comp( def _get_socks_by_all_comp(self, inst: inst_base.Instantiation) -> list[inst_socket.Socket]:
self, inst: inst_base.Instantiation
) -> list[inst_socket.Socket]:
sockets = [] sockets = []
for comp in self._components: for comp in self._components:
sockets.extend(self._get_socks_by_comp(inst=inst, comp=comp)) sockets.extend(self._get_socks_by_comp(inst=inst, comp=comp))
...@@ -281,24 +268,18 @@ class Simulator(utils_base.IdObj, abc.ABC): ...@@ -281,24 +268,18 @@ class Simulator(utils_base.IdObj, abc.ABC):
return [] return []
@abc.abstractmethod @abc.abstractmethod
def supported_socket_types( def supported_socket_types(self, interface: sys_conf.Interface) -> set[inst_socket.SockType]:
self, interface: sys_conf.Interface
) -> set[inst_socket.SockType]:
return set() return set()
# Sockets to be cleaned up: always the CONNECTING sockets # Sockets to be cleaned up: always the CONNECTING sockets
# pylint: disable=unused-argument # pylint: disable=unused-argument
def sockets_cleanup(self, inst: inst_base.Instantiation) -> list[inst_socket.Socket]: def sockets_cleanup(self, inst: inst_base.Instantiation) -> list[inst_socket.Socket]:
return self._get_all_sockets_by_type( return self._get_all_sockets_by_type(inst=inst, sock_type=inst_socket.SockType.LISTEN)
inst=inst, sock_type=inst_socket.SockType.LISTEN
)
# sockets to wait for indicating the simulator is ready # sockets to wait for indicating the simulator is ready
# pylint: disable=unused-argument # pylint: disable=unused-argument
def sockets_wait(self, inst: inst_base.Instantiation) -> list[inst_socket.Socket]: def sockets_wait(self, inst: inst_base.Instantiation) -> list[inst_socket.Socket]:
return self._get_all_sockets_by_type( return self._get_all_sockets_by_type(inst=inst, sock_type=inst_socket.SockType.LISTEN)
inst=inst, sock_type=inst_socket.SockType.LISTEN
)
def start_delay(self) -> int: def start_delay(self) -> int:
return 5 return 5
...@@ -467,6 +448,14 @@ class Simulation(utils_base.IdObj): ...@@ -467,6 +448,14 @@ class Simulation(utils_base.IdObj):
all_channels.extend(channels) all_channels.extend(channels)
return all_channels return all_channels
def enable_synchronization(
self, amount: int | None = None, ratio: utils_base.Time | None = None
) -> None:
for chan in self.get_all_channels():
chan._synchronized = True
if amount and ratio:
chan.set_sync_period(amount=amount, ratio=ratio)
def resreq_mem(self) -> int: def resreq_mem(self) -> int:
"""Memory required to run all simulators in this experiment.""" """Memory required to run all simulators in this experiment."""
mem = 0 mem = 0
...@@ -495,9 +484,6 @@ class Simulation(utils_base.IdObj): ...@@ -495,9 +484,6 @@ class Simulation(utils_base.IdObj):
await asyncio.gather(*promises) await asyncio.gather(*promises)
def any_supports_checkpointing(self) -> bool: def any_supports_checkpointing(self) -> bool:
if ( if len(list(filter(lambda sim: sim.supports_checkpointing(), self._sim_list))) > 0:
len(list(filter(lambda sim: sim.supports_checkpointing(), self._sim_list)))
> 0
):
return True return True
return False return False
...@@ -22,27 +22,24 @@ ...@@ -22,27 +22,24 @@
from __future__ import annotations from __future__ import annotations
import enum
from simbricks.orchestration.simulation import base as sim_base from simbricks.orchestration.simulation import base as sim_base
from simbricks.orchestration.system import base as system_base from simbricks.orchestration.system import base as system_base
from simbricks.utils import base as utils_base from simbricks.utils import base as utils_base
class Time(enum.IntEnum):
Picoseconds = 10 ** (-3)
Nanoseconds = 1
Microseconds = 10 ** (3)
Milliseconds = 10 ** (6)
Seconds = 10 ** (9)
class Channel(utils_base.IdObj): class Channel(utils_base.IdObj):
def __init__(self, chan: system_base.Channel): def __init__(self, chan: system_base.Channel):
super().__init__() super().__init__()
self._synchronized: bool = False self._synchronized: bool = False
self.sync_period: int = 500 # nano seconds self.sync_period: int = 500 # nanoseconds
"""
The synchronization period in nanoseconds. For SimBricks to function
properly in sync mode, the sync period must not be larger than a channels
latency.
"""
assert self.sync_period <= chan.latency
self.sys_channel: system_base.Channel = chan self.sys_channel: system_base.Channel = chan
def toJSON(self): def toJSON(self):
...@@ -57,12 +54,8 @@ class Channel(utils_base.IdObj): ...@@ -57,12 +54,8 @@ class Channel(utils_base.IdObj):
@classmethod @classmethod
def fromJSON(cls, simulation: sim_base.Simulation, json_obj: dict) -> Channel: def fromJSON(cls, simulation: sim_base.Simulation, json_obj: dict) -> Channel:
instance = super().fromJSON(json_obj) instance = super().fromJSON(json_obj)
instance._synchronized = bool( instance._synchronized = bool(utils_base.get_json_attr_top(json_obj, "synchronized"))
utils_base.get_json_attr_top(json_obj, "synchronized") instance.sync_period = int(utils_base.get_json_attr_top(json_obj, "sync_period"))
)
instance.sync_period = int(
utils_base.get_json_attr_top(json_obj, "sync_period")
)
chan_id = int(utils_base.get_json_attr_top(json_obj, "sys_channel")) chan_id = int(utils_base.get_json_attr_top(json_obj, "sys_channel"))
instance.sys_channel = simulation.system.get_chan(chan_id) instance.sys_channel = simulation.system.get_chan(chan_id)
return instance return instance
...@@ -70,6 +63,9 @@ class Channel(utils_base.IdObj): ...@@ -70,6 +63,9 @@ class Channel(utils_base.IdObj):
def full_name(self) -> str: def full_name(self) -> str:
return "channel." + self.name return "channel." + self.name
def set_sync_period(self, amount: int, ratio: Time = Time.Nanoseconds) -> None: def set_sync_period(
utils_base.has_expected_type(obj=ratio, expected_type=Time) self, amount: int, ratio: utils_base.Time = utils_base.Time.Nanoseconds
) -> None:
utils_base.has_expected_type(obj=ratio, expected_type=utils_base.Time)
self.sync_period = amount * ratio self.sync_period = amount * ratio
assert self.sync_period <= self.sys_channel.latency
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
from __future__ import annotations from __future__ import annotations
import abc
import typing as tp import typing as tp
from simbricks.utils import base as util_base from simbricks.utils import base as util_base
...@@ -72,6 +71,20 @@ class System(util_base.IdObj): ...@@ -72,6 +71,20 @@ class System(util_base.IdObj):
return self._all_channels[ident] return self._all_channels[ident]
@staticmethod
def set_latencies(channels: list[Channel], amount: int, ratio: util_base.Time) -> None:
for chan in channels:
chan.set_latency(amount, ratio)
def latencies(self, amount: int, ratio: util_base.Time, channel_type: tp.Any) -> None:
relevant_channels = list(
filter(
lambda chan: util_base.check_type(chan, channel_type),
self._all_channels.values(),
)
)
System.set_latencies(relevant_channels, amount, ratio)
def toJSON(self) -> dict: def toJSON(self) -> dict:
json_obj = super().toJSON() json_obj = super().toJSON()
json_obj["type"] = self.__class__.__name__ json_obj["type"] = self.__class__.__name__
...@@ -242,7 +255,7 @@ class Interface(util_base.IdObj): ...@@ -242,7 +255,7 @@ class Interface(util_base.IdObj):
class Channel(util_base.IdObj): class Channel(util_base.IdObj):
def __init__(self, a: Interface, b: Interface) -> None: def __init__(self, a: Interface, b: Interface) -> None:
super().__init__() super().__init__()
self.latency = 500 self.latency = 500 # nanoseconds
self.a: Interface = a self.a: Interface = a
self.a.connect(self) self.a.connect(self)
self.b: Interface = b self.b: Interface = b
...@@ -268,6 +281,10 @@ class Channel(util_base.IdObj): ...@@ -268,6 +281,10 @@ class Channel(util_base.IdObj):
assert opposing != interface assert opposing != interface
return opposing return opposing
def set_latency(self, amount: int, ratio: util_base.Time = util_base.Time.Nanoseconds) -> None:
util_base.has_expected_type(obj=ratio, expected_type=util_base.Time)
self.latency = amount * ratio
def toJSON(self) -> dict: def toJSON(self) -> dict:
json_obj = super().toJSON() json_obj = super().toJSON()
json_obj["type"] = self.__class__.__name__ json_obj["type"] = self.__class__.__name__
......
...@@ -24,6 +24,7 @@ import abc ...@@ -24,6 +24,7 @@ import abc
import itertools import itertools
import importlib import importlib
import typing as tp import typing as tp
import enum
class IdObj(abc.ABC): class IdObj(abc.ABC):
...@@ -49,6 +50,14 @@ class IdObj(abc.ABC): ...@@ -49,6 +50,14 @@ class IdObj(abc.ABC):
return instance return instance
class Time(enum.IntEnum):
Picoseconds = 10 ** (-3)
Nanoseconds = 1
Microseconds = 10 ** (3)
Milliseconds = 10 ** (6)
Seconds = 10 ** (9)
def filter_None_dict(to_filter: dict) -> dict: def filter_None_dict(to_filter: dict) -> dict:
res = {k: v for k, v in to_filter.items() if v is not None} res = {k: v for k, v in to_filter.items() if v is not None}
return res return res
...@@ -78,7 +87,9 @@ def check_types(obj, *expected_types) -> bool: ...@@ -78,7 +87,9 @@ def check_types(obj, *expected_types) -> bool:
def has_expected_type(obj, expected_type) -> None: def has_expected_type(obj, expected_type) -> None:
if not check_type(obj=obj, expected_type=expected_type): if not check_type(obj=obj, expected_type=expected_type):
raise Exception(f"obj of type {type(obj)} has not the type or is not a subtype of {expected_type}") raise Exception(
f"obj of type {type(obj)} has not the type or is not a subtype of {expected_type}"
)
def has_attribute(obj, attr: str) -> None: def has_attribute(obj, attr: str) -> None:
...@@ -118,9 +129,11 @@ def get_cls_by_json(json_obj: dict): ...@@ -118,9 +129,11 @@ def get_cls_by_json(json_obj: dict):
module_name = get_json_attr_top(json_obj, "module") module_name = get_json_attr_top(json_obj, "module")
return get_cls_from_type_module(type_name, module_name) return get_cls_from_type_module(type_name, module_name)
def _has_base_type(obj: tp.Any) -> bool: def _has_base_type(obj: tp.Any) -> bool:
return isinstance(obj, (str, int, float, bool, type(None))) return isinstance(obj, (str, int, float, bool, type(None)))
def _obj_to_json(obj: tp.Any) -> tp.Any: def _obj_to_json(obj: tp.Any) -> tp.Any:
if _has_base_type(obj): if _has_base_type(obj):
return obj return obj
...@@ -134,6 +147,7 @@ def _obj_to_json(obj: tp.Any) -> tp.Any: ...@@ -134,6 +147,7 @@ def _obj_to_json(obj: tp.Any) -> tp.Any:
has_attribute(obj, "toJSON") has_attribute(obj, "toJSON")
return obj.toJSON() return obj.toJSON()
def list_tuple_to_json(list: list | tuple) -> list: def list_tuple_to_json(list: list | tuple) -> list:
json_list = [] json_list = []
for element in list: for element in list:
...@@ -141,16 +155,18 @@ def list_tuple_to_json(list: list | tuple) -> list: ...@@ -141,16 +155,18 @@ def list_tuple_to_json(list: list | tuple) -> list:
return json_list return json_list
def dict_to_json(data: dict) -> dict: def dict_to_json(data: dict) -> dict:
json_obj = {} json_obj = {}
for key, value in data.items(): for key, value in data.items():
key_json = _obj_to_json(key) key_json = _obj_to_json(key)
value_json = _obj_to_json(value) value_json = _obj_to_json(value)
assert(key_json not in json_obj) assert key_json not in json_obj
json_obj[key_json] = value_json json_obj[key_json] = value_json
return json_obj return json_obj
def _json_obj_to_dict(obj: tp.Any) -> tp.Any: def _json_obj_to_dict(obj: tp.Any) -> tp.Any:
if _has_base_type(obj): if _has_base_type(obj):
return obj return obj
...@@ -161,6 +177,7 @@ def _json_obj_to_dict(obj: tp.Any) -> tp.Any: ...@@ -161,6 +177,7 @@ def _json_obj_to_dict(obj: tp.Any) -> tp.Any:
else: else:
raise ValueError(f"cannot parse object with type {type(obj)} from json") raise ValueError(f"cannot parse object with type {type(obj)} from json")
def _json_dict_to_obj(json_obj: dict) -> tp.Any: def _json_dict_to_obj(json_obj: dict) -> tp.Any:
if "type" in json_obj and "module" in json_obj: if "type" in json_obj and "module" in json_obj:
# this seems to be a Python object that was converted to JSON # this seems to be a Python object that was converted to JSON
...@@ -171,6 +188,7 @@ def _json_dict_to_obj(json_obj: dict) -> tp.Any: ...@@ -171,6 +188,7 @@ def _json_dict_to_obj(json_obj: dict) -> tp.Any:
# this seems to be a plain dict # this seems to be a plain dict
return json_to_dict(json_obj) return json_to_dict(json_obj)
def json_array_to_list(array: list) -> list: def json_array_to_list(array: list) -> list:
data = [] data = []
for element in array: for element in array:
...@@ -178,12 +196,13 @@ def json_array_to_list(array: list) -> list: ...@@ -178,12 +196,13 @@ def json_array_to_list(array: list) -> list:
return data return data
def json_to_dict(json_obj: dict) -> dict: def json_to_dict(json_obj: dict) -> dict:
data = {} data = {}
for key, value in json_obj.items(): for key, value in json_obj.items():
key_dict = _json_obj_to_dict(key) key_dict = _json_obj_to_dict(key)
value_dict = _json_obj_to_dict(value) value_dict = _json_obj_to_dict(value)
assert(key_dict not in data) assert key_dict not in data
data[key_dict] = value_dict data[key_dict] = value_dict
return data return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment