Commit 7fa18525 authored by one's avatar one
Browse files

Update hytop-gpu collectors

parent 24bf8df9
...@@ -7,7 +7,7 @@ uv pip install -e . ...@@ -7,7 +7,7 @@ uv pip install -e .
hytop gpu --help hytop gpu --help
``` ```
## Prerequesites ## Prerequisites
- Python >= 3.10 - Python >= 3.10
- Python packages: `rich`, `typer` - Python packages: `rich`, `typer`
...@@ -33,6 +33,10 @@ hytop gpu --devices 0,1 --wait-idle ...@@ -33,6 +33,10 @@ hytop gpu --devices 0,1 --wait-idle
# Wait at most 300s for availability (exit 124 on timeout) # Wait at most 300s for availability (exit 124 on timeout)
hytop gpu --devices 0,1 --wait-idle --timeout 300 hytop gpu --devices 0,1 --wait-idle --timeout 300
# Fine-grained columns (output order follows show-flag order)
hytop gpu --showtemp --showpower
hytop gpu --showpower --showtemp
``` ```
Queue jobs in shared environments: Queue jobs in shared environments:
...@@ -56,6 +60,19 @@ Designed to be script-friendly: ...@@ -56,6 +60,19 @@ Designed to be script-friendly:
* `130`: Interrupted by the user (Ctrl+C). * `130`: Interrupted by the user (Ctrl+C).
* `2`: Argument or input error. * `2`: Argument or input error.
### Fine-grained metric flags
`hytop gpu` uses formatted `hy-smi --json` output and supports a subset of `hy-smi` `--show*` flags:
- `--showtemp`: GPU core temperature (`Temp`)
- `--showpower`: average package power (`AvgPwr`, plus `AvgPwr@window`)
- `--showhcuclocks`: sclk frequency (`sclk`)
- `--showmemuse`: VRAM usage (`VRAM%`)
- `--showuse`: GPU utilization (`GPU%`, plus `GPU%@window`)
If no `--show*` flags are specified, hytop defaults to:
`--showtemp --showpower --showhcuclocks --showmemuse --showuse`.
## Development ## Development
### Version bump ### Version bump
......
...@@ -58,7 +58,13 @@ class SlidingHistory: ...@@ -58,7 +58,13 @@ class SlidingHistory:
if not self.samples: if not self.samples:
return 0.0 return 0.0
cutoff = now - window_s cutoff = now - window_s
values = [getattr(s, metric) for s in self.samples if s.ts >= cutoff] values = [
value
for s in self.samples
if s.ts >= cutoff
for value in [getattr(s, metric)]
if isinstance(value, (int, float))
]
if not values: if not values:
return 0.0 return 0.0
return float(sum(values) / len(values)) return float(sum(values) / len(values))
...@@ -2,6 +2,7 @@ from __future__ import annotations ...@@ -2,6 +2,7 @@ from __future__ import annotations
import subprocess import subprocess
from dataclasses import dataclass from dataclasses import dataclass
from typing import Sequence
@dataclass @dataclass
...@@ -21,7 +22,12 @@ class CollectResult: ...@@ -21,7 +22,12 @@ class CollectResult:
error: str | None = None error: str | None = None
def collect_from_host(host: str, ssh_timeout: float, cmd_timeout: float) -> CollectResult: def collect_from_host(
host: str,
ssh_timeout: float,
cmd_timeout: float,
hy_smi_args: Sequence[str],
) -> CollectResult:
"""Run hy-smi locally or via SSH and return raw output. """Run hy-smi locally or via SSH and return raw output.
Args: Args:
...@@ -34,7 +40,7 @@ def collect_from_host(host: str, ssh_timeout: float, cmd_timeout: float) -> Coll ...@@ -34,7 +40,7 @@ def collect_from_host(host: str, ssh_timeout: float, cmd_timeout: float) -> Coll
""" """
local_names = {"localhost", "127.0.0.1", "::1"} local_names = {"localhost", "127.0.0.1", "::1"}
if host in local_names: if host in local_names:
cmd = ["hy-smi"] cmd = ["hy-smi", *hy_smi_args]
else: else:
connect_timeout = max(1, int(round(ssh_timeout))) connect_timeout = max(1, int(round(ssh_timeout)))
cmd = [ cmd = [
...@@ -45,6 +51,7 @@ def collect_from_host(host: str, ssh_timeout: float, cmd_timeout: float) -> Coll ...@@ -45,6 +51,7 @@ def collect_from_host(host: str, ssh_timeout: float, cmd_timeout: float) -> Coll
f"ConnectTimeout={connect_timeout}", f"ConnectTimeout={connect_timeout}",
host, host,
"hy-smi", "hy-smi",
*hy_smi_args,
] ]
try: try:
......
...@@ -5,6 +5,7 @@ from typing import Optional, Set ...@@ -5,6 +5,7 @@ from typing import Optional, Set
import typer import typer
from hytop import __version__ from hytop import __version__
from hytop.gpu.metrics import SUPPORTED_SHOW_FLAGS, normalized_show_flags
from hytop.gpu.service import run_monitor from hytop.gpu.service import run_monitor
from hytop.gpu.validators import parse_csv_ints, parse_csv_strings, parse_positive_float from hytop.gpu.validators import parse_csv_ints, parse_csv_strings, parse_positive_float
...@@ -13,6 +14,8 @@ app = typer.Typer( ...@@ -13,6 +14,8 @@ app = typer.Typer(
context_settings={"help_option_names": ["-h", "--help"]}, context_settings={"help_option_names": ["-h", "--help"]},
) )
SHOW_FLAG_ORDER_KEY = "show_flag_order"
def version_callback(value: bool) -> None: def version_callback(value: bool) -> None:
"""Handle Typer eager version option. """Handle Typer eager version option.
...@@ -29,8 +32,23 @@ def version_callback(value: bool) -> None: ...@@ -29,8 +32,23 @@ def version_callback(value: bool) -> None:
raise typer.Exit() raise typer.Exit()
def remember_show_flag_callback(ctx: typer.Context, param: object, value: bool) -> bool:
"""Record --show* flags in parser encounter order."""
if not value:
return value
flag = getattr(param, "name", None)
if not isinstance(flag, str) or flag not in SUPPORTED_SHOW_FLAGS:
return value
ordered = ctx.meta.setdefault(SHOW_FLAG_ORDER_KEY, [])
if flag not in ordered:
ordered.append(flag)
return value
@app.callback(invoke_without_command=True) @app.callback(invoke_without_command=True)
def gpu( def gpu(
ctx: typer.Context,
hosts: str = typer.Option( hosts: str = typer.Option(
"localhost", "localhost",
"--hosts", "--hosts",
...@@ -59,6 +77,36 @@ def gpu( ...@@ -59,6 +77,36 @@ def gpu(
"--wait-idle", "--wait-idle",
help="Exit 0 when all monitored GPUs have zero VRAM/HCU avg in the configured window.", help="Exit 0 when all monitored GPUs have zero VRAM/HCU avg in the configured window.",
), ),
showtemp: bool = typer.Option(
False,
"--showtemp",
callback=remember_show_flag_callback,
help="Display GPU core temperature.",
),
showpower: bool = typer.Option(
False,
"--showpower",
callback=remember_show_flag_callback,
help="Display average GPU power.",
),
showhcuclocks: bool = typer.Option(
False,
"--showhcuclocks",
callback=remember_show_flag_callback,
help="Display GPU sclk frequency.",
),
showmemuse: bool = typer.Option(
False,
"--showmemuse",
callback=remember_show_flag_callback,
help="Display GPU VRAM usage.",
),
showuse: bool = typer.Option(
False,
"--showuse",
callback=remember_show_flag_callback,
help="Display GPU utilization.",
),
timeout: Optional[float] = typer.Option( timeout: Optional[float] = typer.Option(
None, None,
"--timeout", "--timeout",
...@@ -77,6 +125,24 @@ def gpu( ...@@ -77,6 +125,24 @@ def gpu(
try: try:
host_list = parse_csv_strings(hosts, "--hosts") host_list = parse_csv_strings(hosts, "--hosts")
selected_show_flags = {
"showtemp": showtemp,
"showpower": showpower,
"showhcuclocks": showhcuclocks,
"showmemuse": showmemuse,
"showuse": showuse,
}
requested_order = [
flag
for flag in ctx.meta.get(SHOW_FLAG_ORDER_KEY, [])
if selected_show_flags.get(flag, False)
]
if requested_order:
show_flags = normalized_show_flags(requested_order)
else:
show_flags = normalized_show_flags(
[flag for flag, enabled in selected_show_flags.items() if enabled]
)
parsed_device_filter: Optional[Set[int]] = None parsed_device_filter: Optional[Set[int]] = None
if device_filter: if device_filter:
parsed_device_filter = set(parse_csv_ints(device_filter, "--devices")) parsed_device_filter = set(parse_csv_ints(device_filter, "--devices"))
...@@ -93,6 +159,7 @@ def gpu( ...@@ -93,6 +159,7 @@ def gpu(
code = run_monitor( code = run_monitor(
hosts=host_list, hosts=host_list,
device_filter=parsed_device_filter, device_filter=parsed_device_filter,
show_flags=show_flags,
window=window_value, window=window_value,
interval=interval, interval=interval,
wait_idle=wait_idle, wait_idle=wait_idle,
......
from __future__ import annotations
from dataclasses import dataclass
from typing import Final, Iterable
@dataclass(frozen=True)
class RenderColumn:
label: str
metric: str
avg_label: str | None = None
@dataclass(frozen=True)
class ShowSpec:
flag: str
metric_json_keys: dict[str, str]
columns: tuple[RenderColumn, ...]
SHOW_SPECS: Final[tuple[ShowSpec, ...]] = (
ShowSpec(
flag="showtemp",
metric_json_keys={"temp_c": "Temperature (Sensor core) (C)"},
columns=(RenderColumn(label="Temp", metric="temp_c"),),
),
ShowSpec(
flag="showpower",
metric_json_keys={"avg_pwr_w": "Average Graphics Package Power (W)"},
columns=(RenderColumn(label="AvgPwr", metric="avg_pwr_w", avg_label="AvgPwr"),),
),
ShowSpec(
flag="showhcuclocks",
metric_json_keys={"sclk_mhz": "sclk clock speed"},
columns=(RenderColumn(label="sclk", metric="sclk_mhz"),),
),
ShowSpec(
flag="showmemuse",
metric_json_keys={"vram_pct": "HCU memory use (%)"},
columns=(RenderColumn(label="VRAM%", metric="vram_pct"),),
),
ShowSpec(
flag="showuse",
metric_json_keys={"hcu_pct": "HCU use (%)"},
columns=(RenderColumn(label="GPU%", metric="hcu_pct", avg_label="GPU%"),),
),
)
SPEC_BY_FLAG: Final[dict[str, ShowSpec]] = {spec.flag: spec for spec in SHOW_SPECS}
SUPPORTED_SHOW_FLAGS: Final[tuple[str, ...]] = tuple(spec.flag for spec in SHOW_SPECS)
DEFAULT_SHOW_FLAGS: Final[tuple[str, ...]] = SUPPORTED_SHOW_FLAGS
JSON_KEY_BY_METRIC: Final[dict[str, str]] = {
metric: json_key
for spec in SHOW_SPECS
for metric, json_key in spec.metric_json_keys.items()
}
def normalized_show_flags(show_flags: Iterable[str] | None) -> list[str]:
"""Normalize user-selected show flags, preserving order and uniqueness."""
if not show_flags:
return list(DEFAULT_SHOW_FLAGS)
output: list[str] = []
for flag in show_flags:
if flag not in SPEC_BY_FLAG:
continue
if flag not in output:
output.append(flag)
return output if output else list(DEFAULT_SHOW_FLAGS)
def hy_smi_args_for_show_flags(show_flags: Iterable[str], wait_idle: bool) -> list[str]:
"""Build hy-smi args with JSON output and requested fine-grained metrics."""
ordered_flags = normalized_show_flags(show_flags)
if wait_idle:
# wait-idle relies on usage+memory metrics even when not displayed.
for required in ("showmemuse", "showuse"):
if required not in ordered_flags:
ordered_flags.append(required)
args = ["--json"]
args.extend(f"--{flag}" for flag in ordered_flags)
return args
def render_columns_for_show_flags(show_flags: Iterable[str]) -> list[RenderColumn]:
"""Resolve display columns from ordered show flags."""
columns: list[RenderColumn] = []
for flag in normalized_show_flags(show_flags):
columns.extend(SPEC_BY_FLAG[flag].columns)
return columns
...@@ -13,17 +13,19 @@ class Sample: ...@@ -13,17 +13,19 @@ class Sample:
Attributes: Attributes:
ts: Monotonic timestamp when the sample was captured. ts: Monotonic timestamp when the sample was captured.
temp_c: GPU temperature in Celsius. temp_c: GPU core temperature in Celsius.
avg_pwr_w: Average power draw in Watts. avg_pwr_w: Average power draw in Watts.
vram_pct: VRAM usage percentage. vram_pct: VRAM usage percentage.
hcu_pct: HCU usage percentage. hcu_pct: HCU usage percentage.
sclk_mhz: sclk frequency in MHz.
""" """
ts: float ts: float
temp_c: float temp_c: float | None = None
avg_pwr_w: float avg_pwr_w: float | None = None
vram_pct: float vram_pct: float | None = None
hcu_pct: float hcu_pct: float | None = None
sclk_mhz: float | None = None
@dataclass @dataclass
......
from __future__ import annotations from __future__ import annotations
import json
import re import re
from typing import Dict from typing import Dict
from hytop.gpu.metrics import JSON_KEY_BY_METRIC
from hytop.gpu.models import Sample from hytop.gpu.models import Sample
ANSI_RE = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]") ANSI_RE = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]")
CARD_KEY_RE = re.compile(r"^card(\d+)$")
def strip_ansi(text: str) -> str: def strip_ansi(text: str) -> str:
...@@ -41,7 +44,7 @@ def parse_number(text: str) -> float: ...@@ -41,7 +44,7 @@ def parse_number(text: str) -> float:
def parse_hy_smi_output(raw: str, sample_ts: float) -> Dict[int, Sample]: def parse_hy_smi_output(raw: str, sample_ts: float) -> Dict[int, Sample]:
"""Parse hy-smi stdout text into GPU keyed samples. """Parse hy-smi JSON output into GPU keyed samples.
Args: Args:
raw: Raw hy-smi stdout text. raw: Raw hy-smi stdout text.
...@@ -51,21 +54,33 @@ def parse_hy_smi_output(raw: str, sample_ts: float) -> Dict[int, Sample]: ...@@ -51,21 +54,33 @@ def parse_hy_smi_output(raw: str, sample_ts: float) -> Dict[int, Sample]:
Mapping from GPU id to parsed sample. Mapping from GPU id to parsed sample.
""" """
cleaned = strip_ansi(raw) cleaned = strip_ansi(raw).strip()
if not cleaned:
return {}
try:
payload = json.loads(cleaned)
except json.JSONDecodeError:
return {}
if not isinstance(payload, dict):
return {}
result: Dict[int, Sample] = {} result: Dict[int, Sample] = {}
for line in cleaned.splitlines(): for card_key, card_data in payload.items():
cols = line.strip().split() if not isinstance(card_key, str):
if len(cols) < 7 or not cols[0].isdigit():
continue continue
gpu_id = int(cols[0]) card_match = CARD_KEY_RE.match(card_key)
try: if card_match is None or not isinstance(card_data, dict):
result[gpu_id] = Sample(
ts=sample_ts,
temp_c=parse_number(cols[1]),
avg_pwr_w=parse_number(cols[2]),
vram_pct=parse_number(cols[5]),
hcu_pct=parse_number(cols[6]),
)
except (IndexError, ValueError):
continue continue
gpu_id = int(card_match.group(1))
sample = Sample(ts=sample_ts)
for metric_name, json_key in JSON_KEY_BY_METRIC.items():
raw_value = card_data.get(json_key)
if raw_value is None:
continue
try:
parsed_value = parse_number(str(raw_value))
except ValueError:
continue
setattr(sample, metric_name, parsed_value)
result[gpu_id] = sample
return result return result
...@@ -8,6 +8,7 @@ from rich.console import Group ...@@ -8,6 +8,7 @@ from rich.console import Group
from rich.table import Table from rich.table import Table
from hytop.core.history import SlidingHistory from hytop.core.history import SlidingHistory
from hytop.gpu.metrics import render_columns_for_show_flags
def fmt_window(window_s: float) -> str: def fmt_window(window_s: float) -> str:
...@@ -46,6 +47,7 @@ def build_renderable( ...@@ -46,6 +47,7 @@ def build_renderable(
histories: Dict[Tuple[str, int], SlidingHistory], histories: Dict[Tuple[str, int], SlidingHistory],
monitored_keys: Iterable[Tuple[str, int]], monitored_keys: Iterable[Tuple[str, int]],
errors: Dict[str, str], errors: Dict[str, str],
show_flags: Iterable[str],
poll_interval: float, poll_interval: float,
elapsed_since_start: float, elapsed_since_start: float,
) -> Group: ) -> Group:
...@@ -65,7 +67,10 @@ def build_renderable( ...@@ -65,7 +67,10 @@ def build_renderable(
""" """
now = time.monotonic() now = time.monotonic()
key_list = sorted(monitored_keys, key=lambda x: (hosts.index(x[0]), x[1])) host_rank = {host: idx for idx, host in enumerate(hosts)}
key_list = sorted(
monitored_keys, key=lambda x: (host_rank.get(x[0], len(hosts)), x[1])
)
table = Table( table = Table(
title=f"hy-smi monitor | interval={poll_interval:.2f}s | elapsed={fmt_elapsed(elapsed_since_start)}", title=f"hy-smi monitor | interval={poll_interval:.2f}s | elapsed={fmt_elapsed(elapsed_since_start)}",
box=box.MINIMAL_HEAVY_HEAD, box=box.MINIMAL_HEAVY_HEAD,
...@@ -73,14 +78,11 @@ def build_renderable( ...@@ -73,14 +78,11 @@ def build_renderable(
) )
table.add_column("Host", justify="left", no_wrap=True) table.add_column("Host", justify="left", no_wrap=True)
table.add_column("GPU", justify="right") table.add_column("GPU", justify="right")
table.add_column("Temp", justify="right") columns = render_columns_for_show_flags(show_flags)
table.add_column(f"Temp@{fmt_window(window)}", justify="right") for col in columns:
table.add_column("AvgPwr", justify="right") table.add_column(col.label, justify="right")
table.add_column(f"AvgPwr@{fmt_window(window)}", justify="right") if col.avg_label is not None:
table.add_column("VRAM%", justify="right") table.add_column(f"{col.avg_label}@{fmt_window(window)}", justify="right")
table.add_column(f"VRAM%@{fmt_window(window)}", justify="right")
table.add_column("HCU%", justify="right")
table.add_column(f"HCU%@{fmt_window(window)}", justify="right")
for key in key_list: for key in key_list:
history = histories.get(key) history = histories.get(key)
...@@ -92,19 +94,23 @@ def build_renderable( ...@@ -92,19 +94,23 @@ def build_renderable(
host, gpu = key host, gpu = key
stale = (now - latest.ts) > window stale = (now - latest.ts) > window
if stale: if stale:
table.add_row(host, str(gpu), "-", "-", "-", "-", "-", "-", "-", "-") table.add_row(host, str(gpu), *["-"] * (len(table.columns) - 2))
continue continue
values: list[str] = []
for col in columns:
metric_value = getattr(latest, col.metric, None)
values.append(_format_metric(col.metric, metric_value))
if col.avg_label is not None:
if metric_value is None:
values.append("-")
else:
values.append(
_format_metric(col.metric, history.avg(col.metric, window, now))
)
table.add_row( table.add_row(
host, host,
str(gpu), str(gpu),
f"{latest.temp_c:7.1f}C", *values,
f"{history.avg('temp_c', window, now):7.1f}C",
f"{latest.avg_pwr_w:8.1f}W",
f"{history.avg('avg_pwr_w', window, now):8.1f}W",
f"{latest.vram_pct:7.2f}%",
f"{history.avg('vram_pct', window, now):7.2f}%",
f"{latest.hcu_pct:7.2f}%",
f"{history.avg('hcu_pct', window, now):7.2f}%",
) )
if table.row_count == 0: if table.row_count == 0:
...@@ -120,3 +126,17 @@ def build_renderable( ...@@ -120,3 +126,17 @@ def build_renderable(
if err: if err:
err_table.add_row(host, err) err_table.add_row(host, err)
return Group(table, err_table) return Group(table, err_table)
def _format_metric(metric: str, value: object) -> str:
if value is None:
return "-"
if metric == "temp_c":
return f"{float(value):7.1f}C"
if metric == "avg_pwr_w":
return f"{float(value):8.1f}W"
if metric in {"vram_pct", "hcu_pct"}:
return f"{float(value):7.2f}%"
if metric == "sclk_mhz":
return f"{float(value):7.0f}MHz"
return str(value)
...@@ -3,19 +3,25 @@ from __future__ import annotations ...@@ -3,19 +3,25 @@ from __future__ import annotations
import sys import sys
import threading import threading
import time import time
from typing import List, Optional, Set from typing import List, Optional, Sequence, Set
from rich.console import Console from rich.console import Console
from rich.live import Live from rich.live import Live
from hytop.core.history import SlidingHistory from hytop.core.history import SlidingHistory
from hytop.core.ssh import collect_from_host from hytop.core.ssh import collect_from_host
from hytop.gpu.metrics import hy_smi_args_for_show_flags
from hytop.gpu.models import HostSnapshot, MonitorState, NodeResult from hytop.gpu.models import HostSnapshot, MonitorState, NodeResult
from hytop.gpu.parser import parse_hy_smi_output from hytop.gpu.parser import parse_hy_smi_output
from hytop.gpu.render import build_renderable from hytop.gpu.render import build_renderable
def collect_node(host: str, ssh_timeout: float, cmd_timeout: float) -> NodeResult: def collect_node(
host: str,
ssh_timeout: float,
cmd_timeout: float,
hy_smi_args: Sequence[str],
) -> NodeResult:
"""Collect one host snapshot and parse it into structured samples. """Collect one host snapshot and parse it into structured samples.
Args: Args:
...@@ -27,7 +33,9 @@ def collect_node(host: str, ssh_timeout: float, cmd_timeout: float) -> NodeResul ...@@ -27,7 +33,9 @@ def collect_node(host: str, ssh_timeout: float, cmd_timeout: float) -> NodeResul
Normalized collection result for the host. Normalized collection result for the host.
""" """
raw = collect_from_host(host=host, ssh_timeout=ssh_timeout, cmd_timeout=cmd_timeout) raw = collect_from_host(
host=host, ssh_timeout=ssh_timeout, cmd_timeout=cmd_timeout, hy_smi_args=hy_smi_args
)
if raw.error: if raw.error:
return NodeResult(host=host, samples={}, error=raw.error) return NodeResult(host=host, samples={}, error=raw.error)
sample_ts = time.monotonic() sample_ts = time.monotonic()
...@@ -41,6 +49,7 @@ def host_collector_loop( ...@@ -41,6 +49,7 @@ def host_collector_loop(
host: str, host: str,
ssh_timeout: float, ssh_timeout: float,
cmd_timeout: float, cmd_timeout: float,
hy_smi_args: Sequence[str],
interval: float, interval: float,
state: dict[str, HostSnapshot], state: dict[str, HostSnapshot],
state_lock: threading.Lock, state_lock: threading.Lock,
...@@ -60,7 +69,7 @@ def host_collector_loop( ...@@ -60,7 +69,7 @@ def host_collector_loop(
while not stop_event.is_set(): while not stop_event.is_set():
started = time.monotonic() started = time.monotonic()
result = collect_node(host, ssh_timeout, cmd_timeout) result = collect_node(host, ssh_timeout, cmd_timeout, hy_smi_args)
with state_lock: with state_lock:
snapshot = state[host] snapshot = state[host]
snapshot.seq += 1 snapshot.seq += 1
...@@ -103,6 +112,8 @@ def availability_ready( ...@@ -103,6 +112,8 @@ def availability_ready(
latest = history.latest() latest = history.latest()
if latest is None or (now - latest.ts) > window: if latest is None or (now - latest.ts) > window:
return False return False
if latest.vram_pct is None or latest.hcu_pct is None:
return False
if history.avg("vram_pct", window, now) != 0.0: if history.avg("vram_pct", window, now) != 0.0:
return False return False
if history.avg("hcu_pct", window, now) != 0.0: if history.avg("hcu_pct", window, now) != 0.0:
...@@ -147,6 +158,7 @@ def start_collectors( ...@@ -147,6 +158,7 @@ def start_collectors(
hosts: List[str], hosts: List[str],
ssh_timeout: float, ssh_timeout: float,
cmd_timeout: float, cmd_timeout: float,
hy_smi_args: Sequence[str],
interval: float, interval: float,
state: MonitorState, state: MonitorState,
) -> List[threading.Thread]: ) -> List[threading.Thread]:
...@@ -171,6 +183,7 @@ def start_collectors( ...@@ -171,6 +183,7 @@ def start_collectors(
host, host,
ssh_timeout, ssh_timeout,
cmd_timeout, cmd_timeout,
hy_smi_args,
interval, interval,
state.host_state, state.host_state,
state.state_lock, state.state_lock,
...@@ -244,6 +257,7 @@ def apply_node_results( ...@@ -244,6 +257,7 @@ def apply_node_results(
def run_monitor( def run_monitor(
hosts: List[str], hosts: List[str],
device_filter: Optional[Set[int]], device_filter: Optional[Set[int]],
show_flags: Sequence[str],
window: float, window: float,
interval: float, interval: float,
wait_idle: bool, wait_idle: bool,
...@@ -280,6 +294,7 @@ def run_monitor( ...@@ -280,6 +294,7 @@ def run_monitor(
state = init_monitor_state( state = init_monitor_state(
hosts=hosts, device_filter=device_filter, max_window=window hosts=hosts, device_filter=device_filter, max_window=window
) )
hy_smi_args = hy_smi_args_for_show_flags(show_flags, wait_idle=wait_idle)
ssh_timeout = min(max(5 * interval, 2.0), 5.0) ssh_timeout = min(max(5 * interval, 2.0), 5.0)
cmd_timeout = min(max(10 * interval, 5.0), 10.0) cmd_timeout = min(max(10 * interval, 5.0), 10.0)
render_interval = min(interval, 0.5) render_interval = min(interval, 0.5)
...@@ -291,6 +306,7 @@ def run_monitor( ...@@ -291,6 +306,7 @@ def run_monitor(
hosts=hosts, hosts=hosts,
ssh_timeout=ssh_timeout, ssh_timeout=ssh_timeout,
cmd_timeout=cmd_timeout, cmd_timeout=cmd_timeout,
hy_smi_args=hy_smi_args,
interval=interval, interval=interval,
state=state, state=state,
) )
...@@ -311,6 +327,7 @@ def run_monitor( ...@@ -311,6 +327,7 @@ def run_monitor(
histories=state.histories, histories=state.histories,
monitored_keys=state.monitored_keys, monitored_keys=state.monitored_keys,
errors=state.errors, errors=state.errors,
show_flags=show_flags,
poll_interval=interval, poll_interval=interval,
elapsed_since_start=time.monotonic() - started, elapsed_since_start=time.monotonic() - started,
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment