service.py 10 KB
Newer Older
one's avatar
one committed
1
2
3
4
5
from __future__ import annotations

import sys
import threading
import time
one's avatar
one committed
6
from collections.abc import Sequence
one's avatar
one committed
7
8
9

from hytop.core.history import SlidingHistory
from hytop.core.ssh import collect_from_host
one's avatar
one committed
10
from hytop.gpu.metrics import hy_smi_args_for_show_flags
one's avatar
one committed
11
12
13
14
from hytop.gpu.models import HostSnapshot, MonitorState, NodeResult
from hytop.gpu.parser import parse_hy_smi_output


one's avatar
one committed
15
16
17
18
19
20
def collect_node(
    host: str,
    ssh_timeout: float,
    cmd_timeout: float,
    hy_smi_args: Sequence[str],
) -> NodeResult:
one's avatar
one committed
21
22
23
24
25
26
27
28
29
30
31
    """Collect one host snapshot and parse it into structured samples.

    Args:
        host: Hostname or localhost alias.
        ssh_timeout: SSH connect timeout in seconds.
        cmd_timeout: Command timeout in seconds.

    Returns:
        Normalized collection result for the host.
    """

one's avatar
one committed
32
33
34
    raw = collect_from_host(
        host=host, ssh_timeout=ssh_timeout, cmd_timeout=cmd_timeout, hy_smi_args=hy_smi_args
    )
one's avatar
one committed
35
36
37
    if raw.error:
        return NodeResult(host=host, samples={}, error=raw.error)
    sample_ts = time.monotonic()
38
    samples, parse_error = parse_hy_smi_output(raw.stdout, sample_ts=sample_ts)
one's avatar
one committed
39
    if not samples:
40
41
        reason = parse_error or "unknown parse error"
        return NodeResult(host=host, samples={}, error=f"no gpu rows parsed ({reason})")
one's avatar
one committed
42
43
44
45
46
47
48
    return NodeResult(host=host, samples=samples)


def host_collector_loop(
    host: str,
    ssh_timeout: float,
    cmd_timeout: float,
one's avatar
one committed
49
    hy_smi_args: Sequence[str],
one's avatar
one committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    interval: float,
    state: dict[str, HostSnapshot],
    state_lock: threading.Lock,
    stop_event: threading.Event,
) -> None:
    """Continuously collect one host and publish latest snapshot state.

    Args:
        host: Hostname to collect.
        ssh_timeout: SSH connect timeout in seconds.
        cmd_timeout: Command timeout in seconds.
        interval: Desired collection interval in seconds.
        state: Shared per-host snapshot map.
        state_lock: Lock guarding shared state writes.
        stop_event: Stop signal for graceful shutdown.
    """

    while not stop_event.is_set():
        started = time.monotonic()
one's avatar
one committed
69
        result = collect_node(host, ssh_timeout, cmd_timeout, hy_smi_args)
one's avatar
one committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        with state_lock:
            snapshot = state[host]
            snapshot.seq += 1
            snapshot.updated_ts = time.monotonic()
            snapshot.result = result
        sleep_s = max(0.0, interval - (time.monotonic() - started))
        if stop_event.wait(sleep_s):
            break


def availability_ready(
    window: float,
    histories: dict[tuple[str, int], SlidingHistory],
    monitored_keys: set[tuple[str, int]],
    hosts: list[str],
    errors: dict[str, str],
) -> bool:
    """Check whether all monitored GPUs satisfy idle availability criteria.

    Args:
        window: Rolling window length in seconds.
        histories: Sliding histories by host+gpu key.
        monitored_keys: Effective host+gpu keys to evaluate.
        hosts: Host list used for host-level error checks.
        errors: Latest host-level errors.

    Returns:
        True when all monitored GPUs are fresh and window averages are idle.
    """

    if not monitored_keys:
        return False
    now = time.monotonic()
    if any(errors.get(host) for host in hosts):
        return False
    for key in monitored_keys:
        history = histories.get(key)
        if history is None:
            return False
        latest = history.latest()
        if latest is None or (now - latest.ts) > window:
            return False
one's avatar
one committed
112
        if latest.vram_pct is None or latest.gpu_pct is None:
one's avatar
one committed
113
            return False
one's avatar
one committed
114
115
        if history.avg("vram_pct", window, now) != 0.0:
            return False
one's avatar
one committed
116
        if history.avg("gpu_pct", window, now) != 0.0:
one's avatar
one committed
117
118
119
120
121
            return False
    return True


def init_monitor_state(
one's avatar
one committed
122
123
    hosts: list[str],
    device_filter: set[int] | None,
one's avatar
one committed
124
125
126
127
128
129
130
131
132
133
134
135
136
    max_window: float,
) -> MonitorState:
    """Create initial monitor state for the run.

    Args:
        hosts: Host list.
        device_filter: Optional set of GPU ids to monitor.
        max_window: Sliding window length in seconds.

    Returns:
        Initialized monitor state object.
    """

one's avatar
one committed
137
    monitored_keys = {(h, d) for h in hosts for d in device_filter} if device_filter else set()
one's avatar
one committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    return MonitorState(
        max_window=max_window,
        histories={},
        discovered_keys=set(),
        last_applied_sample_ts={},
        monitored_keys=monitored_keys,
        errors={},
        host_state={host: HostSnapshot() for host in hosts},
        processed_seq={host: 0 for host in hosts},
        state_lock=threading.Lock(),
        stop_event=threading.Event(),
    )


def start_collectors(
one's avatar
one committed
153
    hosts: list[str],
one's avatar
one committed
154
155
    ssh_timeout: float,
    cmd_timeout: float,
one's avatar
one committed
156
    hy_smi_args: Sequence[str],
one's avatar
one committed
157
158
    interval: float,
    state: MonitorState,
one's avatar
one committed
159
) -> list[threading.Thread]:
one's avatar
one committed
160
161
162
163
164
165
166
167
168
169
170
171
172
    """Start one daemon collector thread per host.

    Args:
        hosts: Host list.
        ssh_timeout: SSH connect timeout in seconds.
        cmd_timeout: Command timeout in seconds.
        interval: Desired collection interval in seconds.
        state: Shared monitor state.

    Returns:
        Started collector thread list.
    """

one's avatar
one committed
173
    workers: list[threading.Thread] = []
one's avatar
one committed
174
175
176
177
178
179
180
    for host in hosts:
        worker = threading.Thread(
            target=host_collector_loop,
            args=(
                host,
                ssh_timeout,
                cmd_timeout,
one's avatar
one committed
181
                hy_smi_args,
one's avatar
one committed
182
183
184
185
186
187
188
189
190
191
192
193
194
                interval,
                state.host_state,
                state.state_lock,
                state.stop_event,
            ),
            daemon=True,
            name=f"collector-{host}",
        )
        worker.start()
        workers.append(worker)
    return workers


one's avatar
one committed
195
def drain_pending_nodes(hosts: list[str], state: MonitorState) -> list[NodeResult]:
one's avatar
one committed
196
197
198
199
200
201
202
203
204
205
    """Fetch unseen host snapshots since the previous render tick.

    Args:
        hosts: Host list used to preserve deterministic ordering.
        state: Shared monitor state.

    Returns:
        Newly published node results to apply this tick.
    """

one's avatar
one committed
206
    nodes: list[NodeResult] = []
one's avatar
one committed
207
208
209
210
211
212
213
214
215
216
217
218
    with state.state_lock:
        for host in hosts:
            snapshot = state.host_state[host]
            if snapshot.seq <= state.processed_seq[host]:
                continue
            state.processed_seq[host] = snapshot.seq
            if snapshot.result is not None:
                nodes.append(snapshot.result)
    return nodes


def apply_node_results(
one's avatar
one committed
219
220
    nodes: list[NodeResult],
    device_filter: set[int] | None,
one's avatar
one committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    state: MonitorState,
) -> None:
    """Apply collected node results into histories and error state.

    Args:
        nodes: Newly collected node results.
        device_filter: Optional GPU id filter.
        state: Shared monitor state.
    """

    for node in nodes:
        if node.error:
            state.errors[node.host] = node.error
            continue
        state.errors.pop(node.host, None)
        for gpu_id, sample in node.samples.items():
            key = (node.host, gpu_id)
            state.discovered_keys.add(key)
            if device_filter is not None and gpu_id not in device_filter:
                continue
            history = state.histories.get(key)
            if history is None:
                history = SlidingHistory(max_window_s=state.max_window)
                state.histories[key] = history
            last_ts = state.last_applied_sample_ts.get(key)
            if last_ts is not None and sample.ts <= last_ts:
                continue
            history.add(sample)
            state.last_applied_sample_ts[key] = sample.ts


def run_monitor(
one's avatar
one committed
253
254
    hosts: list[str],
    device_filter: set[int] | None,
one's avatar
one committed
255
    show_flags: Sequence[str],
one's avatar
one committed
256
257
258
    window: float,
    interval: float,
    wait_idle: bool,
one's avatar
one committed
259
    timeout: float | None,
260
    wait_idle_duration: float = 10.0,
one's avatar
one committed
261
) -> int:
one's avatar
one committed
262
263
264
265
266
    """Run the GPU monitor as a Textual TUI application.

    Starts per-host collector threads, then launches the Textual App which
    polls the shared MonitorState and refreshes the DataTable on a timer.
    The app returns an integer exit code via ``App.exit(code)``.
one's avatar
one committed
267
268
269
270

    Args:
        hosts: Host list to monitor.
        device_filter: Optional GPU id filter.
one's avatar
one committed
271
        show_flags: Ordered list of show flags controlling displayed columns.
one's avatar
one committed
272
273
274
        window: Rolling window length in seconds.
        interval: Sampling interval in seconds.
        wait_idle: Whether to exit when all monitored GPUs become idle.
275
        wait_idle_duration: How long GPUs must stay idle before exiting.
one's avatar
one committed
276
277
278
279
280
281
282
283
284
285
        timeout: Optional timeout for wait-idle mode.

    Returns:
        Process-style exit code:
            0 for success,
            2 for invalid arguments,
            124 for timeout in wait-idle mode,
            130 when interrupted by user.
    """

one's avatar
one committed
286
287
    # Import here to avoid a circular import at module load time.
    from hytop.gpu.app import GpuMonitorApp
one's avatar
one committed
288
289
290
291
292
293
294
295

    if interval <= 0:
        print("argument error: --interval must be > 0", file=sys.stderr)
        return 2
    if interval > window:
        print("argument error: --interval must be <= --window value", file=sys.stderr)
        return 2

one's avatar
one committed
296
    state = init_monitor_state(hosts=hosts, device_filter=device_filter, max_window=window)
one's avatar
one committed
297
    hy_smi_args = hy_smi_args_for_show_flags(show_flags, wait_idle=wait_idle)
one's avatar
one committed
298
299
300
301
    ssh_timeout = min(max(5 * interval, 2.0), 5.0)
    cmd_timeout = min(max(10 * interval, 5.0), 10.0)
    started = time.monotonic()

one's avatar
one committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    workers = start_collectors(
        hosts=hosts,
        ssh_timeout=ssh_timeout,
        cmd_timeout=cmd_timeout,
        hy_smi_args=hy_smi_args,
        interval=interval,
        state=state,
    )

    app = GpuMonitorApp(
        hosts=hosts,
        show_flags=show_flags,
        window=window,
        interval=interval,
        wait_idle=wait_idle,
        wait_idle_duration=wait_idle_duration,
        timeout=timeout,
        state=state,
        device_filter=device_filter,
        started=started,
    )

one's avatar
one committed
324
    try:
one's avatar
one committed
325
        app.run()
one's avatar
one committed
326
    except KeyboardInterrupt:
one's avatar
one committed
327
328
329
330
331
332
333
        pass
    finally:
        state.stop_event.set()
        for worker in workers:
            worker.join(timeout=min(0.2, interval))

    return app.return_code or 0