"vscode:/vscode.git/clone" did not exist on "216c5b5c7152cc3ccf37a678782531ec8a331d14"
Commit d5a8ffbe authored by one's avatar one
Browse files

[hytop] Add hytop net

parent 952a9533
...@@ -124,6 +124,39 @@ Designed to be script-friendly: ...@@ -124,6 +124,39 @@ Designed to be script-friendly:
If no `--show*` flags are specified, hytop defaults to: If no `--show*` flags are specified, hytop defaults to:
`--showtemp --showpower --showsclk --showmemuse --showuse`. `--showtemp --showpower --showsclk --showmemuse --showuse`.
### SSH transport tuning
`hytop` keeps the same lightweight SSH pull model and enables SSH connection reuse by default in the core layer (applies to all subcommands using SSH collection):
- `ControlMaster=auto`
- `ControlPersist=30s`
- `ControlPath=~/.ssh/hytop-%C`
- `ServerAliveInterval=5`
- `ServerAliveCountMax=1`
## `hytop net`
Lightweight pull-based network monitor for Ethernet and InfiniBand across one or more hosts.
### Usage
```bash
# Local host, auto-discover eth+ib interfaces
hytop net
# Two hosts, 0.5-second interval
hytop -H node01,node02 -n 0.5 net
# IB-only monitoring
hytop net --kind ib
# Include only selected interfaces
hytop net --ifaces eth0,mlx5_0/p1
# Stop after 60 seconds (returns 124 on timeout)
hytop --timeout 60 net
```
## Development ## Development
Clone the repo and run `make setup` to create the virtual environment, install all dependencies (including dev), and configure pre-commit hooks: Clone the repo and run `make setup` to create the virtual environment, install all dependencies (including dev), and configure pre-commit hooks:
......
from __future__ import annotations from __future__ import annotations
import shlex
import subprocess import subprocess
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
...@@ -22,11 +23,59 @@ class CollectResult: ...@@ -22,11 +23,59 @@ class CollectResult:
error: str | None = None error: str | None = None
@dataclass(frozen=True)
class SSHOptions:
"""Optional SSH transport tuning options."""
use_mux: bool = False
control_persist_seconds: int = 30
control_path: str | None = "~/.ssh/hytop-%C"
server_alive_interval: int = 5
server_alive_count_max: int = 1
DEFAULT_SSH_OPTIONS = SSHOptions(
use_mux=True,
control_persist_seconds=30,
control_path="~/.ssh/hytop-%C",
server_alive_interval=5,
server_alive_count_max=1,
)
def _build_ssh_option_args(ssh_timeout: float, ssh_options: SSHOptions | None) -> list[str]:
effective = ssh_options or DEFAULT_SSH_OPTIONS
connect_timeout = max(1, round(ssh_timeout))
options = [
"-o",
"BatchMode=yes",
"-o",
f"ConnectTimeout={connect_timeout}",
"-o",
f"ServerAliveInterval={max(1, effective.server_alive_interval)}",
"-o",
f"ServerAliveCountMax={max(1, effective.server_alive_count_max)}",
]
if effective.use_mux:
options.extend(["-o", "ControlMaster=auto"])
options.extend(["-o", f"ControlPersist={max(1, effective.control_persist_seconds)}s"])
if effective.control_path:
options.extend(["-o", f"ControlPath={effective.control_path}"])
return options
def _build_remote_python_shell_command(python_code: str) -> str:
"""Build one safely-quoted remote shell command for python -c."""
return f"python3 -c {shlex.quote(python_code)}"
def collect_from_host( def collect_from_host(
host: str, host: str,
ssh_timeout: float, ssh_timeout: float,
cmd_timeout: float, cmd_timeout: float,
hy_smi_args: Sequence[str], hy_smi_args: Sequence[str],
ssh_options: SSHOptions | None = None,
) -> CollectResult: ) -> CollectResult:
"""Run hy-smi locally or via SSH and return raw output. """Run hy-smi locally or via SSH and return raw output.
...@@ -42,13 +91,9 @@ def collect_from_host( ...@@ -42,13 +91,9 @@ def collect_from_host(
if host in local_names: if host in local_names:
cmd = ["hy-smi", *hy_smi_args] cmd = ["hy-smi", *hy_smi_args]
else: else:
connect_timeout = max(1, round(ssh_timeout))
cmd = [ cmd = [
"ssh", "ssh",
"-o", *_build_ssh_option_args(ssh_timeout=ssh_timeout, ssh_options=ssh_options),
"BatchMode=yes",
"-o",
f"ConnectTimeout={connect_timeout}",
host, host,
"hy-smi", "hy-smi",
*hy_smi_args, *hy_smi_args,
...@@ -82,3 +127,69 @@ def collect_from_host( ...@@ -82,3 +127,69 @@ def collect_from_host(
) )
return CollectResult(host=host, stdout=proc.stdout, stderr=proc.stderr, error=None) return CollectResult(host=host, stdout=proc.stdout, stderr=proc.stderr, error=None)
def collect_python_from_host(
host: str,
ssh_timeout: float,
cmd_timeout: float,
python_code: str,
ssh_options: SSHOptions | None = None,
) -> CollectResult:
"""Run Python code locally or via SSH and return raw output."""
local_names = {"localhost", "127.0.0.1", "::1"}
if host in local_names:
cmd = ["python3", "-c", python_code]
else:
cmd = [
"ssh",
*_build_ssh_option_args(ssh_timeout=ssh_timeout, ssh_options=ssh_options),
host,
_build_remote_python_shell_command(python_code),
]
try:
proc = subprocess.run(
cmd,
check=False,
capture_output=True,
text=True,
timeout=cmd_timeout,
)
except subprocess.TimeoutExpired:
return CollectResult(
host=host,
stdout="",
stderr="",
error=f"timeout after {cmd_timeout:.1f}s",
)
except OSError as exc:
return CollectResult(host=host, stdout="", stderr="", error=str(exc))
if proc.returncode != 0:
stderr = proc.stderr.strip() or "unknown error"
return CollectResult(
host=host,
stdout=proc.stdout,
stderr=proc.stderr,
error=f"exit {proc.returncode}: {stderr}",
)
return CollectResult(host=host, stdout=proc.stdout, stderr=proc.stderr, error=None)
def build_remote_python_command(
host: str,
ssh_timeout: float,
python_code: str,
ssh_options: SSHOptions | None = None,
) -> list[str]:
"""Build command for remote Python execution."""
return [
"ssh",
*_build_ssh_option_args(ssh_timeout=ssh_timeout, ssh_options=ssh_options),
host,
_build_remote_python_shell_command(python_code),
]
...@@ -6,13 +6,15 @@ from hytop import __version__ ...@@ -6,13 +6,15 @@ from hytop import __version__
from hytop.core.validators import parse_csv_strings, parse_positive_float from hytop.core.validators import parse_csv_strings, parse_positive_float
from hytop.cpu.cli import app as cpu_app from hytop.cpu.cli import app as cpu_app
from hytop.gpu.cli import app as gpu_app from hytop.gpu.cli import app as gpu_app
from hytop.net.cli import app as net_app
app = typer.Typer( app = typer.Typer(
help="hytop toolkit command line", help="hytop toolkit command line",
context_settings={"help_option_names": ["-h", "--help"]}, context_settings={"help_option_names": ["-h", "--help"]},
) )
app.add_typer(cpu_app, name="cpu")
app.add_typer(gpu_app, name="gpu") app.add_typer(gpu_app, name="gpu")
app.add_typer(cpu_app, name="cpu")
app.add_typer(net_app, name="net")
def version_callback(value: bool) -> None: def version_callback(value: bool) -> None:
......
"""Network monitoring commands for hytop."""
from __future__ import annotations
import typer
from hytop.net.collector import parse_kind_filter, parse_name_filter
from hytop.net.service import run_monitor
app = typer.Typer(
add_completion=False,
context_settings={"help_option_names": ["-h", "--help"]},
help="Network monitoring commands.",
)
@app.callback(invoke_without_command=True)
def net(
ctx: typer.Context,
kind: str = typer.Option(
"all",
"--kind",
help="Network kind to monitor: all, eth, ib.",
),
ifaces: str = typer.Option(
"",
"--ifaces",
help="Comma-separated interface names to include.",
),
) -> None:
"""Network monitoring commands."""
if ctx.obj is None:
typer.echo("argument error: global options not available", err=True)
raise typer.Exit(code=2)
if kind not in {"all", "eth", "ib"}:
typer.echo("argument error: --kind must be one of all/eth/ib", err=True)
raise typer.Exit(code=2)
hosts = ctx.obj["hosts"]
interval = ctx.obj["interval"]
window_value = ctx.obj["window"]
timeout_value = ctx.obj.get("timeout")
include_filter = parse_name_filter(ifaces)
code = run_monitor(
hosts=hosts,
kind_filter=parse_kind_filter(kind),
include=include_filter,
window=window_value,
interval=interval,
timeout=timeout_value,
)
raise typer.Exit(code=code)
from __future__ import annotations
import json
import re
import subprocess
from pathlib import Path
from typing import Final
from hytop.net.models import NetCounter, NetKind
VIRTUAL_PREFIXES: Final[tuple[str, ...]] = (
"veth",
"docker",
"cni",
"flannel",
"virbr",
"br-",
"tunl",
)
REMOTE_COLLECTOR_PY: Final[str] = r"""
import json
import re
import subprocess
from pathlib import Path
vprefix = ("veth", "docker", "cni", "flannel", "virbr", "br-", "tunl")
def read_int(path):
try:
return int(path.read_text().strip())
except (OSError, ValueError):
return None
def read_text(path):
try:
return path.read_text().strip()
except OSError:
return None
def normalize_link_layer(value):
if not value:
return "ib"
text = value.strip().lower()
return "eth" if text.startswith("ethernet") else "ib"
def parse_ibdev2netdev_output(stdout):
mapping = {}
pattern = re.compile(
r"^(?P<dev>\S+)\s+port\s+(?P<port>\d+)\s+==>\s+(?P<netdev>\S+)\s+\((?P<state>[^)]+)\)$"
)
for line in stdout.splitlines():
m = pattern.match(line.strip())
if not m:
continue
dev = m.group("dev")
port = m.group("port")
netdev = m.group("netdev")
state = m.group("state").strip().lower()
mapping[(dev, port)] = {"netdev": netdev, "state": state}
return mapping
def get_ibdev2netdev_mapping():
try:
proc = subprocess.run(
["ibdev2netdev"],
check=False,
capture_output=True,
text=True,
timeout=2.0,
)
except (OSError, subprocess.TimeoutExpired):
return {}
if proc.returncode != 0:
return {}
return parse_ibdev2netdev_output(proc.stdout)
def is_ib_netdev(iface_path):
type_text = read_text(iface_path / "type")
return type_text == "32"
def want(name, include):
if include and name not in include:
return False
if name == "lo" or name.startswith(vprefix):
return False
return True
def collect(kind_filter, include):
out = []
ib_map = get_ibdev2netdev_mapping()
ib_netdevs = {item["netdev"] for item in ib_map.values()}
if "eth" in kind_filter:
for p in Path("/sys/class/net").iterdir():
if not p.is_dir():
continue
name = p.name
if name in ib_netdevs:
continue
if is_ib_netdev(p):
continue
if not want(name, include):
continue
rx = read_int(p / "statistics" / "rx_bytes")
tx = read_int(p / "statistics" / "tx_bytes")
if rx is None or tx is None:
continue
out.append(
{
"kind": "eth",
"name": name,
"rx_bytes": rx,
"tx_bytes": tx,
"link_state": (read_text(p / "operstate") or "unknown").lower(),
}
)
if ("ib" in kind_filter) or ("eth" in kind_filter):
base = Path("/sys/class/infiniband")
if base.exists():
for dev in base.iterdir():
ports = dev / "ports"
if not ports.is_dir():
continue
for port in ports.iterdir():
pname = f"{dev.name}/p{port.name}"
mapped = ib_map.get((dev.name, port.name))
mapped_netdev = mapped["netdev"] if mapped is not None else None
mode = normalize_link_layer(read_text(port / "link_layer"))
if mode not in kind_filter:
continue
candidate_name = mapped_netdev or pname
if not want(candidate_name, include):
continue
if include and pname in include:
pass
elif include and candidate_name not in include:
continue
rxw = read_int(port / "counters" / "port_rcv_data")
txw = read_int(port / "counters" / "port_xmit_data")
if rxw is None or txw is None:
continue
out.append(
{
"kind": mode,
"name": candidate_name,
"rx_bytes": rxw * 4,
"tx_bytes": txw * 4,
"link_state": (
mapped["state"]
if mapped is not None
else (read_text(port / "state") or "unknown")
),
"device_name": pname,
}
)
return out
payload = json.loads(INPUT_JSON)
res = {
"counters": collect(
kind_filter=set(payload.get("kind_filter", ["eth", "ib"])),
include=set(payload.get("include", [])),
)
}
print(json.dumps(res, separators=(",", ":")))
""".strip()
def parse_kind_filter(kind: str) -> set[NetKind]:
if kind == "all":
return {"eth", "ib"}
return {kind} # type: ignore[return-value]
def parse_name_filter(raw: str) -> set[str]:
return {token.strip() for token in raw.split(",") if token.strip()}
def should_keep_iface(
name: str,
include: set[str],
) -> bool:
if include and name not in include:
return False
if name == "lo":
return False
return not name.startswith(VIRTUAL_PREFIXES)
def _safe_read_int(path: Path) -> int | None:
try:
return int(path.read_text().strip())
except (OSError, ValueError):
return None
def _safe_read_text(path: Path) -> str | None:
try:
return path.read_text().strip()
except OSError:
return None
def _is_infiniband_netdev(iface: Path) -> bool:
return _safe_read_text(iface / "type") == "32"
def _normalize_ib_port_state(raw: str | None) -> str:
if not raw:
return "unknown"
# Example: "4: ACTIVE"
if ":" in raw:
_, _, state = raw.partition(":")
return state.strip().lower()
return raw.strip().lower()
def _normalize_link_layer(raw: str | None) -> NetKind:
if not raw:
return "ib"
text = raw.strip().lower()
return "eth" if text.startswith("ethernet") else "ib"
def _should_keep_ib_iface(
port_name: str,
mapped_netdev: str | None,
include: set[str],
) -> bool:
names = {port_name}
if mapped_netdev:
names.add(mapped_netdev)
if include and not (names & include):
return False
if mapped_netdev == "lo":
return False
return not (mapped_netdev and mapped_netdev.startswith(VIRTUAL_PREFIXES))
def _parse_ibdev2netdev_output(stdout: str) -> dict[tuple[str, str], tuple[str, str]]:
"""Parse `ibdev2netdev` output into (dev,port)->(netdev,state)."""
pattern = re.compile(
r"^(?P<dev>\S+)\s+port\s+(?P<port>\d+)\s+==>\s+(?P<netdev>\S+)\s+\((?P<state>[^)]+)\)$"
)
mapping: dict[tuple[str, str], tuple[str, str]] = {}
for line in stdout.splitlines():
match = pattern.match(line.strip())
if not match:
continue
key = (match.group("dev"), match.group("port"))
value = (match.group("netdev"), match.group("state").strip().lower())
mapping[key] = value
return mapping
def _get_ibdev2netdev_mapping() -> dict[tuple[str, str], tuple[str, str]]:
"""Best-effort ibdev2netdev mapping for classification/state labeling."""
try:
proc = subprocess.run(
["ibdev2netdev"],
check=False,
capture_output=True,
text=True,
timeout=2.0,
)
except (OSError, subprocess.TimeoutExpired):
return {}
if proc.returncode != 0:
return {}
return _parse_ibdev2netdev_output(proc.stdout)
def collect_local_counters(
kind_filter: set[NetKind],
include: set[str],
) -> dict[str, NetCounter]:
counters: dict[str, NetCounter] = {}
ib_mapping = _get_ibdev2netdev_mapping()
ib_netdevs = {netdev for netdev, _state in ib_mapping.values()}
if "eth" in kind_filter:
for iface in Path("/sys/class/net").iterdir():
if not iface.is_dir():
continue
name = iface.name
if name in ib_netdevs:
continue
if _is_infiniband_netdev(iface):
continue
if not should_keep_iface(name, include):
continue
rx = _safe_read_int(iface / "statistics" / "rx_bytes")
tx = _safe_read_int(iface / "statistics" / "tx_bytes")
if rx is None or tx is None:
continue
counter = NetCounter(
kind="eth",
name=name,
rx_bytes=rx,
tx_bytes=tx,
link_state=(_safe_read_text(iface / "operstate") or "unknown").lower(),
)
counters[counter.key] = counter
if ("ib" in kind_filter) or ("eth" in kind_filter):
ib_root = Path("/sys/class/infiniband")
if ib_root.exists():
for hca in ib_root.iterdir():
ports_dir = hca / "ports"
if not ports_dir.is_dir():
continue
for port in ports_dir.iterdir():
name = f"{hca.name}/p{port.name}"
mapped = ib_mapping.get((hca.name, port.name))
mapped_netdev = mapped[0] if mapped is not None else None
link_layer = _safe_read_text(port / "link_layer")
mode = _normalize_link_layer(link_layer)
if mode not in kind_filter:
continue
if not _should_keep_ib_iface(
port_name=name,
mapped_netdev=mapped_netdev,
include=include,
):
continue
iface_name = mapped_netdev or name
rx_words = _safe_read_int(port / "counters" / "port_rcv_data")
tx_words = _safe_read_int(port / "counters" / "port_xmit_data")
if rx_words is None or tx_words is None:
continue
counter = NetCounter(
kind=mode,
name=iface_name,
rx_bytes=rx_words * 4,
tx_bytes=tx_words * 4,
link_state=(
mapped[1]
if mapped is not None
else _normalize_ib_port_state(_safe_read_text(port / "state"))
),
device_name=name,
)
counters[counter.key] = counter
return counters
def parse_counter_payload(payload: str) -> dict[str, NetCounter]:
try:
data = json.loads(payload)
except json.JSONDecodeError as exc:
raise ValueError(f"invalid collector payload: {exc}") from exc
raw_counters = data.get("counters")
if not isinstance(raw_counters, list):
raise ValueError("invalid collector payload: missing counters list")
counters: dict[str, NetCounter] = {}
for item in raw_counters:
if not isinstance(item, dict):
continue
kind = item.get("kind")
name = item.get("name")
rx = item.get("rx_bytes")
tx = item.get("tx_bytes")
link_state = item.get("link_state")
device_name = item.get("device_name")
if kind not in {"eth", "ib"}:
continue
if not isinstance(name, str):
continue
try:
rx_value = int(rx)
tx_value = int(tx)
except (TypeError, ValueError):
continue
normalized_state: str | None
if isinstance(link_state, str):
normalized_state = (
_normalize_ib_port_state(link_state) if kind == "ib" else link_state.strip().lower()
)
else:
normalized_state = None
counter = NetCounter(
kind=kind,
name=name,
rx_bytes=rx_value,
tx_bytes=tx_value,
link_state=normalized_state,
device_name=device_name if isinstance(device_name, str) else None,
)
counters[counter.key] = counter
return counters
from __future__ import annotations
import threading
from dataclasses import dataclass
from typing import Literal
from hytop.core.history import SlidingHistory
NetKind = Literal["eth", "ib"]
@dataclass(frozen=True)
class NetCounter:
"""One network interface counter snapshot."""
kind: NetKind
name: str
rx_bytes: int
tx_bytes: int
link_state: str | None = None
device_name: str | None = None
@property
def key(self) -> str:
return f"{self.kind}:{self.name}"
@dataclass
class NodeCounterSnapshot:
"""Collection result for one host."""
host: str
counters: dict[str, NetCounter]
sample_ts: float = 0.0
error: str | None = None
@dataclass
class RateSample:
"""Rate sample for one host+iface key."""
ts: float
rx_bps: float
tx_bps: float
@dataclass
class HostSnapshot:
seq: int = 0
result: NodeCounterSnapshot | None = None
@dataclass
class MonitorState:
max_window: float
histories: dict[tuple[str, str], SlidingHistory]
monitored_keys: set[tuple[str, str]]
latest_counter_by_key: dict[tuple[str, str], NetCounter]
previous_counter_by_key: dict[tuple[str, str], NetCounter]
errors: dict[str, str]
host_state: dict[str, HostSnapshot]
processed_seq: dict[str, int]
state_lock: threading.Lock
stop_event: threading.Event
from __future__ import annotations
import time
from rich import box
from rich.console import Group
from rich.table import Table
from hytop.core.history import SlidingHistory
from hytop.gpu.render import fmt_elapsed, fmt_window
from hytop.net.models import NetCounter, RateSample
def format_rate(value: float) -> str:
units = ["B/s", "KB/s", "MB/s", "GB/s", "TB/s"]
output = float(value)
idx = 0
while output >= 1000.0 and idx < len(units) - 1:
output /= 1000.0
idx += 1
return f"{output:7.2f} {units[idx]}"
def split_iface_key(iface_key: str) -> tuple[str, str]:
kind, _, name = iface_key.partition(":")
return kind, name
def format_iface_name(name: str, link_state: str | None) -> str:
if not link_state:
return name
normalized = link_state.strip().lower()
is_down = normalized in {"down", "disabled", "init", "inactive"} or normalized.startswith(
"down"
)
if is_down:
return f"{name} (down)"
return name
def build_renderable(
window: float,
hosts: list[str],
histories: dict[tuple[str, str], SlidingHistory],
monitored_keys: set[tuple[str, str]],
latest_counter_by_key: dict[tuple[str, str], NetCounter],
errors: dict[str, str],
poll_interval: float,
elapsed_since_start: float,
) -> Group:
now = time.monotonic()
host_rank = {host: idx for idx, host in enumerate(hosts)}
key_list = sorted(monitored_keys, key=lambda x: (host_rank.get(x[0], len(hosts)), x[1]))
title = (
f"hytop net | interval={poll_interval:.2f}s | elapsed={fmt_elapsed(elapsed_since_start)}"
)
table = Table(title=title, box=box.MINIMAL_HEAVY_HEAD, expand=True)
table.add_column("Host", justify="left", no_wrap=True)
table.add_column("Mode", justify="left")
table.add_column("Device", justify="left")
table.add_column("NIC", justify="left")
table.add_column("RX", justify="right")
table.add_column("TX", justify="right")
table.add_column(f"RX@{fmt_window(window)}", justify="right")
table.add_column(f"TX@{fmt_window(window)}", justify="right")
for key in key_list:
history = histories.get(key)
latest_counter = latest_counter_by_key.get(key)
if history is None or latest_counter is None:
continue
latest_rate = history.latest()
if latest_rate is None:
continue
if not isinstance(latest_rate, RateSample):
continue
stale = (now - latest_rate.ts) > window
host, iface_key = key
mode, name = split_iface_key(iface_key)
device_text = latest_counter.device_name or name
nic_text = format_iface_name(name, latest_counter.link_state)
if stale:
table.add_row(host, mode, device_text, nic_text, "-", "-", "-", "-")
continue
rx_avg = history.avg("rx_bps", window, now)
tx_avg = history.avg("tx_bps", window, now)
table.add_row(
host,
mode,
device_text,
nic_text,
format_rate(latest_rate.rx_bps),
format_rate(latest_rate.tx_bps),
format_rate(rx_avg),
format_rate(tx_avg),
)
if table.row_count == 0:
table.add_row("No data yet.", "", "", "", "", "", "", "")
if not errors:
return Group(table)
err_table = Table(title="Host errors", box=box.MINIMAL_HEAVY_HEAD, expand=True)
err_table.add_column("Host", justify="left", no_wrap=True)
err_table.add_column("Error", justify="left")
for host in hosts:
err = errors.get(host)
if err:
err_table.add_row(host, err)
return Group(table, err_table)
from __future__ import annotations
import json
import sys
import threading
import time
from rich.console import Console
from rich.live import Live
from hytop.core.history import SlidingHistory
from hytop.core.ssh import SSHOptions, collect_python_from_host
from hytop.net.collector import REMOTE_COLLECTOR_PY, collect_local_counters, parse_counter_payload
from hytop.net.models import HostSnapshot, MonitorState, NetKind, NodeCounterSnapshot, RateSample
from hytop.net.render import build_renderable
LOCAL_HOSTS = {"localhost", "127.0.0.1", "::1"}
def _build_remote_python_script(payload: dict[str, object]) -> str:
return f"INPUT_JSON={json.dumps(payload)!r}\n{REMOTE_COLLECTOR_PY}"
def _collect_remote_counters(
host: str,
ssh_timeout: float,
cmd_timeout: float,
kind_filter: set[NetKind],
include: set[str],
ssh_options: SSHOptions | None = None,
) -> NodeCounterSnapshot:
payload = {
"kind_filter": sorted(kind_filter),
"include": sorted(include),
}
raw = collect_python_from_host(
host=host,
ssh_timeout=ssh_timeout,
cmd_timeout=cmd_timeout,
python_code=_build_remote_python_script(payload),
ssh_options=ssh_options,
)
if raw.error:
return NodeCounterSnapshot(host=host, counters={}, error=raw.error)
try:
counters = parse_counter_payload(raw.stdout)
except ValueError as exc:
return NodeCounterSnapshot(host=host, counters={}, error=str(exc))
return NodeCounterSnapshot(host=host, counters=counters, sample_ts=time.monotonic())
def collect_node(
host: str,
ssh_timeout: float,
cmd_timeout: float,
kind_filter: set[NetKind],
include: set[str],
ssh_options: SSHOptions | None = None,
) -> NodeCounterSnapshot:
if host in LOCAL_HOSTS:
counters = collect_local_counters(
kind_filter=kind_filter,
include=include,
)
return NodeCounterSnapshot(host=host, counters=counters, sample_ts=time.monotonic())
return _collect_remote_counters(
host=host,
ssh_timeout=ssh_timeout,
cmd_timeout=cmd_timeout,
kind_filter=kind_filter,
include=include,
ssh_options=ssh_options,
)
def host_collector_loop(
host: str,
ssh_timeout: float,
cmd_timeout: float,
interval: float,
kind_filter: set[NetKind],
include: set[str],
ssh_options: SSHOptions | None,
state: dict[str, HostSnapshot],
state_lock: threading.Lock,
stop_event: threading.Event,
) -> None:
while not stop_event.is_set():
started = time.monotonic()
result = collect_node(
host=host,
ssh_timeout=ssh_timeout,
cmd_timeout=cmd_timeout,
kind_filter=kind_filter,
include=include,
ssh_options=ssh_options,
)
with state_lock:
snapshot = state[host]
snapshot.seq += 1
snapshot.result = result
sleep_s = max(0.0, interval - (time.monotonic() - started))
if stop_event.wait(sleep_s):
break
def init_monitor_state(hosts: list[str], max_window: float) -> MonitorState:
return MonitorState(
max_window=max_window,
histories={},
monitored_keys=set(),
latest_counter_by_key={},
previous_counter_by_key={},
errors={},
host_state={host: HostSnapshot() for host in hosts},
processed_seq={host: 0 for host in hosts},
state_lock=threading.Lock(),
stop_event=threading.Event(),
)
def start_collectors(
hosts: list[str],
ssh_timeout: float,
cmd_timeout: float,
interval: float,
kind_filter: set[NetKind],
include: set[str],
ssh_options: SSHOptions | None,
state: MonitorState,
) -> list[threading.Thread]:
workers: list[threading.Thread] = []
for host in hosts:
worker = threading.Thread(
target=host_collector_loop,
args=(
host,
ssh_timeout,
cmd_timeout,
interval,
kind_filter,
include,
ssh_options,
state.host_state,
state.state_lock,
state.stop_event,
),
daemon=True,
name=f"net-collector-{host}",
)
worker.start()
workers.append(worker)
return workers
def drain_pending_nodes(hosts: list[str], state: MonitorState) -> list[NodeCounterSnapshot]:
nodes: list[NodeCounterSnapshot] = []
with state.state_lock:
for host in hosts:
snapshot = state.host_state[host]
if snapshot.seq <= state.processed_seq[host]:
continue
state.processed_seq[host] = snapshot.seq
if snapshot.result is not None:
nodes.append(snapshot.result)
return nodes
def apply_node_results(
nodes: list[NodeCounterSnapshot], interval: float, state: MonitorState
) -> None:
for node in nodes:
if node.error:
state.errors[node.host] = node.error
continue
state.errors.pop(node.host, None)
for iface_key, counter in node.counters.items():
key = (node.host, iface_key)
state.monitored_keys.add(key)
state.latest_counter_by_key[key] = counter
prev = state.previous_counter_by_key.get(key)
state.previous_counter_by_key[key] = counter
if prev is None:
continue
delta_rx = counter.rx_bytes - prev.rx_bytes
delta_tx = counter.tx_bytes - prev.tx_bytes
if delta_rx < 0 or delta_tx < 0:
continue
history = state.histories.get(key)
if history is None:
history = SlidingHistory(max_window_s=state.max_window)
state.histories[key] = history
history.add(
RateSample(
ts=node.sample_ts,
rx_bps=delta_rx / interval,
tx_bps=delta_tx / interval,
)
)
def run_monitor(
hosts: list[str],
kind_filter: set[NetKind],
include: set[str],
window: float,
interval: float,
timeout: float | None,
ssh_options: SSHOptions | None = None,
) -> int:
if interval <= 0:
print("argument error: --interval must be > 0", file=sys.stderr)
return 2
if interval > window:
print("argument error: --interval must be <= --window value", file=sys.stderr)
return 2
ssh_timeout = min(max(5 * interval, 2.0), 5.0)
cmd_timeout = min(max(10 * interval, 5.0), 10.0)
console = Console()
err_console = Console(stderr=True)
started = time.monotonic()
render_interval = min(interval, 0.5)
state = init_monitor_state(hosts=hosts, max_window=window)
try:
with Live(console=console, auto_refresh=False, screen=True) as live:
workers = start_collectors(
hosts=hosts,
ssh_timeout=ssh_timeout,
cmd_timeout=cmd_timeout,
interval=interval,
kind_filter=kind_filter,
include=include,
ssh_options=ssh_options,
state=state,
)
try:
while True:
loop_started = time.monotonic()
apply_node_results(
nodes=drain_pending_nodes(hosts=hosts, state=state),
interval=interval,
state=state,
)
live.update(
build_renderable(
window=window,
hosts=hosts,
histories=state.histories,
monitored_keys=state.monitored_keys,
latest_counter_by_key=state.latest_counter_by_key,
errors=state.errors,
poll_interval=interval,
elapsed_since_start=time.monotonic() - started,
),
refresh=True,
)
if timeout is not None and (time.monotonic() - started) >= timeout:
return 124
time.sleep(max(0.0, render_interval - (time.monotonic() - loop_started)))
finally:
state.stop_event.set()
for worker in workers:
worker.join(timeout=min(0.2, interval))
except KeyboardInterrupt:
err_console.print("status: interrupted by user", style="yellow")
return 130
"""Tests for hytop.net.collector helpers."""
from __future__ import annotations
from hytop.net.collector import (
_normalize_ib_port_state,
_parse_ibdev2netdev_output,
parse_counter_payload,
parse_kind_filter,
parse_name_filter,
should_keep_iface,
)
class TestParseKindFilter:
def test_all(self):
assert parse_kind_filter("all") == {"eth", "ib"}
def test_eth_only(self):
assert parse_kind_filter("eth") == {"eth"}
def test_ib_only(self):
assert parse_kind_filter("ib") == {"ib"}
class TestParseNameFilter:
def test_empty(self):
assert parse_name_filter("") == set()
def test_trim_and_split(self):
assert parse_name_filter(" eth0, ib0 ,eth1 ") == {"eth0", "ib0", "eth1"}
class TestShouldKeepIface:
def test_include_filter(self):
assert should_keep_iface("eth0", include={"eth0"})
assert not should_keep_iface("eth1", include={"eth0"})
def test_virtual_excluded_by_default(self):
assert not should_keep_iface("lo", include=set())
assert not should_keep_iface("docker0", include=set())
class TestParseCounterPayload:
def test_parse_valid_payload(self):
payload = (
'{"counters":['
'{"kind":"eth","name":"eth0","rx_bytes":10,"tx_bytes":20,"link_state":"up"},'
'{"kind":"ib","name":"mlx5_0/p1","rx_bytes":30,"tx_bytes":40,"link_state":"4: ACTIVE"}'
"]}"
)
result = parse_counter_payload(payload)
assert result["eth:eth0"].rx_bytes == 10
assert result["ib:mlx5_0/p1"].tx_bytes == 40
assert result["eth:eth0"].link_state == "up"
assert result["ib:mlx5_0/p1"].link_state == "active"
def test_invalid_json_raises(self):
try:
parse_counter_payload("{")
except ValueError as exc:
assert "invalid collector payload" in str(exc)
else:
raise AssertionError("expected ValueError")
class TestNormalizeIbState:
def test_state_with_numeric_prefix(self):
assert _normalize_ib_port_state("1: DOWN") == "down"
def test_state_without_prefix(self):
assert _normalize_ib_port_state("ACTIVE") == "active"
class TestParseIbdev2netdevOutput:
def test_parse_mapping_lines(self):
stdout = "mlx5_0 port 1 ==> p6p1 (Down)\nmlx5_3 port 1 ==> ibs61f0 (Up)\n"
mapping = _parse_ibdev2netdev_output(stdout)
assert mapping[("mlx5_0", "1")] == ("p6p1", "down")
assert mapping[("mlx5_3", "1")] == ("ibs61f0", "up")
"""Tests for hytop.net.render helpers."""
from __future__ import annotations
from hytop.net.render import format_iface_name
class TestFormatIfaceName:
def test_keep_name_when_up(self):
assert format_iface_name("eth0", "up") == "eth0"
def test_mark_down_state(self):
assert format_iface_name("eth0", "down") == "eth0 (down)"
def test_mark_init_state(self):
assert format_iface_name("mlx5_0/p1", "init") == "mlx5_0/p1 (down)"
"""Tests for hytop.net.service core logic."""
from __future__ import annotations
from hytop.net.models import NetCounter, NodeCounterSnapshot
from hytop.net.service import apply_node_results, init_monitor_state
class TestApplyNodeResults:
def test_second_sample_produces_rate(self):
state = init_monitor_state(hosts=["localhost"], max_window=10.0)
first = NodeCounterSnapshot(
host="localhost",
counters={"eth:eth0": NetCounter(kind="eth", name="eth0", rx_bytes=100, tx_bytes=200)},
sample_ts=1.0,
)
second = NodeCounterSnapshot(
host="localhost",
counters={"eth:eth0": NetCounter(kind="eth", name="eth0", rx_bytes=300, tx_bytes=500)},
sample_ts=2.0,
)
apply_node_results([first], interval=1.0, state=state)
apply_node_results([second], interval=1.0, state=state)
history = state.histories[("localhost", "eth:eth0")]
latest = history.latest()
assert latest is not None
assert latest.rx_bps == 200.0
assert latest.tx_bps == 300.0
def test_counter_reset_is_skipped(self):
state = init_monitor_state(hosts=["localhost"], max_window=10.0)
first = NodeCounterSnapshot(
host="localhost",
counters={"eth:eth0": NetCounter(kind="eth", name="eth0", rx_bytes=100, tx_bytes=100)},
sample_ts=1.0,
)
reset = NodeCounterSnapshot(
host="localhost",
counters={"eth:eth0": NetCounter(kind="eth", name="eth0", rx_bytes=50, tx_bytes=50)},
sample_ts=2.0,
)
apply_node_results([first], interval=1.0, state=state)
apply_node_results([reset], interval=1.0, state=state)
assert ("localhost", "eth:eth0") not in state.histories
...@@ -5,7 +5,7 @@ from __future__ import annotations ...@@ -5,7 +5,7 @@ from __future__ import annotations
import subprocess import subprocess
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from hytop.core.ssh import collect_from_host from hytop.core.ssh import SSHOptions, collect_from_host, collect_python_from_host
def _make_proc(returncode=0, stdout="", stderr=""): def _make_proc(returncode=0, stdout="", stderr=""):
...@@ -104,3 +104,42 @@ class TestCollectFromHostRemote: ...@@ -104,3 +104,42 @@ class TestCollectFromHostRemote:
cmd = mock_run.call_args[0][0] cmd = mock_run.call_args[0][0]
assert "--json" in cmd assert "--json" in cmd
assert "--showtemp" in cmd assert "--showtemp" in cmd
class TestCollectPythonFromHost:
@patch("hytop.core.ssh.subprocess.run")
def test_remote_uses_ssh(self, mock_run):
mock_run.return_value = _make_proc(stdout='{"ok":1}')
collect_python_from_host("node01", ssh_timeout=5, cmd_timeout=10, python_code="print('ok')")
cmd = mock_run.call_args[0][0]
assert cmd[0] == "ssh"
assert any(str(part).startswith("python3 -c ") for part in cmd)
@patch("hytop.core.ssh.subprocess.run")
def test_mux_options_applied(self, mock_run):
mock_run.return_value = _make_proc(stdout='{"ok":1}')
collect_python_from_host(
"node01",
ssh_timeout=5,
cmd_timeout=10,
python_code="print('ok')",
ssh_options=SSHOptions(use_mux=True),
)
cmd = mock_run.call_args[0][0]
assert "ControlMaster=auto" in cmd
assert any(str(item).startswith("ControlPersist=") for item in cmd)
@patch("hytop.core.ssh.subprocess.run")
def test_multiline_script_is_single_remote_arg(self, mock_run):
mock_run.return_value = _make_proc(stdout='{"ok":1}')
collect_python_from_host(
"node01",
ssh_timeout=5,
cmd_timeout=10,
python_code="import json\nprint(json.dumps({'ok': True}))",
)
cmd = mock_run.call_args[0][0]
remote_parts = [
part for part in cmd if isinstance(part, str) and part.startswith("python3 -c ")
]
assert len(remote_parts) == 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment