Unverified Commit 306650af authored by one's avatar one Committed by GitHub
Browse files

[hytop] Migrate rich to textual (#1)

* [hytop] Migrate from rich to textual

* [hytop] Fix ssh contention, add sorting bindings

* [hytop] Prevent clearing tables
parent 153a2737
...@@ -10,7 +10,7 @@ readme = "README.md" ...@@ -10,7 +10,7 @@ readme = "README.md"
license = { text = "MIT" } license = { text = "MIT" }
authors = [{ name = "alephpiece", email = "wangan.cs@gmail.com" }] authors = [{ name = "alephpiece", email = "wangan.cs@gmail.com" }]
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = ["rich>=14", "typer>=0.23"] dependencies = ["textual>=8.0.0", "typer>=0.23"]
keywords = ["monitoring", "gpu", "dcu", "hygon", "hytop"] keywords = ["monitoring", "gpu", "dcu", "hygon", "hytop"]
classifiers = [ classifiers = [
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
......
from __future__ import annotations
from collections.abc import Callable
from typing import TypeVar
T = TypeVar("T")
K = TypeVar("K")
def sort_with_missing_last(
base_keys: list[K],
value_getter: Callable[[K], T | None],
desc: bool,
) -> list[K]:
"""Sort keys by value while keeping missing values at the end."""
present: list[tuple[K, T]] = []
missing: list[K] = []
for key in base_keys:
value = value_getter(key)
if value is None:
missing.append(key)
continue
present.append((key, value))
present.sort(key=lambda item: item[1], reverse=desc)
return [key for key, _ in present] + missing
def next_sort_field_index(current: int, total_fields: int) -> int:
"""Cycle sort field index in [0, total_fields)."""
if total_fields <= 0:
return 0
return (current + 1) % total_fields
...@@ -47,6 +47,8 @@ def _build_ssh_option_args(ssh_timeout: float, ssh_options: SSHOptions | None) - ...@@ -47,6 +47,8 @@ def _build_ssh_option_args(ssh_timeout: float, ssh_options: SSHOptions | None) -
effective = ssh_options or DEFAULT_SSH_OPTIONS effective = ssh_options or DEFAULT_SSH_OPTIONS
connect_timeout = max(1, round(ssh_timeout)) connect_timeout = max(1, round(ssh_timeout))
options = [ options = [
"-n",
"-T",
"-o", "-o",
"BatchMode=yes", "BatchMode=yes",
"-o", "-o",
...@@ -105,6 +107,7 @@ def collect_from_host( ...@@ -105,6 +107,7 @@ def collect_from_host(
check=False, check=False,
capture_output=True, capture_output=True,
text=True, text=True,
stdin=subprocess.DEVNULL,
timeout=cmd_timeout, timeout=cmd_timeout,
) )
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
...@@ -155,6 +158,7 @@ def collect_python_from_host( ...@@ -155,6 +158,7 @@ def collect_python_from_host(
check=False, check=False,
capture_output=True, capture_output=True,
text=True, text=True,
stdin=subprocess.DEVNULL,
timeout=cmd_timeout, timeout=cmd_timeout,
) )
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
......
from __future__ import annotations
import time
from collections.abc import Sequence
from typing import ClassVar
from textual.app import App, ComposeResult
from textual.containers import Vertical
from textual.widgets import DataTable, Footer, Header, Label
from hytop.core.format import fmt_elapsed, fmt_window
from hytop.core.sorting import next_sort_field_index, sort_with_missing_last
from hytop.gpu.metrics import RenderColumn, render_columns_for_show_flags
from hytop.gpu.models import MonitorState
from hytop.gpu.service import (
apply_node_results,
availability_ready,
drain_pending_nodes,
)
from hytop.gpu.sort import sort_gpu_keys_grouped
def _format_metric(metric: str, value: object) -> str:
"""Format a GPU metric value for display."""
if value is None:
return "-"
if metric == "temp_c":
return f"{float(value):.1f}C"
if metric == "avg_pwr_w":
return f"{float(value):.1f}W"
if metric in {"vram_pct", "gpu_pct"}:
return f"{float(value):.2f}%"
if metric == "sclk_mhz":
return f"{float(value):.0f}MHz"
return str(value)
class GpuMonitorApp(App[int]):
"""Textual TUI application for real-time GPU monitoring via hy-smi.
This replaces the previous Rich Live + Table rendering loop. The GPU data
is collected by background threads managed by the caller; this App simply
polls the shared MonitorState on a timer and updates the DataTable in place.
The app returns an integer exit code via ``App.exit(return_code)`` so that
the calling ``run_monitor`` function can propagate the correct process exit
code to the shell.
"""
CSS = """
Screen {
layout: vertical;
}
#gpu-table {
height: 1fr;
border: none;
}
#error-label {
height: auto;
padding: 0 1;
color: $warning;
}
#error-label.hidden {
display: none;
}
"""
BINDINGS: ClassVar[list] = [
("q", "quit_with_code", "Quit"),
("ctrl+c", "quit_with_code", "Quit"),
("s", "next_sort_field", "Next Sort Field"),
("S", "toggle_sort_order", "Toggle Sort Order"),
("g", "group_sort", "Grouped Sort"),
]
SORT_FIELDS: ClassVar[list[tuple[str, str]]] = [
("host", "Host"),
("gpu", "GPU"),
("temp_c", "Temp"),
("avg_pwr_w", "Power"),
("vram_pct", "VRAM"),
("gpu_pct", "GPU%"),
("sclk_mhz", "SCLK"),
]
def __init__(
self,
hosts: list[str],
show_flags: Sequence[str],
window: float,
interval: float,
wait_idle: bool,
wait_idle_duration: float,
timeout: float | None,
state: MonitorState,
device_filter: set[int] | None,
started: float,
**kwargs: object,
) -> None:
super().__init__(**kwargs)
self._hosts = hosts
self._show_flags = list(show_flags)
self._window = window
self._interval = interval
self._wait_idle = wait_idle
self._wait_idle_duration = wait_idle_duration
self._timeout = timeout
self._state = state
self._device_filter = device_filter
self._started = started
self._columns: list[RenderColumn] = render_columns_for_show_flags(show_flags)
# Map (host, gpu) -> DataTable row key for efficient cell updates.
self._row_keys: dict[tuple[str, int], str] = {}
# Last displayed key order; when unchanged we use update_cell to preserve scroll/cursor.
self._ordered_keys: list[tuple[str, int]] = []
# Column keys for metric cells (host, gpu, then metric columns).
self._cell_column_keys: list[str] = ["host", "gpu"]
for col in self._columns:
self._cell_column_keys.append(col.metric)
if col.avg_label is not None:
self._cell_column_keys.append(f"{col.metric}_avg")
self._sort_mode: str = "grouped"
self._sort_desc: bool = False
self._sort_field_index: int = 0
# ------------------------------------------------------------------
# Composition
# ------------------------------------------------------------------
def compose(self) -> ComposeResult:
yield Header()
with Vertical():
yield DataTable(id="gpu-table", cursor_type="row")
yield Label("", id="error-label")
yield Footer()
def on_mount(self) -> None:
"""Initialise columns and start the polling timer."""
table: DataTable = self.query_one("#gpu-table", DataTable)
table.add_column("Host", key="host")
table.add_column("GPU", key="gpu")
for col in self._columns:
table.add_column(col.label, key=col.metric)
if col.avg_label is not None:
avg_key = f"{col.metric}_avg"
table.add_column(
f"{col.avg_label}@{fmt_window(self._window)}",
key=avg_key,
)
render_interval = min(self._interval, 0.5)
self.set_interval(render_interval, self._tick)
# ------------------------------------------------------------------
# Actions
# ------------------------------------------------------------------
def action_quit_with_code(self) -> None:
"""Quit the monitor with a user-interrupt exit code."""
self.exit(130)
def action_next_sort_field(self) -> None:
"""Cycle sortable fields and switch to metric sorting mode."""
if self._sort_mode == "grouped":
self._sort_mode = "metric"
self._sort_field_index = next_sort_field_index(
self._sort_field_index, len(self.SORT_FIELDS)
)
def action_toggle_sort_order(self) -> None:
"""Toggle ascending / descending metric sort order."""
self._sort_desc = not self._sort_desc
def action_group_sort(self) -> None:
"""Reset to grouped default ordering."""
self._sort_mode = "grouped"
self._sort_desc = False
def _sort_status_text(self) -> str:
if self._sort_mode == "grouped":
return "sort=grouped(host->gpu)"
_, label = self.SORT_FIELDS[self._sort_field_index]
order = "desc" if self._sort_desc else "asc"
return f"sort={label.lower()} {order}"
# ------------------------------------------------------------------
# Internal timer callback
# ------------------------------------------------------------------
def _tick(self) -> None:
"""Drain pending collector results and refresh the DataTable."""
state = self._state
# Process any new data from collector threads.
apply_node_results(
nodes=drain_pending_nodes(hosts=self._hosts, state=state),
device_filter=self._device_filter,
state=state,
)
# Update the effective monitored key set when no filter is applied.
if self._device_filter is None:
state.monitored_keys = state.discovered_keys.copy()
now = time.monotonic()
elapsed = now - self._started
# ---- Update header subtitle with current runtime ----
self.title = (
f"hytop gpu | interval={self._interval:.2f}s"
f" | window={fmt_window(self._window)}"
f" | elapsed={fmt_elapsed(elapsed)}"
f" | {self._sort_status_text()}"
)
# ---- Rebuild table rows ----
key_list = sort_gpu_keys_grouped(state.monitored_keys, self._hosts)
if self._sort_mode == "metric":
field, _ = self.SORT_FIELDS[self._sort_field_index]
def _metric_value(key: tuple[str, int]) -> float | int | None:
if field == "host":
return key[0]
if field == "gpu":
return key[1]
history = state.histories.get(key)
if history is None:
return None
latest = history.latest()
if latest is None:
return None
if field in {"temp_c", "avg_pwr_w", "vram_pct", "gpu_pct", "sclk_mhz"}:
value = getattr(latest, field, None)
return float(value) if value is not None else None
return None
key_list = sort_with_missing_last(key_list, _metric_value, self._sort_desc)
table: DataTable = self.query_one("#gpu-table", DataTable)
# Build keys_to_display: keys we actually have data for.
keys_to_display: list[tuple[str, int]] = []
for key in key_list:
history = state.histories.get(key)
if history is None:
continue
latest = history.latest()
if latest is None:
continue
keys_to_display.append(key)
def _cell_values(key: tuple[str, int]) -> list[str]:
host, gpu = key
history = state.histories.get(key)
latest = history.latest() if history else None
if history is None or latest is None:
return []
stale = (now - latest.ts) > self._window
if stale:
return [host, str(gpu)] + ["-"] * (len(self._cell_column_keys) - 2)
values: list[str] = []
for col in self._columns:
metric_value = getattr(latest, col.metric, None)
values.append(_format_metric(col.metric, metric_value))
if col.avg_label is not None:
if metric_value is None:
values.append("-")
else:
values.append(
_format_metric(
col.metric,
history.avg(col.metric, self._window, now),
)
)
return [host, str(gpu), *values]
if keys_to_display == self._ordered_keys:
# Same order: incremental update to preserve scroll position and cursor.
for key in keys_to_display:
row_key = self._row_keys.get(key)
if row_key is None:
continue
values = _cell_values(key)
if len(values) != len(self._cell_column_keys):
continue
for col_key, val in zip(self._cell_column_keys, values, strict=True):
table.update_cell(row_key, col_key, val)
else:
# Order or set changed: full rebuild.
table.clear(columns=False)
self._row_keys.clear()
self._ordered_keys = keys_to_display.copy()
for key in keys_to_display:
host, gpu = key
values = _cell_values(key)
if not values:
continue
row_key = f"{host}:{gpu}"
table.add_row(*values, key=row_key)
self._row_keys[key] = row_key
# ---- Update error label ----
error_label: Label = self.query_one("#error-label", Label)
if state.errors:
parts = [f"{h}: {err}" for h, err in state.errors.items()]
error_label.update(" ".join(parts))
error_label.remove_class("hidden")
else:
error_label.update("")
error_label.add_class("hidden")
# ---- Check wait-idle exit condition ----
if self._wait_idle:
warmup_done = elapsed >= self._wait_idle_duration
if warmup_done and availability_ready(
window=self._window,
histories=state.histories,
monitored_keys=state.monitored_keys,
hosts=self._hosts,
errors=state.errors,
):
self.exit(0)
# ---- Check global timeout ----
if self._timeout is not None and elapsed >= self._timeout:
self.exit(124)
...@@ -45,7 +45,7 @@ def gpu( ...@@ -45,7 +45,7 @@ def gpu(
wait_idle: bool = typer.Option( wait_idle: bool = typer.Option(
False, False,
"--wait-idle", "--wait-idle",
help="Exit 0 when all monitored GPUs have zero VRAM/HCU avg.", help="Exit 0 when all monitored GPUs have zero VRAM/GPU avg.",
), ),
wait_idle_seconds: float = typer.Option( wait_idle_seconds: float = typer.Option(
10.0, 10.0,
......
...@@ -57,8 +57,8 @@ SHOW_SPECS: Final[tuple[ShowSpec, ...]] = ( ...@@ -57,8 +57,8 @@ SHOW_SPECS: Final[tuple[ShowSpec, ...]] = (
), ),
ShowSpec( ShowSpec(
flag="showuse", flag="showuse",
metric_json_keys={"hcu_pct": "HCU use (%)"}, metric_json_keys={"gpu_pct": "HCU use (%)"},
columns=(RenderColumn(label="GPU%", metric="hcu_pct", avg_label="GPU%"),), columns=(RenderColumn(label="GPU%", metric="gpu_pct", avg_label="GPU%"),),
cli_help="Display GPU utilization.", cli_help="Display GPU utilization.",
), ),
) )
......
...@@ -15,7 +15,7 @@ class Sample: ...@@ -15,7 +15,7 @@ class Sample:
temp_c: GPU core temperature in Celsius. temp_c: GPU core temperature in Celsius.
avg_pwr_w: Average power draw in Watts. avg_pwr_w: Average power draw in Watts.
vram_pct: VRAM usage percentage. vram_pct: VRAM usage percentage.
hcu_pct: HCU usage percentage. gpu_pct: GPU usage percentage.
sclk_mhz: sclk frequency in MHz. sclk_mhz: sclk frequency in MHz.
""" """
...@@ -23,7 +23,7 @@ class Sample: ...@@ -23,7 +23,7 @@ class Sample:
temp_c: float | None = None temp_c: float | None = None
avg_pwr_w: float | None = None avg_pwr_w: float | None = None
vram_pct: float | None = None vram_pct: float | None = None
hcu_pct: float | None = None gpu_pct: float | None = None
sclk_mhz: float | None = None sclk_mhz: float | None = None
......
from __future__ import annotations
import time
from collections.abc import Iterable
from rich import box
from rich.console import Group
from rich.table import Table
from hytop.core.format import fmt_elapsed, fmt_window
from hytop.core.history import SlidingHistory
from hytop.gpu.metrics import render_columns_for_show_flags
def build_renderable(
window: float,
hosts: list[str],
histories: dict[tuple[str, int], SlidingHistory],
monitored_keys: Iterable[tuple[str, int]],
errors: dict[str, str],
show_flags: Iterable[str],
poll_interval: float,
elapsed_since_start: float,
) -> Group:
"""Build the Rich renderable for the current monitor frame.
Args:
window: Rolling window length in seconds.
hosts: Host order used for row ordering and error table.
histories: Sliding histories by host+gpu key.
monitored_keys: Effective host+gpu keys to show.
errors: Latest host-level errors.
poll_interval: Configured sampling interval in seconds.
elapsed_since_start: Total runtime in seconds.
Returns:
A Group containing main data table and optional error table.
"""
now = time.monotonic()
host_rank = {host: idx for idx, host in enumerate(hosts)}
key_list = sorted(monitored_keys, key=lambda x: (host_rank.get(x[0], len(hosts)), x[1]))
title = (
f"hytop gpu | interval={poll_interval:.2f}s | elapsed={fmt_elapsed(elapsed_since_start)}"
)
table = Table(
title=title,
box=box.MINIMAL_HEAVY_HEAD,
expand=True,
)
table.add_column("Host", justify="left", no_wrap=True)
table.add_column("GPU", justify="right")
columns = render_columns_for_show_flags(show_flags)
for col in columns:
table.add_column(col.label, justify="right")
if col.avg_label is not None:
table.add_column(f"{col.avg_label}@{fmt_window(window)}", justify="right")
for key in key_list:
history = histories.get(key)
if history is None:
continue
latest = history.latest()
if latest is None:
continue
host, gpu = key
stale = (now - latest.ts) > window
if stale:
table.add_row(host, str(gpu), *["-"] * (len(table.columns) - 2))
continue
values: list[str] = []
for col in columns:
metric_value = getattr(latest, col.metric, None)
values.append(_format_metric(col.metric, metric_value))
if col.avg_label is not None:
if metric_value is None:
values.append("-")
else:
values.append(_format_metric(col.metric, history.avg(col.metric, window, now)))
table.add_row(
host,
str(gpu),
*values,
)
if table.row_count == 0:
table.add_row("No data yet.", *[""] * (len(table.columns) - 1))
if not errors:
return Group(table)
err_table = Table(title="Host errors", box=box.MINIMAL_HEAVY_HEAD, expand=True)
err_table.add_column("Host", justify="left", no_wrap=True)
err_table.add_column("Error", justify="left")
for host in hosts:
err = errors.get(host)
if err:
err_table.add_row(host, err)
return Group(table, err_table)
def _format_metric(metric: str, value: object) -> str:
if value is None:
return "-"
if metric == "temp_c":
return f"{float(value):7.1f}C"
if metric == "avg_pwr_w":
return f"{float(value):8.1f}W"
if metric in {"vram_pct", "hcu_pct"}:
return f"{float(value):7.2f}%"
if metric == "sclk_mhz":
return f"{float(value):7.0f}MHz"
return str(value)
...@@ -5,15 +5,11 @@ import threading ...@@ -5,15 +5,11 @@ import threading
import time import time
from collections.abc import Sequence from collections.abc import Sequence
from rich.console import Console
from rich.live import Live
from hytop.core.history import SlidingHistory from hytop.core.history import SlidingHistory
from hytop.core.ssh import collect_from_host from hytop.core.ssh import collect_from_host
from hytop.gpu.metrics import hy_smi_args_for_show_flags from hytop.gpu.metrics import hy_smi_args_for_show_flags
from hytop.gpu.models import HostSnapshot, MonitorState, NodeResult from hytop.gpu.models import HostSnapshot, MonitorState, NodeResult
from hytop.gpu.parser import parse_hy_smi_output from hytop.gpu.parser import parse_hy_smi_output
from hytop.gpu.render import build_renderable
def collect_node( def collect_node(
...@@ -112,11 +108,11 @@ def availability_ready( ...@@ -112,11 +108,11 @@ def availability_ready(
latest = history.latest() latest = history.latest()
if latest is None or (now - latest.ts) > window: if latest is None or (now - latest.ts) > window:
return False return False
if latest.vram_pct is None or latest.hcu_pct is None: if latest.vram_pct is None or latest.gpu_pct is None:
return False return False
if history.avg("vram_pct", window, now) != 0.0: if history.avg("vram_pct", window, now) != 0.0:
return False return False
if history.avg("hcu_pct", window, now) != 0.0: if history.avg("gpu_pct", window, now) != 0.0:
return False return False
return True return True
...@@ -262,11 +258,16 @@ def run_monitor( ...@@ -262,11 +258,16 @@ def run_monitor(
timeout: float | None, timeout: float | None,
wait_idle_duration: float = 10.0, wait_idle_duration: float = 10.0,
) -> int: ) -> int:
"""Run the asynchronous collector + periodic renderer monitor loop. """Run the GPU monitor as a Textual TUI application.
Starts per-host collector threads, then launches the Textual App which
polls the shared MonitorState and refreshes the DataTable on a timer.
The app returns an integer exit code via ``App.exit(code)``.
Args: Args:
hosts: Host list to monitor. hosts: Host list to monitor.
device_filter: Optional GPU id filter. device_filter: Optional GPU id filter.
show_flags: Ordered list of show flags controlling displayed columns.
window: Rolling window length in seconds. window: Rolling window length in seconds.
interval: Sampling interval in seconds. interval: Sampling interval in seconds.
wait_idle: Whether to exit when all monitored GPUs become idle. wait_idle: Whether to exit when all monitored GPUs become idle.
...@@ -281,8 +282,8 @@ def run_monitor( ...@@ -281,8 +282,8 @@ def run_monitor(
130 when interrupted by user. 130 when interrupted by user.
""" """
console = Console() # Import here to avoid a circular import at module load time.
err_console = Console(stderr=True) from hytop.gpu.app import GpuMonitorApp
if interval <= 0: if interval <= 0:
print("argument error: --interval must be > 0", file=sys.stderr) print("argument error: --interval must be > 0", file=sys.stderr)
...@@ -295,11 +296,8 @@ def run_monitor( ...@@ -295,11 +296,8 @@ def run_monitor(
hy_smi_args = hy_smi_args_for_show_flags(show_flags, wait_idle=wait_idle) hy_smi_args = hy_smi_args_for_show_flags(show_flags, wait_idle=wait_idle)
ssh_timeout = min(max(5 * interval, 2.0), 5.0) ssh_timeout = min(max(5 * interval, 2.0), 5.0)
cmd_timeout = min(max(10 * interval, 5.0), 10.0) cmd_timeout = min(max(10 * interval, 5.0), 10.0)
render_interval = min(interval, 0.5)
started = time.monotonic() started = time.monotonic()
try:
with Live(console=console, auto_refresh=False, screen=True) as live:
workers = start_collectors( workers = start_collectors(
hosts=hosts, hosts=hosts,
ssh_timeout=ssh_timeout, ssh_timeout=ssh_timeout,
...@@ -308,55 +306,27 @@ def run_monitor( ...@@ -308,55 +306,27 @@ def run_monitor(
interval=interval, interval=interval,
state=state, state=state,
) )
try:
while True: app = GpuMonitorApp(
loop_started = time.monotonic()
apply_node_results(
nodes=drain_pending_nodes(hosts=hosts, state=state),
device_filter=device_filter,
state=state,
)
if device_filter is None:
state.monitored_keys = state.discovered_keys.copy()
live.update(
build_renderable(
window=window,
hosts=hosts, hosts=hosts,
histories=state.histories,
monitored_keys=state.monitored_keys,
errors=state.errors,
show_flags=show_flags, show_flags=show_flags,
poll_interval=interval,
elapsed_since_start=time.monotonic() - started,
),
refresh=True,
)
elapsed_since_start = time.monotonic() - started
warmup_done = elapsed_since_start >= wait_idle_duration
if (
wait_idle
and warmup_done
and availability_ready(
window=window, window=window,
histories=state.histories, interval=interval,
monitored_keys=state.monitored_keys, wait_idle=wait_idle,
hosts=hosts, wait_idle_duration=wait_idle_duration,
errors=state.errors, timeout=timeout,
) state=state,
): device_filter=device_filter,
console.print("status: success (all monitored GPUs are available)") started=started,
return 0
if wait_idle and timeout is not None and elapsed_since_start >= timeout:
err_console.print(
"status: timeout (availability condition not met)",
style="yellow",
) )
return 124
time.sleep(max(0.0, render_interval - (time.monotonic() - loop_started))) try:
app.run()
except KeyboardInterrupt:
pass
finally: finally:
state.stop_event.set() state.stop_event.set()
for worker in workers: for worker in workers:
worker.join(timeout=min(0.2, interval)) worker.join(timeout=min(0.2, interval))
except KeyboardInterrupt:
err_console.print("status: interrupted by user", style="yellow") return app.return_code or 0
return 130
from __future__ import annotations
def sort_gpu_keys_grouped(
monitored_keys: set[tuple[str, int]],
hosts: list[str],
) -> list[tuple[str, int]]:
"""Default grouped order: Host -> GPU."""
host_rank = {host: idx for idx, host in enumerate(hosts)}
return sorted(
monitored_keys,
key=lambda x: (host_rank.get(x[0], len(hosts)), x[1]),
)
from __future__ import annotations
import time
from typing import ClassVar
from textual.app import App, ComposeResult
from textual.containers import Vertical
from textual.widgets import DataTable, Footer, Header, Label
from hytop.core.format import fmt_elapsed, fmt_window
from hytop.core.sorting import next_sort_field_index, sort_with_missing_last
from hytop.net.models import MonitorState
from hytop.net.service import apply_node_results, drain_pending_nodes
from hytop.net.sort import sort_net_keys_grouped, split_iface_key
def format_rate(value: float, iec: bool = False) -> str:
if iec:
units = ["B/s", "KiB/s", "MiB/s", "GiB/s", "TiB/s"]
base = 1024.0
else:
units = ["B/s", "kB/s", "MB/s", "GB/s", "TB/s"]
base = 1000.0
output = float(value)
idx = 0
while output >= base and idx < len(units) - 1:
output /= base
idx += 1
return f"{output:.2f} {units[idx]}"
def format_iface_name(name: str, link_state: str | None) -> str:
if not link_state:
return name
normalized = link_state.strip().lower()
is_down = normalized in {"down", "disabled", "init", "inactive"} or normalized.startswith(
"down"
)
if is_down:
return f"{name} (down)"
return name
class NetMonitorApp(App[int]):
"""Textual TUI application for real-time network interface monitoring.
Replaces the previous Rich Live + Table rendering loop. Network data
is collected by background threads; this App polls the shared MonitorState
on a timer and updates the DataTable in place, enabling native scrolling.
"""
CSS = """
Screen {
layout: vertical;
}
#net-table {
height: 1fr;
border: none;
}
#error-label {
height: auto;
padding: 0 1;
color: $warning;
}
#error-label.hidden {
display: none;
}
"""
BINDINGS: ClassVar[list] = [
("q", "quit_with_code", "Quit"),
("ctrl+c", "quit_with_code", "Quit"),
("s", "next_sort_field", "Next Sort Field"),
("S", "toggle_sort_order", "Toggle Sort Order"),
("g", "group_sort", "Grouped Sort"),
]
SORT_FIELDS: ClassVar[list[tuple[str, str]]] = [
("host", "Host"),
("mode", "Mode"),
("nic", "NIC"),
("rx_bps", "RX"),
("tx_bps", "TX"),
("rx_avg", "RX@window"),
("tx_avg", "TX@window"),
]
def __init__(
self,
hosts: list[str],
window: float,
interval: float,
timeout: float | None,
state: MonitorState,
started: float,
iec: bool = False,
**kwargs: object,
) -> None:
super().__init__(**kwargs)
self._hosts = hosts
self._window = window
self._interval = interval
self._timeout = timeout
self._state = state
self._started = started
self._iec = iec
# Map (host, iface_key) -> DataTable row key for in-place updates.
self._row_keys: dict[tuple[str, str], str] = {}
# Last displayed key order; when unchanged we use update_cell to preserve scroll/cursor.
self._ordered_keys: list[tuple[str, str]] = []
self._cell_column_keys: list[str] = [
"host",
"mode",
"device",
"nic",
"rx",
"tx",
"rx_avg",
"tx_avg",
]
self._sort_mode: str = "grouped"
self._sort_desc: bool = False
self._sort_field_index: int = 0
# ------------------------------------------------------------------
# Composition
# ------------------------------------------------------------------
def compose(self) -> ComposeResult:
yield Header()
with Vertical():
yield DataTable(id="net-table", cursor_type="row")
yield Label("", id="error-label")
yield Footer()
def on_mount(self) -> None:
"""Initialise columns and start the polling timer."""
table: DataTable = self.query_one("#net-table", DataTable)
table.add_column("Host", key="host")
table.add_column("Mode", key="mode")
table.add_column("Device", key="device")
table.add_column("NIC", key="nic")
table.add_column("RX", key="rx")
table.add_column("TX", key="tx")
table.add_column(f"RX@{fmt_window(self._window)}", key="rx_avg")
table.add_column(f"TX@{fmt_window(self._window)}", key="tx_avg")
render_interval = min(self._interval, 0.5)
self.set_interval(render_interval, self._tick)
# ------------------------------------------------------------------
# Actions
# ------------------------------------------------------------------
def action_quit_with_code(self) -> None:
"""Quit with a user-interrupt exit code."""
self.exit(130)
def action_next_sort_field(self) -> None:
"""Cycle sortable fields and switch to metric sorting mode."""
if self._sort_mode == "grouped":
self._sort_mode = "metric"
self._sort_field_index = next_sort_field_index(
self._sort_field_index, len(self.SORT_FIELDS)
)
def action_toggle_sort_order(self) -> None:
"""Toggle ascending / descending metric sort order."""
self._sort_desc = not self._sort_desc
def action_group_sort(self) -> None:
"""Reset to grouped default ordering."""
self._sort_mode = "grouped"
self._sort_desc = False
def _sort_status_text(self) -> str:
if self._sort_mode == "grouped":
return "sort=grouped(host->mode->nic)"
_, label = self.SORT_FIELDS[self._sort_field_index]
order = "desc" if self._sort_desc else "asc"
return f"sort={label.lower()} {order}"
# ------------------------------------------------------------------
# Internal timer callback
# ------------------------------------------------------------------
def _tick(self) -> None:
"""Drain pending collector results and refresh the DataTable."""
state = self._state
apply_node_results(
nodes=drain_pending_nodes(hosts=self._hosts, state=state),
interval=self._interval,
state=state,
)
now = time.monotonic()
elapsed = now - self._started
# ---- Update app title with runtime info ----
self.title = (
f"hytop net | interval={self._interval:.2f}s"
f" | window={fmt_window(self._window)}"
f" | elapsed={fmt_elapsed(elapsed)}"
f" | {self._sort_status_text()}"
)
# ---- Rebuild rows ----
key_list = sort_net_keys_grouped(state.monitored_keys, self._hosts)
if self._sort_mode == "metric":
field, _ = self.SORT_FIELDS[self._sort_field_index]
def _metric_value(key: tuple[str, str]) -> float | str | None:
if field == "host":
return key[0]
history = state.histories.get(key)
if field == "mode":
mode, _ = split_iface_key(key[1])
return mode
if field == "nic":
_, nic = split_iface_key(key[1])
return nic
if history is None:
return None
latest_rate = history.latest()
if latest_rate is None:
return None
if field == "rx_bps":
return latest_rate.rx_bps
if field == "tx_bps":
return latest_rate.tx_bps
if field == "rx_avg":
return history.avg("rx_bps", self._window, now)
if field == "tx_avg":
return history.avg("tx_bps", self._window, now)
return None
key_list = sort_with_missing_last(key_list, _metric_value, self._sort_desc)
table: DataTable = self.query_one("#net-table", DataTable)
# Build keys_to_display: keys we actually have data for.
keys_to_display: list[tuple[str, str]] = []
for key in key_list:
history = state.histories.get(key)
latest_counter = state.latest_counter_by_key.get(key)
if history is None or latest_counter is None:
continue
latest_rate = history.latest()
if latest_rate is None:
continue
keys_to_display.append(key)
def _cell_values(key: tuple[str, str]) -> list[str]:
history = state.histories.get(key)
latest_counter = state.latest_counter_by_key.get(key)
if history is None or latest_counter is None:
return []
latest_rate = history.latest()
if latest_rate is None:
return []
host, iface_key = key
mode, name = split_iface_key(iface_key)
device_text = latest_counter.device_name or name
nic_text = format_iface_name(name, latest_counter.link_state)
stale = (now - latest_rate.ts) > self._window
if stale:
return [host, mode, device_text, nic_text, "-", "-", "-", "-"]
rx_avg = history.avg("rx_bps", self._window, now)
tx_avg = history.avg("tx_bps", self._window, now)
return [
host,
mode,
device_text,
nic_text,
format_rate(latest_rate.rx_bps, self._iec),
format_rate(latest_rate.tx_bps, self._iec),
format_rate(rx_avg, self._iec),
format_rate(tx_avg, self._iec),
]
if keys_to_display == self._ordered_keys:
# Same order: incremental update to preserve scroll position and cursor.
for key in keys_to_display:
row_key = self._row_keys.get(key)
if row_key is None:
continue
values = _cell_values(key)
if len(values) != len(self._cell_column_keys):
continue
for col_key, val in zip(self._cell_column_keys, values, strict=True):
table.update_cell(row_key, col_key, val)
else:
# Order or set changed: full rebuild.
table.clear(columns=False)
self._row_keys.clear()
self._ordered_keys = keys_to_display.copy()
for key in keys_to_display:
values = _cell_values(key)
if not values:
continue
host, iface_key = key
row_key = f"{host}:{iface_key}"
table.add_row(*values, key=row_key)
self._row_keys[key] = row_key
# ---- Update error label ----
error_label: Label = self.query_one("#error-label", Label)
if state.errors:
parts = [f"{h}: {err}" for h, err in state.errors.items()]
error_label.update(" ".join(parts))
error_label.remove_class("hidden")
else:
error_label.update("")
error_label.add_class("hidden")
# ---- Check global timeout ----
if self._timeout is not None and elapsed >= self._timeout:
self.exit(124)
from __future__ import annotations
import time
from rich import box
from rich.console import Group
from rich.table import Table
from hytop.core.format import fmt_elapsed, fmt_window
from hytop.core.history import SlidingHistory
from hytop.net.models import NetCounter, RateSample
def format_rate(value: float, iec: bool = False) -> str:
if iec:
units = ["B/s", "KiB/s", "MiB/s", "GiB/s", "TiB/s"]
base = 1024.0
else:
units = ["B/s", "kB/s", "MB/s", "GB/s", "TB/s"]
base = 1000.0
output = float(value)
idx = 0
while output >= base and idx < len(units) - 1:
output /= base
idx += 1
return f"{output:7.2f} {units[idx]}"
def split_iface_key(iface_key: str) -> tuple[str, str]:
kind, _, name = iface_key.partition(":")
return kind, name
def format_iface_name(name: str, link_state: str | None) -> str:
if not link_state:
return name
normalized = link_state.strip().lower()
is_down = normalized in {"down", "disabled", "init", "inactive"} or normalized.startswith(
"down"
)
if is_down:
return f"{name} (down)"
return name
def build_renderable(
window: float,
hosts: list[str],
histories: dict[tuple[str, str], SlidingHistory],
monitored_keys: set[tuple[str, str]],
latest_counter_by_key: dict[tuple[str, str], NetCounter],
errors: dict[str, str],
poll_interval: float,
elapsed_since_start: float,
iec: bool = False,
) -> Group:
now = time.monotonic()
host_rank = {host: idx for idx, host in enumerate(hosts)}
key_list = sorted(monitored_keys, key=lambda x: (host_rank.get(x[0], len(hosts)), x[1]))
title = (
f"hytop net | interval={poll_interval:.2f}s | elapsed={fmt_elapsed(elapsed_since_start)}"
)
table = Table(title=title, box=box.MINIMAL_HEAVY_HEAD, expand=True)
table.add_column("Host", justify="left", no_wrap=True)
table.add_column("Mode", justify="left")
table.add_column("Device", justify="left")
table.add_column("NIC", justify="left")
table.add_column("RX", justify="right")
table.add_column("TX", justify="right")
table.add_column(f"RX@{fmt_window(window)}", justify="right")
table.add_column(f"TX@{fmt_window(window)}", justify="right")
for key in key_list:
history = histories.get(key)
latest_counter = latest_counter_by_key.get(key)
if history is None or latest_counter is None:
continue
latest_rate = history.latest()
if latest_rate is None:
continue
if not isinstance(latest_rate, RateSample):
continue
stale = (now - latest_rate.ts) > window
host, iface_key = key
mode, name = split_iface_key(iface_key)
device_text = latest_counter.device_name or name
nic_text = format_iface_name(name, latest_counter.link_state)
if stale:
table.add_row(host, mode, device_text, nic_text, "-", "-", "-", "-")
continue
rx_avg = history.avg("rx_bps", window, now)
tx_avg = history.avg("tx_bps", window, now)
table.add_row(
host,
mode,
device_text,
nic_text,
format_rate(latest_rate.rx_bps, iec),
format_rate(latest_rate.tx_bps, iec),
format_rate(rx_avg, iec),
format_rate(tx_avg, iec),
)
if table.row_count == 0:
table.add_row("No data yet.", "", "", "", "", "", "", "")
if not errors:
return Group(table)
err_table = Table(title="Host errors", box=box.MINIMAL_HEAVY_HEAD, expand=True)
err_table.add_column("Host", justify="left", no_wrap=True)
err_table.add_column("Error", justify="left")
for host in hosts:
err = errors.get(host)
if err:
err_table.add_row(host, err)
return Group(table, err_table)
...@@ -5,14 +5,10 @@ import sys ...@@ -5,14 +5,10 @@ import sys
import threading import threading
import time import time
from rich.console import Console
from rich.live import Live
from hytop.core.history import SlidingHistory from hytop.core.history import SlidingHistory
from hytop.core.ssh import SSHOptions, collect_python_from_host from hytop.core.ssh import SSHOptions, collect_python_from_host
from hytop.net.collector import REMOTE_COLLECTOR_PY, collect_local_counters, parse_counter_payload from hytop.net.collector import REMOTE_COLLECTOR_PY, collect_local_counters, parse_counter_payload
from hytop.net.models import HostSnapshot, MonitorState, NetKind, NodeCounterSnapshot, RateSample from hytop.net.models import HostSnapshot, MonitorState, NetKind, NodeCounterSnapshot, RateSample
from hytop.net.render import build_renderable
LOCAL_HOSTS = {"localhost", "127.0.0.1", "::1"} LOCAL_HOSTS = {"localhost", "127.0.0.1", "::1"}
...@@ -220,6 +216,17 @@ def run_monitor( ...@@ -220,6 +216,17 @@ def run_monitor(
iec: bool = False, iec: bool = False,
ssh_options: SSHOptions | None = None, ssh_options: SSHOptions | None = None,
) -> int: ) -> int:
"""Run the network monitor as a Textual TUI application.
Starts per-host collector threads, then launches the Textual App which
polls the shared MonitorState and refreshes the DataTable on a timer.
Returns:
Process-style exit code (0, 2, 124, or 130).
"""
from hytop.net.app import NetMonitorApp
if interval <= 0: if interval <= 0:
print("argument error: --interval must be > 0", file=sys.stderr) print("argument error: --interval must be > 0", file=sys.stderr)
return 2 return 2
...@@ -229,15 +236,9 @@ def run_monitor( ...@@ -229,15 +236,9 @@ def run_monitor(
ssh_timeout = min(max(5 * interval, 2.0), 5.0) ssh_timeout = min(max(5 * interval, 2.0), 5.0)
cmd_timeout = min(max(10 * interval, 5.0), 10.0) cmd_timeout = min(max(10 * interval, 5.0), 10.0)
console = Console()
err_console = Console(stderr=True)
started = time.monotonic() started = time.monotonic()
render_interval = min(interval, 0.5)
state = init_monitor_state(hosts=hosts, max_window=window) state = init_monitor_state(hosts=hosts, max_window=window)
try:
with Live(console=console, auto_refresh=False, screen=True) as live:
workers = start_collectors( workers = start_collectors(
hosts=hosts, hosts=hosts,
ssh_timeout=ssh_timeout, ssh_timeout=ssh_timeout,
...@@ -248,36 +249,24 @@ def run_monitor( ...@@ -248,36 +249,24 @@ def run_monitor(
ssh_options=ssh_options, ssh_options=ssh_options,
state=state, state=state,
) )
try:
while True: app = NetMonitorApp(
loop_started = time.monotonic() hosts=hosts,
apply_node_results( window=window,
nodes=drain_pending_nodes(hosts=hosts, state=state),
interval=interval, interval=interval,
timeout=timeout,
state=state, state=state,
) started=started,
live.update(
build_renderable(
window=window,
hosts=hosts,
histories=state.histories,
monitored_keys=state.monitored_keys,
latest_counter_by_key=state.latest_counter_by_key,
errors=state.errors,
poll_interval=interval,
elapsed_since_start=time.monotonic() - started,
iec=iec, iec=iec,
),
refresh=True,
) )
if timeout is not None and (time.monotonic() - started) >= timeout:
return 124 try:
time.sleep(max(0.0, render_interval - (time.monotonic() - loop_started))) app.run()
except KeyboardInterrupt:
pass
finally: finally:
state.stop_event.set() state.stop_event.set()
for worker in workers: for worker in workers:
worker.join(timeout=min(0.2, interval)) worker.join(timeout=min(0.2, interval))
except KeyboardInterrupt:
err_console.print("status: interrupted by user", style="yellow") return app.return_code or 0
return 130
return 0
from __future__ import annotations
def split_iface_key(iface_key: str) -> tuple[str, str]:
kind, _, name = iface_key.partition(":")
return kind, name
def sort_net_keys_grouped(
monitored_keys: set[tuple[str, str]],
hosts: list[str],
) -> list[tuple[str, str]]:
"""Default grouped order: Host -> Mode -> NIC."""
host_rank = {host: idx for idx, host in enumerate(hosts)}
mode_rank = {"eth": 0, "ib": 1}
return sorted(
monitored_keys,
key=lambda x: (
host_rank.get(x[0], len(hosts)),
mode_rank.get(split_iface_key(x[1])[0], 99),
split_iface_key(x[1])[1],
),
)
from __future__ import annotations
from hytop.core.sorting import next_sort_field_index, sort_with_missing_last
from hytop.gpu.sort import sort_gpu_keys_grouped
from hytop.net.sort import sort_net_keys_grouped
class TestGpuSortingHelpers:
def test_grouped_default_host_then_gpu(self):
keys = {("node02", 1), ("node01", 2), ("node01", 0)}
ordered = sort_gpu_keys_grouped(keys, hosts=["node01", "node02"])
assert ordered == [("node01", 0), ("node01", 2), ("node02", 1)]
def test_metric_sort_keeps_missing_last(self):
base = [("node01", 0), ("node01", 1), ("node02", 0)]
values = {("node01", 0): 50.0, ("node02", 0): 10.0}
ordered = sort_with_missing_last(base, lambda key: values.get(key), desc=True)
assert ordered == [("node01", 0), ("node02", 0), ("node01", 1)]
class TestNetSortingHelpers:
def test_grouped_default_host_mode_nic(self):
keys = {
("node02", "ib:mlx5_1/p1"),
("node01", "ib:mlx5_0/p1"),
("node01", "eth:p6p1"),
("node01", "eth:p14p1"),
}
ordered = sort_net_keys_grouped(keys, hosts=["node01", "node02"])
assert ordered == [
("node01", "eth:p14p1"),
("node01", "eth:p6p1"),
("node01", "ib:mlx5_0/p1"),
("node02", "ib:mlx5_1/p1"),
]
def test_metric_sort_keeps_missing_last(self):
base = [("node01", "eth:p6p1"), ("node01", "ib:mlx5_0/p1"), ("node02", "eth:p14p1")]
values = {("node01", "ib:mlx5_0/p1"): 30.0, ("node02", "eth:p14p1"): 10.0}
ordered = sort_with_missing_last(base, lambda key: values.get(key), desc=False)
assert ordered == [
("node02", "eth:p14p1"),
("node01", "ib:mlx5_0/p1"),
("node01", "eth:p6p1"),
]
class TestSortStateHelpers:
def test_next_sort_field_index_cycles(self):
assert next_sort_field_index(0, 3) == 1
assert next_sort_field_index(2, 3) == 0
assert next_sort_field_index(0, 0) == 0
"""Tests for hytop.net.render helpers.""" """Tests for hytop.net formatting helpers."""
from __future__ import annotations from __future__ import annotations
from hytop.net.render import format_iface_name, format_rate, split_iface_key from hytop.net.app import format_iface_name, format_rate, split_iface_key
class TestFormatIfaceName: class TestFormatIfaceName:
...@@ -18,10 +18,10 @@ class TestFormatIfaceName: ...@@ -18,10 +18,10 @@ class TestFormatIfaceName:
class TestFormatRate: class TestFormatRate:
def test_zero(self): def test_zero(self):
assert format_rate(0.0) == " 0.00 B/s" assert format_rate(0.0) == "0.00 B/s"
def test_bytes(self): def test_bytes(self):
assert format_rate(512.0) == " 512.00 B/s" assert format_rate(512.0) == "512.00 B/s"
def test_kilobytes(self): def test_kilobytes(self):
result = format_rate(1500.0) result = format_rate(1500.0)
......
...@@ -110,14 +110,14 @@ class TestParseHySmiOutput: ...@@ -110,14 +110,14 @@ class TestParseHySmiOutput:
s = result[0] s = result[0]
assert s.temp_c == pytest.approx(30.0) assert s.temp_c == pytest.approx(30.0)
assert s.avg_pwr_w == pytest.approx(157.0) assert s.avg_pwr_w == pytest.approx(157.0)
assert s.hcu_pct == pytest.approx(0.0) assert s.gpu_pct == pytest.approx(0.0)
assert s.vram_pct == pytest.approx(89.0) # integer string "89" → 89.0 assert s.vram_pct == pytest.approx(89.0) # integer string "89" → 89.0
assert s.sclk_mhz == pytest.approx(1500.0) # "1500Mhz" → 1500.0 assert s.sclk_mhz == pytest.approx(1500.0) # "1500Mhz" → 1500.0
def test_full_output_card7_hcu_load(self): def test_full_output_card7_hcu_load(self):
raw = json.dumps(HY_SMI_FULL) raw = json.dumps(HY_SMI_FULL)
result = parse_hy_smi_output(raw, sample_ts=1.0) result = parse_hy_smi_output(raw, sample_ts=1.0)
assert result[7].hcu_pct == pytest.approx(100.0) assert result[7].gpu_pct == pytest.approx(100.0)
def test_temp_only_output(self): def test_temp_only_output(self):
raw = json.dumps(HY_SMI_TEMP_ONLY) raw = json.dumps(HY_SMI_TEMP_ONLY)
...@@ -126,7 +126,7 @@ class TestParseHySmiOutput: ...@@ -126,7 +126,7 @@ class TestParseHySmiOutput:
assert s.temp_c == pytest.approx(26.0) assert s.temp_c == pytest.approx(26.0)
# Unrelated sensor keys must not populate fields # Unrelated sensor keys must not populate fields
assert s.avg_pwr_w is None assert s.avg_pwr_w is None
assert s.hcu_pct is None assert s.gpu_pct is None
def test_sample_ts_propagated(self): def test_sample_ts_propagated(self):
raw = json.dumps(HY_SMI_FULL) raw = json.dumps(HY_SMI_FULL)
......
"""Tests for hytop.gpu.render formatting helpers.""" """Tests for hytop GPU and core formatting helpers."""
from __future__ import annotations from __future__ import annotations
from hytop.gpu.render import _format_metric, fmt_elapsed, fmt_window from hytop.core.format import fmt_elapsed, fmt_window
from hytop.gpu.app import _format_metric
class TestFmtWindow: class TestFmtWindow:
...@@ -55,8 +56,8 @@ class TestFormatMetric: ...@@ -55,8 +56,8 @@ class TestFormatMetric:
assert "89.00" in result assert "89.00" in result
assert "%" in result assert "%" in result
def test_pct_format_hcu(self): def test_pct_format_gpu(self):
result = _format_metric("hcu_pct", 0.0) result = _format_metric("gpu_pct", 0.0)
assert "0.00" in result and "%" in result assert "0.00" in result and "%" in result
def test_sclk_format(self): def test_sclk_format(self):
......
...@@ -45,8 +45,8 @@ def _state(hosts=("localhost",), device_filter=None, max_window=10.0) -> Monitor ...@@ -45,8 +45,8 @@ def _state(hosts=("localhost",), device_filter=None, max_window=10.0) -> Monitor
return init_monitor_state(hosts=list(hosts), device_filter=device_filter, max_window=max_window) return init_monitor_state(hosts=list(hosts), device_filter=device_filter, max_window=max_window)
def _sample(ts: float, hcu_pct: float = 0.0, vram_pct: float = 0.0) -> Sample: def _sample(ts: float, gpu_pct: float = 0.0, vram_pct: float = 0.0) -> Sample:
return Sample(ts=ts, hcu_pct=hcu_pct, vram_pct=vram_pct) return Sample(ts=ts, gpu_pct=gpu_pct, vram_pct=vram_pct)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
...@@ -88,7 +88,7 @@ class TestCollectNode: ...@@ -88,7 +88,7 @@ class TestCollectNode:
class TestApplyNodeResults: class TestApplyNodeResults:
def test_successful_node_adds_to_history(self): def test_successful_node_adds_to_history(self):
state = _state(hosts=["localhost"]) state = _state(hosts=["localhost"])
sample = _sample(ts=1.0, hcu_pct=50.0) sample = _sample(ts=1.0, gpu_pct=50.0)
node = NodeResult(host="localhost", samples={0: sample}) node = NodeResult(host="localhost", samples={0: sample})
apply_node_results([node], device_filter=None, state=state) apply_node_results([node], device_filter=None, state=state)
assert ("localhost", 0) in state.histories assert ("localhost", 0) in state.histories
...@@ -130,15 +130,15 @@ class TestApplyNodeResults: ...@@ -130,15 +130,15 @@ class TestApplyNodeResults:
class TestAvailabilityReady: class TestAvailabilityReady:
def _make_history(self, hcu_pct: float, vram_pct: float) -> SlidingHistory: def _make_history(self, gpu_pct: float, vram_pct: float) -> SlidingHistory:
"""Build a SlidingHistory with one fresh sample using real monotonic time.""" """Build a SlidingHistory with one fresh sample using real monotonic time."""
h = SlidingHistory(max_window_s=30) h = SlidingHistory(max_window_s=30)
h.add(_sample(ts=time.monotonic(), hcu_pct=hcu_pct, vram_pct=vram_pct)) h.add(_sample(ts=time.monotonic(), gpu_pct=gpu_pct, vram_pct=vram_pct))
return h return h
def test_idle_gpu_returns_true(self): def test_idle_gpu_returns_true(self):
key = ("localhost", 0) key = ("localhost", 0)
histories = {key: self._make_history(hcu_pct=0.0, vram_pct=0.0)} histories = {key: self._make_history(gpu_pct=0.0, vram_pct=0.0)}
assert availability_ready( assert availability_ready(
window=5.0, window=5.0,
histories=histories, histories=histories,
...@@ -149,7 +149,7 @@ class TestAvailabilityReady: ...@@ -149,7 +149,7 @@ class TestAvailabilityReady:
def test_busy_gpu_returns_false(self): def test_busy_gpu_returns_false(self):
key = ("localhost", 0) key = ("localhost", 0)
histories = {key: self._make_history(hcu_pct=100.0, vram_pct=89.0)} histories = {key: self._make_history(gpu_pct=100.0, vram_pct=89.0)}
assert not availability_ready( assert not availability_ready(
window=5.0, window=5.0,
histories=histories, histories=histories,
...@@ -160,7 +160,7 @@ class TestAvailabilityReady: ...@@ -160,7 +160,7 @@ class TestAvailabilityReady:
def test_host_error_returns_false(self): def test_host_error_returns_false(self):
key = ("localhost", 0) key = ("localhost", 0)
histories = {key: self._make_history(hcu_pct=0.0, vram_pct=0.0)} histories = {key: self._make_history(gpu_pct=0.0, vram_pct=0.0)}
assert not availability_ready( assert not availability_ready(
window=5.0, window=5.0,
histories=histories, histories=histories,
......
...@@ -91,6 +91,8 @@ class TestCollectFromHostRemote: ...@@ -91,6 +91,8 @@ class TestCollectFromHostRemote:
collect_from_host("node01", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"]) collect_from_host("node01", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"])
cmd = mock_run.call_args[0][0] cmd = mock_run.call_args[0][0]
assert "BatchMode=yes" in cmd assert "BatchMode=yes" in cmd
assert "-n" in cmd
assert "-T" in cmd
@patch("hytop.core.ssh.subprocess.run") @patch("hytop.core.ssh.subprocess.run")
def test_hy_smi_args_forwarded(self, mock_run): def test_hy_smi_args_forwarded(self, mock_run):
...@@ -143,3 +145,22 @@ class TestCollectPythonFromHost: ...@@ -143,3 +145,22 @@ class TestCollectPythonFromHost:
part for part in cmd if isinstance(part, str) and part.startswith("python3 -c ") part for part in cmd if isinstance(part, str) and part.startswith("python3 -c ")
] ]
assert len(remote_parts) == 1 assert len(remote_parts) == 1
class TestSubprocessStdinIsolation:
@patch("hytop.core.ssh.subprocess.run")
def test_collect_from_host_uses_devnull_stdin(self, mock_run):
mock_run.return_value = _make_proc(stdout='{"card0":{}}')
collect_from_host("localhost", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"])
assert mock_run.call_args.kwargs["stdin"] == subprocess.DEVNULL
@patch("hytop.core.ssh.subprocess.run")
def test_collect_python_from_host_uses_devnull_stdin(self, mock_run):
mock_run.return_value = _make_proc(stdout='{"ok":1}')
collect_python_from_host(
"node01",
ssh_timeout=5,
cmd_timeout=10,
python_code="print('ok')",
)
assert mock_run.call_args.kwargs["stdin"] == subprocess.DEVNULL
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment