"...git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "0611cb1bd9a7ef5659f3d9d6ef1877c7d5e4074e"
Commit 516fd909 authored by one's avatar one
Browse files

[hytop] Add tests and formatter

parent be036ead
...@@ -8,19 +8,39 @@ dynamic = ["version"] ...@@ -8,19 +8,39 @@ dynamic = ["version"]
description = "hytop toolkit" description = "hytop toolkit"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = ["rich>=13", "typer>=0.12"]
"rich>=13",
"typer>=0.12",
]
[project.scripts] [project.scripts]
hytop = "hytop.main:main" hytop = "hytop.main:main"
[project.optional-dependencies]
dev = ["pytest>=8", "ruff>=0.11"]
[tool.setuptools] [tool.setuptools]
package-dir = {"" = "src"} package-dir = { "" = "src" }
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["src"] where = ["src"]
[tool.setuptools.dynamic] [tool.setuptools.dynamic]
version = {attr = "hytop.__version__"} version = { attr = "hytop.__version__" }
[tool.pytest.ini_options]
testpaths = ["tests"]
[tool.ruff]
target-version = "py310"
line-length = 100
src = ["src", "tests"]
[tool.ruff.lint]
select = [
"F", # pyflakes
"E",
"W", # pycodestyle
"I", # isort
"UP", # pyupgrade
"B", # flake8-bugbear
"SIM", # flake8-simplify
"RUF", # ruff-specific
]
from __future__ import annotations from __future__ import annotations
from collections import deque from collections import deque
from typing import Deque, Optional, Protocol from typing import Protocol
class MetricSample(Protocol): class MetricSample(Protocol):
...@@ -13,7 +13,7 @@ class SlidingHistory: ...@@ -13,7 +13,7 @@ class SlidingHistory:
def __init__(self, max_window_s: float) -> None: def __init__(self, max_window_s: float) -> None:
self.max_window_s = max_window_s self.max_window_s = max_window_s
self.samples: Deque[MetricSample] = deque() self.samples: deque[MetricSample] = deque()
def add(self, sample: MetricSample) -> None: def add(self, sample: MetricSample) -> None:
"""Append one sample and prune data outside the max window. """Append one sample and prune data outside the max window.
...@@ -34,7 +34,7 @@ class SlidingHistory: ...@@ -34,7 +34,7 @@ class SlidingHistory:
while self.samples and self.samples[0].ts < cutoff: while self.samples and self.samples[0].ts < cutoff:
self.samples.popleft() self.samples.popleft()
def latest(self) -> Optional[MetricSample]: def latest(self) -> MetricSample | None:
"""Return the latest sample if available. """Return the latest sample if available.
Returns: Returns:
......
from __future__ import annotations from __future__ import annotations
import subprocess import subprocess
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Sequence
@dataclass @dataclass
...@@ -42,7 +42,7 @@ def collect_from_host( ...@@ -42,7 +42,7 @@ def collect_from_host(
if host in local_names: if host in local_names:
cmd = ["hy-smi", *hy_smi_args] cmd = ["hy-smi", *hy_smi_args]
else: else:
connect_timeout = max(1, int(round(ssh_timeout))) connect_timeout = max(1, round(ssh_timeout))
cmd = [ cmd = [
"ssh", "ssh",
"-o", "-o",
......
from __future__ import annotations from __future__ import annotations
from typing import List
def parse_csv_ints(value: str, flag: str) -> list[int]:
def parse_csv_ints(value: str, flag: str) -> List[int]:
"""Parse a non-empty comma-separated integer list. """Parse a non-empty comma-separated integer list.
Args: Args:
...@@ -17,7 +15,7 @@ def parse_csv_ints(value: str, flag: str) -> List[int]: ...@@ -17,7 +15,7 @@ def parse_csv_ints(value: str, flag: str) -> List[int]:
ValueError: If list is empty or contains non-integer tokens. ValueError: If list is empty or contains non-integer tokens.
""" """
out: List[int] = [] out: list[int] = []
for token in value.split(","): for token in value.split(","):
item = token.strip() item = token.strip()
if not item: if not item:
...@@ -30,7 +28,7 @@ def parse_csv_ints(value: str, flag: str) -> List[int]: ...@@ -30,7 +28,7 @@ def parse_csv_ints(value: str, flag: str) -> List[int]:
return out return out
def parse_csv_strings(value: str, flag: str) -> List[str]: def parse_csv_strings(value: str, flag: str) -> list[str]:
"""Parse a non-empty comma-separated string list. """Parse a non-empty comma-separated string list.
Args: Args:
......
from __future__ import annotations from __future__ import annotations
from typing import Optional, Set
import typer import typer
from hytop.core.validators import parse_csv_ints
from hytop.gpu.metrics import ( from hytop.gpu.metrics import (
SHOW_FLAG_HELP, SHOW_FLAG_HELP,
SUPPORTED_SHOW_FLAGS, SUPPORTED_SHOW_FLAGS,
...@@ -11,7 +10,6 @@ from hytop.gpu.metrics import ( ...@@ -11,7 +10,6 @@ from hytop.gpu.metrics import (
normalized_show_flags, normalized_show_flags,
) )
from hytop.gpu.service import run_monitor from hytop.gpu.service import run_monitor
from hytop.core.validators import parse_csv_ints
app = typer.Typer( app = typer.Typer(
add_completion=False, add_completion=False,
...@@ -111,7 +109,7 @@ def gpu( ...@@ -111,7 +109,7 @@ def gpu(
show_flags = normalized_show_flags( show_flags = normalized_show_flags(
[flag for flag, enabled in selected_show_flags.items() if enabled] [flag for flag, enabled in selected_show_flags.items() if enabled]
) )
parsed_device_filter: Optional[Set[int]] = None parsed_device_filter: set[int] | None = None
if device_filter: if device_filter:
parsed_device_filter = set(parse_csv_ints(device_filter, "--devices")) parsed_device_filter = set(parse_csv_ints(device_filter, "--devices"))
except ValueError as exc: except ValueError as exc:
......
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Final, Iterable, Literal, TypeAlias, TypeGuard from typing import Final, Literal, TypeAlias, TypeGuard
ShowFlag: TypeAlias = Literal[ ShowFlag: TypeAlias = Literal[
"showtemp", "showtemp",
...@@ -74,9 +74,7 @@ SHOW_FLAG_HELP: Final[dict[ShowFlag, str]] = {spec.flag: spec.cli_help for spec ...@@ -74,9 +74,7 @@ SHOW_FLAG_HELP: Final[dict[ShowFlag, str]] = {spec.flag: spec.cli_help for spec
WAIT_IDLE_REQUIRED_SHOW_FLAGS: Final[tuple[ShowFlag, ...]] = ("showmemuse", "showuse") WAIT_IDLE_REQUIRED_SHOW_FLAGS: Final[tuple[ShowFlag, ...]] = ("showmemuse", "showuse")
JSON_KEY_BY_METRIC: Final[dict[str, str]] = { JSON_KEY_BY_METRIC: Final[dict[str, str]] = {
metric: json_key metric: json_key for spec in SHOW_SPECS for metric, json_key in spec.metric_json_keys.items()
for spec in SHOW_SPECS
for metric, json_key in spec.metric_json_keys.items()
} }
......
...@@ -2,7 +2,6 @@ from __future__ import annotations ...@@ -2,7 +2,6 @@ from __future__ import annotations
import threading import threading
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Set, Tuple
from hytop.core.history import SlidingHistory from hytop.core.history import SlidingHistory
...@@ -39,8 +38,8 @@ class NodeResult: ...@@ -39,8 +38,8 @@ class NodeResult:
""" """
host: str host: str
samples: Dict[int, Sample] samples: dict[int, Sample]
error: Optional[str] = None error: str | None = None
@dataclass @dataclass
...@@ -55,7 +54,7 @@ class HostSnapshot: ...@@ -55,7 +54,7 @@ class HostSnapshot:
seq: int = 0 seq: int = 0
updated_ts: float = 0.0 updated_ts: float = 0.0
result: Optional[NodeResult] = None result: NodeResult | None = None
@dataclass @dataclass
...@@ -76,12 +75,12 @@ class MonitorState: ...@@ -76,12 +75,12 @@ class MonitorState:
""" """
max_window: float max_window: float
histories: Dict[Tuple[str, int], SlidingHistory] histories: dict[tuple[str, int], SlidingHistory]
discovered_keys: Set[Tuple[str, int]] discovered_keys: set[tuple[str, int]]
last_applied_sample_ts: Dict[Tuple[str, int], float] last_applied_sample_ts: dict[tuple[str, int], float]
monitored_keys: Set[Tuple[str, int]] monitored_keys: set[tuple[str, int]]
errors: Dict[str, str] errors: dict[str, str]
host_state: Dict[str, HostSnapshot] host_state: dict[str, HostSnapshot]
processed_seq: Dict[str, int] processed_seq: dict[str, int]
state_lock: threading.Lock state_lock: threading.Lock
stop_event: threading.Event stop_event: threading.Event
...@@ -2,7 +2,6 @@ from __future__ import annotations ...@@ -2,7 +2,6 @@ from __future__ import annotations
import json import json
import re import re
from typing import Dict
from hytop.gpu.metrics import JSON_KEY_BY_METRIC from hytop.gpu.metrics import JSON_KEY_BY_METRIC
from hytop.gpu.models import Sample from hytop.gpu.models import Sample
...@@ -43,7 +42,7 @@ def parse_number(text: str) -> float: ...@@ -43,7 +42,7 @@ def parse_number(text: str) -> float:
return float(match.group(0)) return float(match.group(0))
def parse_hy_smi_output(raw: str, sample_ts: float) -> Dict[int, Sample]: def parse_hy_smi_output(raw: str, sample_ts: float) -> dict[int, Sample]:
"""Parse hy-smi JSON output into GPU keyed samples. """Parse hy-smi JSON output into GPU keyed samples.
Args: Args:
...@@ -64,7 +63,7 @@ def parse_hy_smi_output(raw: str, sample_ts: float) -> Dict[int, Sample]: ...@@ -64,7 +63,7 @@ def parse_hy_smi_output(raw: str, sample_ts: float) -> Dict[int, Sample]:
if not isinstance(payload, dict): if not isinstance(payload, dict):
return {} return {}
result: Dict[int, Sample] = {} result: dict[int, Sample] = {}
for card_key, card_data in payload.items(): for card_key, card_data in payload.items():
if not isinstance(card_key, str): if not isinstance(card_key, str):
continue continue
......
from __future__ import annotations from __future__ import annotations
import time import time
from typing import Dict, Iterable, List, Tuple from collections.abc import Iterable
from rich import box from rich import box
from rich.console import Group from rich.console import Group
...@@ -43,10 +43,10 @@ def fmt_elapsed(elapsed_s: float) -> str: ...@@ -43,10 +43,10 @@ def fmt_elapsed(elapsed_s: float) -> str:
def build_renderable( def build_renderable(
window: float, window: float,
hosts: List[str], hosts: list[str],
histories: Dict[Tuple[str, int], SlidingHistory], histories: dict[tuple[str, int], SlidingHistory],
monitored_keys: Iterable[Tuple[str, int]], monitored_keys: Iterable[tuple[str, int]],
errors: Dict[str, str], errors: dict[str, str],
show_flags: Iterable[str], show_flags: Iterable[str],
poll_interval: float, poll_interval: float,
elapsed_since_start: float, elapsed_since_start: float,
...@@ -68,11 +68,12 @@ def build_renderable( ...@@ -68,11 +68,12 @@ def build_renderable(
now = time.monotonic() now = time.monotonic()
host_rank = {host: idx for idx, host in enumerate(hosts)} host_rank = {host: idx for idx, host in enumerate(hosts)}
key_list = sorted( key_list = sorted(monitored_keys, key=lambda x: (host_rank.get(x[0], len(hosts)), x[1]))
monitored_keys, key=lambda x: (host_rank.get(x[0], len(hosts)), x[1]) title = (
f"hytop gpu | interval={poll_interval:.2f}s | elapsed={fmt_elapsed(elapsed_since_start)}"
) )
table = Table( table = Table(
title=f"hytop gpu | interval={poll_interval:.2f}s | elapsed={fmt_elapsed(elapsed_since_start)}", title=title,
box=box.MINIMAL_HEAVY_HEAD, box=box.MINIMAL_HEAVY_HEAD,
expand=True, expand=True,
) )
...@@ -104,9 +105,7 @@ def build_renderable( ...@@ -104,9 +105,7 @@ def build_renderable(
if metric_value is None: if metric_value is None:
values.append("-") values.append("-")
else: else:
values.append( values.append(_format_metric(col.metric, history.avg(col.metric, window, now)))
_format_metric(col.metric, history.avg(col.metric, window, now))
)
table.add_row( table.add_row(
host, host,
str(gpu), str(gpu),
......
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import sys import sys
import threading import threading
import time import time
from typing import List, Optional, Sequence, Set from collections.abc import Sequence
from rich.console import Console from rich.console import Console
from rich.live import Live from rich.live import Live
...@@ -122,8 +122,8 @@ def availability_ready( ...@@ -122,8 +122,8 @@ def availability_ready(
def init_monitor_state( def init_monitor_state(
hosts: List[str], hosts: list[str],
device_filter: Optional[Set[int]], device_filter: set[int] | None,
max_window: float, max_window: float,
) -> MonitorState: ) -> MonitorState:
"""Create initial monitor state for the run. """Create initial monitor state for the run.
...@@ -137,9 +137,7 @@ def init_monitor_state( ...@@ -137,9 +137,7 @@ def init_monitor_state(
Initialized monitor state object. Initialized monitor state object.
""" """
monitored_keys = ( monitored_keys = {(h, d) for h in hosts for d in device_filter} if device_filter else set()
{(h, d) for h in hosts for d in device_filter} if device_filter else set()
)
return MonitorState( return MonitorState(
max_window=max_window, max_window=max_window,
histories={}, histories={},
...@@ -155,13 +153,13 @@ def init_monitor_state( ...@@ -155,13 +153,13 @@ def init_monitor_state(
def start_collectors( def start_collectors(
hosts: List[str], hosts: list[str],
ssh_timeout: float, ssh_timeout: float,
cmd_timeout: float, cmd_timeout: float,
hy_smi_args: Sequence[str], hy_smi_args: Sequence[str],
interval: float, interval: float,
state: MonitorState, state: MonitorState,
) -> List[threading.Thread]: ) -> list[threading.Thread]:
"""Start one daemon collector thread per host. """Start one daemon collector thread per host.
Args: Args:
...@@ -175,7 +173,7 @@ def start_collectors( ...@@ -175,7 +173,7 @@ def start_collectors(
Started collector thread list. Started collector thread list.
""" """
workers: List[threading.Thread] = [] workers: list[threading.Thread] = []
for host in hosts: for host in hosts:
worker = threading.Thread( worker = threading.Thread(
target=host_collector_loop, target=host_collector_loop,
...@@ -197,7 +195,7 @@ def start_collectors( ...@@ -197,7 +195,7 @@ def start_collectors(
return workers return workers
def drain_pending_nodes(hosts: List[str], state: MonitorState) -> List[NodeResult]: def drain_pending_nodes(hosts: list[str], state: MonitorState) -> list[NodeResult]:
"""Fetch unseen host snapshots since the previous render tick. """Fetch unseen host snapshots since the previous render tick.
Args: Args:
...@@ -208,7 +206,7 @@ def drain_pending_nodes(hosts: List[str], state: MonitorState) -> List[NodeResul ...@@ -208,7 +206,7 @@ def drain_pending_nodes(hosts: List[str], state: MonitorState) -> List[NodeResul
Newly published node results to apply this tick. Newly published node results to apply this tick.
""" """
nodes: List[NodeResult] = [] nodes: list[NodeResult] = []
with state.state_lock: with state.state_lock:
for host in hosts: for host in hosts:
snapshot = state.host_state[host] snapshot = state.host_state[host]
...@@ -221,8 +219,8 @@ def drain_pending_nodes(hosts: List[str], state: MonitorState) -> List[NodeResul ...@@ -221,8 +219,8 @@ def drain_pending_nodes(hosts: List[str], state: MonitorState) -> List[NodeResul
def apply_node_results( def apply_node_results(
nodes: List[NodeResult], nodes: list[NodeResult],
device_filter: Optional[Set[int]], device_filter: set[int] | None,
state: MonitorState, state: MonitorState,
) -> None: ) -> None:
"""Apply collected node results into histories and error state. """Apply collected node results into histories and error state.
...@@ -255,13 +253,13 @@ def apply_node_results( ...@@ -255,13 +253,13 @@ def apply_node_results(
def run_monitor( def run_monitor(
hosts: List[str], hosts: list[str],
device_filter: Optional[Set[int]], device_filter: set[int] | None,
show_flags: Sequence[str], show_flags: Sequence[str],
window: float, window: float,
interval: float, interval: float,
wait_idle: bool, wait_idle: bool,
timeout: Optional[float], timeout: float | None,
wait_idle_duration: float = 10.0, wait_idle_duration: float = 10.0,
) -> int: ) -> int:
"""Run the asynchronous collector + periodic renderer monitor loop. """Run the asynchronous collector + periodic renderer monitor loop.
...@@ -293,9 +291,7 @@ def run_monitor( ...@@ -293,9 +291,7 @@ def run_monitor(
print("argument error: --interval must be <= --window value", file=sys.stderr) print("argument error: --interval must be <= --window value", file=sys.stderr)
return 2 return 2
state = init_monitor_state( state = init_monitor_state(hosts=hosts, device_filter=device_filter, max_window=window)
hosts=hosts, device_filter=device_filter, max_window=window
)
hy_smi_args = hy_smi_args_for_show_flags(show_flags, wait_idle=wait_idle) hy_smi_args = hy_smi_args_for_show_flags(show_flags, wait_idle=wait_idle)
ssh_timeout = min(max(5 * interval, 2.0), 5.0) ssh_timeout = min(max(5 * interval, 2.0), 5.0)
cmd_timeout = min(max(10 * interval, 5.0), 10.0) cmd_timeout = min(max(10 * interval, 5.0), 10.0)
...@@ -348,23 +344,15 @@ def run_monitor( ...@@ -348,23 +344,15 @@ def run_monitor(
errors=state.errors, errors=state.errors,
) )
): ):
console.print( console.print("status: success (all monitored GPUs are available)")
"status: success (all monitored GPUs are available)"
)
return 0 return 0
if ( if wait_idle and timeout is not None and elapsed_since_start >= timeout:
wait_idle
and timeout is not None
and elapsed_since_start >= timeout
):
err_console.print( err_console.print(
"status: timeout (availability condition not met)", "status: timeout (availability condition not met)",
style="yellow", style="yellow",
) )
return 124 return 124
time.sleep( time.sleep(max(0.0, render_interval - (time.monotonic() - loop_started)))
max(0.0, render_interval - (time.monotonic() - loop_started))
)
finally: finally:
state.stop_event.set() state.stop_event.set()
for worker in workers: for worker in workers:
......
from __future__ import annotations from __future__ import annotations
from typing import Optional
import typer import typer
from hytop import __version__ from hytop import __version__
from hytop.core.validators import parse_csv_strings, parse_positive_float
from hytop.cpu.cli import app as cpu_app from hytop.cpu.cli import app as cpu_app
from hytop.gpu.cli import app as gpu_app from hytop.gpu.cli import app as gpu_app
from hytop.core.validators import parse_csv_strings, parse_positive_float
app = typer.Typer( app = typer.Typer(
help="hytop toolkit command line", help="hytop toolkit command line",
...@@ -52,7 +50,7 @@ def root( ...@@ -52,7 +50,7 @@ def root(
"--window", "--window",
help="Single rolling window in seconds. Default: 5.0", help="Single rolling window in seconds. Default: 5.0",
), ),
timeout: Optional[float] = typer.Option( timeout: float | None = typer.Option(
None, None,
"--timeout", "--timeout",
help="Max runtime in seconds.", help="Max runtime in seconds.",
......
"""Tests for hytop.core.history.SlidingHistory."""
from __future__ import annotations
from dataclasses import dataclass
import pytest
from hytop.core.history import SlidingHistory
@dataclass
class _Sample:
"""Minimal MetricSample for testing."""
ts: float
value: float = 0.0
class TestSlidingHistoryBasics:
def test_empty_latest_is_none(self):
h = SlidingHistory(max_window_s=10)
assert h.latest() is None
def test_add_and_latest(self):
h = SlidingHistory(max_window_s=10)
s = _Sample(ts=1.0, value=42.0)
h.add(s)
assert h.latest() is s
def test_latest_returns_most_recent(self):
h = SlidingHistory(max_window_s=10)
h.add(_Sample(ts=1.0, value=1.0))
s2 = _Sample(ts=2.0, value=2.0)
h.add(s2)
assert h.latest() is s2
def test_add_prunes_outside_window(self):
h = SlidingHistory(max_window_s=5)
h.add(_Sample(ts=0.0))
h.add(_Sample(ts=6.0)) # now=6.0, cutoff=1.0 → ts=0.0 pruned
assert len(h.samples) == 1
def test_add_keeps_samples_within_window(self):
h = SlidingHistory(max_window_s=10)
h.add(_Sample(ts=0.0))
h.add(_Sample(ts=8.0)) # cutoff=-2.0, both kept
assert len(h.samples) == 2
class TestSlidingHistoryAvg:
def test_empty_returns_zero(self):
h = SlidingHistory(max_window_s=10)
assert h.avg("value", window_s=5, now=10.0) == 0.0
def test_single_sample_in_window(self):
h = SlidingHistory(max_window_s=10)
h.add(_Sample(ts=9.0, value=4.0))
assert h.avg("value", window_s=5, now=10.0) == pytest.approx(4.0)
def test_average_of_multiple(self):
h = SlidingHistory(max_window_s=10)
h.add(_Sample(ts=8.0, value=2.0))
h.add(_Sample(ts=9.0, value=4.0))
assert h.avg("value", window_s=5, now=10.0) == pytest.approx(3.0)
def test_sample_outside_window_excluded(self):
h = SlidingHistory(max_window_s=20)
h.add(_Sample(ts=0.0, value=100.0)) # inside max_window but outside avg window
h.add(_Sample(ts=9.0, value=2.0))
# avg window=5, now=10 → cutoff=5.0 → ts=0.0 excluded
assert h.avg("value", window_s=5, now=10.0) == pytest.approx(2.0)
def test_no_samples_in_window_returns_zero(self):
h = SlidingHistory(max_window_s=100)
h.add(_Sample(ts=0.0, value=5.0))
# window=2, now=10 → cutoff=8, ts=0 excluded
assert h.avg("value", window_s=2, now=10.0) == 0.0
def test_none_value_excluded_from_avg(self):
"""avg should skip samples where the attribute is None."""
h = SlidingHistory(max_window_s=10)
h.add(_Sample(ts=9.0, value=6.0))
# Add a sample without a 'value' attr at all — use missing attribute branch
# (MetricSample Protocol only requires ts; extra attrs may be None)
@dataclass
class _NoneValue:
ts: float
value: None = None
h.add(_NoneValue(ts=9.5))
# Only the float sample should count
assert h.avg("value", window_s=5, now=10.0) == pytest.approx(6.0)
"""Tests for hytop.gpu.metrics."""
from __future__ import annotations
from hytop.gpu.metrics import (
hy_smi_args_for_show_flags,
normalized_show_flags,
render_columns_for_show_flags,
)
class TestNormalizedShowFlags:
def test_none_returns_all_defaults(self):
flags = normalized_show_flags(None)
assert flags == ["showtemp", "showpower", "showsclk", "showmemuse", "showuse"]
def test_empty_list_returns_all_defaults(self):
assert normalized_show_flags([]) == normalized_show_flags(None)
def test_single_flag_preserved(self):
assert normalized_show_flags(["showtemp"]) == ["showtemp"]
def test_order_preserved(self):
assert normalized_show_flags(["showuse", "showtemp"]) == ["showuse", "showtemp"]
def test_duplicates_deduplicated(self):
assert normalized_show_flags(["showtemp", "showtemp"]) == ["showtemp"]
def test_unknown_flag_ignored(self):
result = normalized_show_flags(["unknown_flag"])
# Falls back to defaults when nothing valid remains
assert result == normalized_show_flags(None)
def test_mix_valid_and_invalid(self):
result = normalized_show_flags(["showtemp", "INVALID"])
assert result == ["showtemp"]
class TestHySmiArgsForShowFlags:
def test_includes_json_flag(self):
args = hy_smi_args_for_show_flags(["showtemp"], wait_idle=False)
assert "--json" in args
def test_showtemp_maps_to_showtemp(self):
args = hy_smi_args_for_show_flags(["showtemp"], wait_idle=False)
assert "--showtemp" in args
def test_showsclk_maps_to_showhcuclocks(self):
# showsclk has hy_smi_flag override → should emit --showhcuclocks
args = hy_smi_args_for_show_flags(["showsclk"], wait_idle=False)
assert "--showhcuclocks" in args
assert "--showsclk" not in args
def test_wait_idle_injects_required_metrics(self):
# Even if only showtemp requested, wait-idle needs showmemuse + showuse
args = hy_smi_args_for_show_flags(["showtemp"], wait_idle=True)
assert "--showmemuse" in args
assert "--showuse" in args
def test_wait_idle_no_duplication(self):
# If showmemuse already requested, should not appear twice
args = hy_smi_args_for_show_flags(["showmemuse", "showuse"], wait_idle=True)
assert args.count("--showmemuse") == 1
assert args.count("--showuse") == 1
class TestRenderColumnsForShowFlags:
def test_showtemp_gives_temp_column(self):
cols = render_columns_for_show_flags(["showtemp"])
labels = [c.label for c in cols]
assert "Temp" in labels
def test_showuse_gives_avg_column(self):
cols = render_columns_for_show_flags(["showuse"])
# showuse has avg_label set → should have two columns (instant + avg)
assert any(c.avg_label is not None for c in cols)
def test_order_matches_input(self):
cols = render_columns_for_show_flags(["showuse", "showtemp"])
labels = [c.label for c in cols]
# GPU% (avg) columns come before Temp
gpu_idx = next(i for i, label in enumerate(labels) if label == "GPU%")
temp_idx = next(i for i, label in enumerate(labels) if label == "Temp")
assert gpu_idx < temp_idx
"""Tests for hytop.gpu.parser."""
from __future__ import annotations
import json
import pytest
from hytop.gpu.parser import parse_hy_smi_output, parse_number, strip_ansi
# ---------------------------------------------------------------------------
# Real hy-smi JSON fixture (from actual 8-card Hygon DCU node)
# ---------------------------------------------------------------------------
# Full-flag output (--showtemp --showpower --showhcuclocks --showmemuse --showuse --json)
# Representative cards: card0 (idle) and card7 (100% HCU load)
HY_SMI_FULL = {
"card0": {
"Average Graphics Package Power (W)": "157.0",
"Temperature (Sensor edge) (C)": "31.0",
"Temperature (Sensor junction) (C)": "34.0",
"Temperature (Sensor mem) (C)": "28.0",
"Temperature (Sensor core) (C)": "30.0",
"HCU use (%)": "0.0",
"HCU memory use (%)": "89",
"sclk clock level": "10",
"sclk clock speed": "1500Mhz",
},
"card7": {
"Average Graphics Package Power (W)": "141.0",
"Temperature (Sensor edge) (C)": "28.0",
"Temperature (Sensor junction) (C)": "33.0",
"Temperature (Sensor mem) (C)": "25.0",
"Temperature (Sensor core) (C)": "25.0",
"HCU use (%)": "100.0",
"HCU memory use (%)": "89",
"sclk clock level": "10",
"sclk clock speed": "1500Mhz",
},
}
# Temp-only output (--showtemp --json): extra sensor keys should be ignored
HY_SMI_TEMP_ONLY = {
"card0": {
"Temperature (Sensor edge) (C)": "28.0",
"Temperature (Sensor junction) (C)": "31.0",
"Temperature (Sensor mem) (C)": "25.0",
"Temperature (Sensor core) (C)": "26.0",
},
}
# ---------------------------------------------------------------------------
# strip_ansi
# ---------------------------------------------------------------------------
class TestStripAnsi:
def test_plain_text_unchanged(self):
assert strip_ansi("hello") == "hello"
def test_color_codes_removed(self):
assert strip_ansi("\x1b[31mred\x1b[0m") == "red"
def test_empty_string(self):
assert strip_ansi("") == ""
def test_multiple_codes(self):
assert strip_ansi("\x1b[1m\x1b[4mbold\x1b[0m") == "bold"
# ---------------------------------------------------------------------------
# parse_number
# ---------------------------------------------------------------------------
class TestParseNumber:
def test_integer_string(self):
assert parse_number("89") == pytest.approx(89.0)
def test_float_string(self):
assert parse_number("157.0") == pytest.approx(157.0)
def test_value_with_unit_suffix(self):
# "1500Mhz" — real sclk clock speed format from hy-smi
assert parse_number("1500Mhz") == pytest.approx(1500.0)
def test_no_number_raises(self):
with pytest.raises(ValueError, match="cannot parse"):
parse_number("N/A")
def test_negative_number(self):
assert parse_number("-5.5") == pytest.approx(-5.5)
# ---------------------------------------------------------------------------
# parse_hy_smi_output — with real fixture data
# ---------------------------------------------------------------------------
class TestParseHySmiOutput:
def test_full_output_card_count(self):
raw = json.dumps(HY_SMI_FULL)
result = parse_hy_smi_output(raw, sample_ts=1.0)
assert set(result.keys()) == {0, 7}
def test_full_output_card0_metrics(self):
raw = json.dumps(HY_SMI_FULL)
result = parse_hy_smi_output(raw, sample_ts=1.0)
s = result[0]
assert s.temp_c == pytest.approx(30.0)
assert s.avg_pwr_w == pytest.approx(157.0)
assert s.hcu_pct == pytest.approx(0.0)
assert s.vram_pct == pytest.approx(89.0) # integer string "89" → 89.0
assert s.sclk_mhz == pytest.approx(1500.0) # "1500Mhz" → 1500.0
def test_full_output_card7_hcu_load(self):
raw = json.dumps(HY_SMI_FULL)
result = parse_hy_smi_output(raw, sample_ts=1.0)
assert result[7].hcu_pct == pytest.approx(100.0)
def test_temp_only_output(self):
raw = json.dumps(HY_SMI_TEMP_ONLY)
result = parse_hy_smi_output(raw, sample_ts=1.0)
s = result[0]
assert s.temp_c == pytest.approx(26.0)
# Unrelated sensor keys must not populate fields
assert s.avg_pwr_w is None
assert s.hcu_pct is None
def test_sample_ts_propagated(self):
raw = json.dumps(HY_SMI_FULL)
result = parse_hy_smi_output(raw, sample_ts=42.5)
assert result[0].ts == pytest.approx(42.5)
def test_unknown_card_keys_ignored(self):
payload = {"sys_info": {"foo": "bar"}, "card0": HY_SMI_FULL["card0"]}
result = parse_hy_smi_output(json.dumps(payload), sample_ts=1.0)
assert list(result.keys()) == [0]
def test_empty_string_returns_empty(self):
assert parse_hy_smi_output("", sample_ts=1.0) == {}
def test_invalid_json_returns_empty(self):
assert parse_hy_smi_output("not json", sample_ts=1.0) == {}
def test_ansi_stripped_before_parse(self):
# Some hy-smi versions emit ANSI colors; parser must strip them first
raw_with_ansi = "\x1b[0m" + json.dumps(HY_SMI_TEMP_ONLY) + "\x1b[0m"
result = parse_hy_smi_output(raw_with_ansi, sample_ts=1.0)
assert 0 in result
"""Tests for hytop.gpu.render formatting helpers."""
from __future__ import annotations
from hytop.gpu.render import _format_metric, fmt_elapsed, fmt_window
class TestFmtWindow:
def test_integer_seconds(self):
assert fmt_window(5.0) == "5s"
def test_fractional_seconds(self):
assert fmt_window(0.5) == "0.5s"
def test_large_integer(self):
assert fmt_window(300.0) == "300s"
def test_non_round_float(self):
assert fmt_window(1.25) == "1.2s"
class TestFmtElapsed:
def test_zero(self):
assert fmt_elapsed(0) == "00:00:00"
def test_seconds_only(self):
assert fmt_elapsed(45) == "00:00:45"
def test_minutes_and_seconds(self):
assert fmt_elapsed(90) == "00:01:30"
def test_hours(self):
assert fmt_elapsed(3661) == "01:01:01"
def test_negative_clamped_to_zero(self):
assert fmt_elapsed(-5) == "00:00:00"
class TestFormatMetric:
def test_none_returns_dash(self):
assert _format_metric("temp_c", None) == "-"
def test_temp_format(self):
result = _format_metric("temp_c", 30.0)
assert "30.0" in result
assert "C" in result
def test_power_format(self):
result = _format_metric("avg_pwr_w", 157.0)
assert "157.0" in result
assert "W" in result
def test_pct_format_vram(self):
result = _format_metric("vram_pct", 89.0)
assert "89.00" in result
assert "%" in result
def test_pct_format_hcu(self):
result = _format_metric("hcu_pct", 0.0)
assert "0.00" in result and "%" in result
def test_sclk_format(self):
result = _format_metric("sclk_mhz", 1500.0)
assert "1500" in result
assert "MHz" in result
def test_unknown_metric_str(self):
assert _format_metric("unknown", 42) == "42"
"""Tests for hytop.gpu.service — business logic (collect_from_host mocked)."""
from __future__ import annotations
import json
import time
from unittest.mock import patch
from hytop.core.history import SlidingHistory
from hytop.core.ssh import CollectResult
from hytop.gpu.models import MonitorState, NodeResult, Sample
from hytop.gpu.service import (
apply_node_results,
availability_ready,
collect_node,
drain_pending_nodes,
init_monitor_state,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
HY_SMI_FULL_JSON = json.dumps(
{
"card0": {
"Average Graphics Package Power (W)": "157.0",
"Temperature (Sensor core) (C)": "30.0",
"HCU use (%)": "0.0",
"HCU memory use (%)": "89",
"sclk clock speed": "1500Mhz",
},
"card7": {
"Average Graphics Package Power (W)": "141.0",
"Temperature (Sensor core) (C)": "25.0",
"HCU use (%)": "100.0",
"HCU memory use (%)": "89",
"sclk clock speed": "1500Mhz",
},
}
)
def _state(hosts=("localhost",), device_filter=None, max_window=10.0) -> MonitorState:
return init_monitor_state(hosts=list(hosts), device_filter=device_filter, max_window=max_window)
def _sample(ts: float, hcu_pct: float = 0.0, vram_pct: float = 0.0) -> Sample:
return Sample(ts=ts, hcu_pct=hcu_pct, vram_pct=vram_pct)
# ---------------------------------------------------------------------------
# collect_node
# ---------------------------------------------------------------------------
class TestCollectNode:
@patch("hytop.gpu.service.collect_from_host")
def test_success_returns_samples(self, mock_collect):
mock_collect.return_value = CollectResult(
host="localhost", stdout=HY_SMI_FULL_JSON, stderr=""
)
result = collect_node("localhost", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"])
assert result.error is None
assert set(result.samples.keys()) == {0, 7}
@patch("hytop.gpu.service.collect_from_host")
def test_host_error_propagated(self, mock_collect):
mock_collect.return_value = CollectResult(
host="node01", stdout="", stderr="", error="timeout after 10.0s"
)
result = collect_node("node01", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"])
assert result.error is not None
assert "timeout" in result.error
@patch("hytop.gpu.service.collect_from_host")
def test_empty_output_yields_error(self, mock_collect):
mock_collect.return_value = CollectResult(host="localhost", stdout="", stderr="")
result = collect_node("localhost", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"])
assert result.error is not None
# ---------------------------------------------------------------------------
# apply_node_results
# ---------------------------------------------------------------------------
class TestApplyNodeResults:
def test_successful_node_adds_to_history(self):
state = _state(hosts=["localhost"])
sample = _sample(ts=1.0, hcu_pct=50.0)
node = NodeResult(host="localhost", samples={0: sample})
apply_node_results([node], device_filter=None, state=state)
assert ("localhost", 0) in state.histories
def test_error_node_sets_error(self):
state = _state(hosts=["localhost"])
node = NodeResult(host="localhost", samples={}, error="connection refused")
apply_node_results([node], device_filter=None, state=state)
assert state.errors["localhost"] == "connection refused"
def test_device_filter_excludes_other_gpus(self):
state = _state(hosts=["localhost"], device_filter={0})
s0 = _sample(ts=1.0)
s1 = _sample(ts=1.0)
node = NodeResult(host="localhost", samples={0: s0, 1: s1})
apply_node_results([node], device_filter={0}, state=state)
assert ("localhost", 0) in state.histories
assert ("localhost", 1) not in state.histories
def test_success_clears_previous_error(self):
state = _state(hosts=["localhost"])
state.errors["localhost"] = "old error"
node = NodeResult(host="localhost", samples={0: _sample(ts=1.0)})
apply_node_results([node], device_filter=None, state=state)
assert "localhost" not in state.errors
def test_duplicate_sample_not_added(self):
state = _state(hosts=["localhost"])
sample = _sample(ts=5.0)
node = NodeResult(host="localhost", samples={0: sample})
apply_node_results([node], device_filter=None, state=state)
apply_node_results([node], device_filter=None, state=state) # same ts
assert len(state.histories[("localhost", 0)].samples) == 1
# ---------------------------------------------------------------------------
# availability_ready
# ---------------------------------------------------------------------------
class TestAvailabilityReady:
def _make_history(self, hcu_pct: float, vram_pct: float) -> SlidingHistory:
"""Build a SlidingHistory with one fresh sample using real monotonic time."""
h = SlidingHistory(max_window_s=30)
h.add(_sample(ts=time.monotonic(), hcu_pct=hcu_pct, vram_pct=vram_pct))
return h
def test_idle_gpu_returns_true(self):
key = ("localhost", 0)
histories = {key: self._make_history(hcu_pct=0.0, vram_pct=0.0)}
assert availability_ready(
window=5.0,
histories=histories,
monitored_keys={key},
hosts=["localhost"],
errors={},
)
def test_busy_gpu_returns_false(self):
key = ("localhost", 0)
histories = {key: self._make_history(hcu_pct=100.0, vram_pct=89.0)}
assert not availability_ready(
window=5.0,
histories=histories,
monitored_keys={key},
hosts=["localhost"],
errors={},
)
def test_host_error_returns_false(self):
key = ("localhost", 0)
histories = {key: self._make_history(hcu_pct=0.0, vram_pct=0.0)}
assert not availability_ready(
window=5.0,
histories=histories,
monitored_keys={key},
hosts=["localhost"],
errors={"localhost": "connection refused"},
)
def test_empty_monitored_keys_returns_false(self):
assert not availability_ready(
window=5.0,
histories={},
monitored_keys=set(),
hosts=["localhost"],
errors={},
)
def test_missing_history_returns_false(self):
key = ("localhost", 0)
assert not availability_ready(
window=5.0,
histories={}, # no history for this key
monitored_keys={key},
hosts=["localhost"],
errors={},
)
# ---------------------------------------------------------------------------
# drain_pending_nodes
# ---------------------------------------------------------------------------
class TestDrainPendingNodes:
def test_drains_new_result(self):
state = _state(hosts=["localhost"])
result = NodeResult(host="localhost", samples={})
with state.state_lock:
snap = state.host_state["localhost"]
snap.seq = 1
snap.result = result
nodes = drain_pending_nodes(["localhost"], state)
assert len(nodes) == 1
assert nodes[0] is result
def test_does_not_drain_already_processed(self):
state = _state(hosts=["localhost"])
state.processed_seq["localhost"] = 1
with state.state_lock:
snap = state.host_state["localhost"]
snap.seq = 1
snap.result = NodeResult(host="localhost", samples={})
nodes = drain_pending_nodes(["localhost"], state)
assert nodes == []
"""Tests for hytop.core.ssh.collect_from_host (subprocess mocked)."""
from __future__ import annotations
import subprocess
from unittest.mock import MagicMock, patch
from hytop.core.ssh import collect_from_host
def _make_proc(returncode=0, stdout="", stderr=""):
m = MagicMock()
m.returncode = returncode
m.stdout = stdout
m.stderr = stderr
return m
class TestCollectFromHostLocal:
@patch("hytop.core.ssh.subprocess.run")
def test_success_returns_no_error(self, mock_run):
mock_run.return_value = _make_proc(stdout='{"card0":{}}')
result = collect_from_host(
"localhost", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"]
)
assert result.error is None
assert result.host == "localhost"
@patch("hytop.core.ssh.subprocess.run")
def test_local_invokes_hy_smi_directly(self, mock_run):
mock_run.return_value = _make_proc()
collect_from_host("localhost", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"])
cmd = mock_run.call_args[0][0]
assert cmd[0] == "hy-smi"
assert "ssh" not in cmd
@patch("hytop.core.ssh.subprocess.run")
def test_127_0_0_1_treated_as_local(self, mock_run):
mock_run.return_value = _make_proc()
collect_from_host("127.0.0.1", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"])
cmd = mock_run.call_args[0][0]
assert cmd[0] == "hy-smi"
@patch("hytop.core.ssh.subprocess.run")
def test_nonzero_exit_returns_error(self, mock_run):
mock_run.return_value = _make_proc(returncode=1, stderr="permission denied")
result = collect_from_host(
"localhost", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"]
)
assert result.error is not None
assert "exit 1" in result.error
@patch(
"hytop.core.ssh.subprocess.run",
side_effect=subprocess.TimeoutExpired("cmd", 10),
)
def test_timeout_returns_error(self, mock_run):
result = collect_from_host(
"localhost", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"]
)
assert result.error is not None
assert "timeout" in result.error
@patch("hytop.core.ssh.subprocess.run", side_effect=OSError("no such file"))
def test_oserror_returns_error(self, mock_run):
result = collect_from_host(
"localhost", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"]
)
assert result.error is not None
assert "no such file" in result.error
class TestCollectFromHostRemote:
@patch("hytop.core.ssh.subprocess.run")
def test_remote_uses_ssh(self, mock_run):
mock_run.return_value = _make_proc(stdout="{}")
collect_from_host("node01", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"])
cmd = mock_run.call_args[0][0]
assert cmd[0] == "ssh"
@patch("hytop.core.ssh.subprocess.run")
def test_remote_hostname_in_cmd(self, mock_run):
mock_run.return_value = _make_proc(stdout="{}")
collect_from_host("node01", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"])
cmd = mock_run.call_args[0][0]
assert "node01" in cmd
@patch("hytop.core.ssh.subprocess.run")
def test_remote_batch_mode_set(self, mock_run):
mock_run.return_value = _make_proc(stdout="{}")
collect_from_host("node01", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"])
cmd = mock_run.call_args[0][0]
assert "BatchMode=yes" in cmd
@patch("hytop.core.ssh.subprocess.run")
def test_hy_smi_args_forwarded(self, mock_run):
mock_run.return_value = _make_proc(stdout="{}")
collect_from_host(
"node01",
ssh_timeout=5,
cmd_timeout=10,
hy_smi_args=["--json", "--showtemp"],
)
cmd = mock_run.call_args[0][0]
assert "--json" in cmd
assert "--showtemp" in cmd
"""Tests for hytop.core.validators."""
from __future__ import annotations
import pytest
from hytop.core.validators import (
parse_csv_ints,
parse_csv_strings,
parse_positive_float,
)
# ---------------------------------------------------------------------------
# parse_csv_ints
# ---------------------------------------------------------------------------
class TestParseCsvInts:
def test_single(self):
assert parse_csv_ints("0", "--devices") == [0]
def test_multiple(self):
assert parse_csv_ints("0,1,2", "--devices") == [0, 1, 2]
def test_whitespace_trimmed(self):
assert parse_csv_ints(" 0 , 1 ", "--devices") == [0, 1]
def test_empty_string_raises(self):
with pytest.raises(ValueError, match="cannot be empty"):
parse_csv_ints("", "--devices")
def test_only_commas_raises(self):
with pytest.raises(ValueError, match="cannot be empty"):
parse_csv_ints(",,,", "--devices")
def test_non_integer_raises(self):
with pytest.raises(ValueError, match="non-integer"):
parse_csv_ints("0,abc", "--devices")
def test_negative_rejected(self):
# negative numbers fail .isdigit() → non-integer error
with pytest.raises(ValueError, match="non-integer"):
parse_csv_ints("-1", "--devices")
def test_flag_in_error_message(self):
with pytest.raises(ValueError, match="--devices"):
parse_csv_ints("x", "--devices")
# ---------------------------------------------------------------------------
# parse_csv_strings
# ---------------------------------------------------------------------------
class TestParseCsvStrings:
def test_single(self):
assert parse_csv_strings("localhost", "--hosts") == ["localhost"]
def test_multiple(self):
assert parse_csv_strings("node01,node02", "--hosts") == ["node01", "node02"]
def test_whitespace_trimmed(self):
assert parse_csv_strings(" node01 , node02 ", "--hosts") == ["node01", "node02"]
def test_empty_string_raises(self):
with pytest.raises(ValueError, match="cannot be empty"):
parse_csv_strings("", "--hosts")
def test_only_commas_raises(self):
with pytest.raises(ValueError, match="cannot be empty"):
parse_csv_strings(",", "--hosts")
def test_flag_in_error_message(self):
with pytest.raises(ValueError, match="--hosts"):
parse_csv_strings("", "--hosts")
# ---------------------------------------------------------------------------
# parse_positive_float
# ---------------------------------------------------------------------------
class TestParsePositiveFloat:
def test_integer_string(self):
assert parse_positive_float("1", "--window") == 1.0
def test_float_string(self):
assert parse_positive_float("0.5", "--window") == pytest.approx(0.5)
def test_zero_raises(self):
with pytest.raises(ValueError, match="positive"):
parse_positive_float("0", "--window")
def test_negative_raises(self):
with pytest.raises(ValueError, match="positive"):
parse_positive_float("-1", "--window")
def test_non_numeric_raises(self):
with pytest.raises(ValueError, match="non-numeric"):
parse_positive_float("abc", "--window")
def test_flag_in_error_message(self):
with pytest.raises(ValueError, match="--window"):
parse_positive_float("0", "--window")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment