Unverified Commit c4fd7612 authored by one's avatar one Committed by GitHub
Browse files

[hytop] Support ssh port (#2)

parent 3a7b8869
...@@ -60,6 +60,9 @@ hytop -n 0.5 --window 5 [COMMAND] ...@@ -60,6 +60,9 @@ hytop -n 0.5 --window 5 [COMMAND]
# Specify a list of nodes for the subcommand # Specify a list of nodes for the subcommand
hytop -H node01,node02 [COMMAND] hytop -H node01,node02 [COMMAND]
# Specify a list of nodes with non-standard ssh ports for the subcommand
hytop -H node01:3333,node02:3333 [COMMAND]
``` ```
### SSH transport ### SSH transport
......
...@@ -23,6 +23,24 @@ class CollectResult: ...@@ -23,6 +23,24 @@ class CollectResult:
error: str | None = None error: str | None = None
@dataclass(frozen=True)
class HostTarget:
"""Parsed host connection target details."""
name: str
hostname: str
port: int | None
def parse_host_target(raw_host: str) -> HostTarget:
"""Parse a host string into a hostname and optional port."""
if ":" in raw_host:
host_part, port_part = raw_host.rsplit(":", 1)
if port_part.isdigit():
return HostTarget(name=raw_host, hostname=host_part, port=int(port_part))
return HostTarget(name=raw_host, hostname=raw_host, port=None)
@dataclass(frozen=True) @dataclass(frozen=True)
class SSHOptions: class SSHOptions:
"""Optional SSH transport tuning options.""" """Optional SSH transport tuning options."""
...@@ -89,17 +107,23 @@ def collect_from_host( ...@@ -89,17 +107,23 @@ def collect_from_host(
Returns: Returns:
Raw command output with normalized error information. Raw command output with normalized error information.
""" """
target = parse_host_target(host)
local_names = {"localhost", "127.0.0.1", "::1"} local_names = {"localhost", "127.0.0.1", "::1"}
if host in local_names:
if target.port is None and target.hostname in local_names:
cmd = ["hy-smi", *hy_smi_args] cmd = ["hy-smi", *hy_smi_args]
else: else:
cmd = [ cmd = ["ssh"]
"ssh", if target.port is not None:
*_build_ssh_option_args(ssh_timeout=ssh_timeout, ssh_options=ssh_options), cmd.extend(["-p", str(target.port)])
host, cmd.extend(
"hy-smi", [
*hy_smi_args, *_build_ssh_option_args(ssh_timeout=ssh_timeout, ssh_options=ssh_options),
] target.hostname,
"hy-smi",
*hy_smi_args,
]
)
try: try:
proc = subprocess.run( proc = subprocess.run(
...@@ -141,16 +165,22 @@ def collect_python_from_host( ...@@ -141,16 +165,22 @@ def collect_python_from_host(
) -> CollectResult: ) -> CollectResult:
"""Run Python code locally or via SSH and return raw output.""" """Run Python code locally or via SSH and return raw output."""
target = parse_host_target(host)
local_names = {"localhost", "127.0.0.1", "::1"} local_names = {"localhost", "127.0.0.1", "::1"}
if host in local_names:
if target.port is None and target.hostname in local_names:
cmd = ["python3", "-c", python_code] cmd = ["python3", "-c", python_code]
else: else:
cmd = [ cmd = ["ssh"]
"ssh", if target.port is not None:
*_build_ssh_option_args(ssh_timeout=ssh_timeout, ssh_options=ssh_options), cmd.extend(["-p", str(target.port)])
host, cmd.extend(
_build_remote_python_shell_command(python_code), [
] *_build_ssh_option_args(ssh_timeout=ssh_timeout, ssh_options=ssh_options),
target.hostname,
_build_remote_python_shell_command(python_code),
]
)
try: try:
proc = subprocess.run( proc = subprocess.run(
...@@ -191,9 +221,15 @@ def build_remote_python_command( ...@@ -191,9 +221,15 @@ def build_remote_python_command(
) -> list[str]: ) -> list[str]:
"""Build command for remote Python execution.""" """Build command for remote Python execution."""
return [ target = parse_host_target(host)
"ssh", cmd = ["ssh"]
*_build_ssh_option_args(ssh_timeout=ssh_timeout, ssh_options=ssh_options), if target.port is not None:
host, cmd.extend(["-p", str(target.port)])
_build_remote_python_shell_command(python_code), cmd.extend(
] [
*_build_ssh_option_args(ssh_timeout=ssh_timeout, ssh_options=ssh_options),
target.hostname,
_build_remote_python_shell_command(python_code),
]
)
return cmd
...@@ -164,3 +164,49 @@ class TestSubprocessStdinIsolation: ...@@ -164,3 +164,49 @@ class TestSubprocessStdinIsolation:
python_code="print('ok')", python_code="print('ok')",
) )
assert mock_run.call_args.kwargs["stdin"] == subprocess.DEVNULL assert mock_run.call_args.kwargs["stdin"] == subprocess.DEVNULL
class TestHostTargetParsing:
def test_parse_host_target_simple(self):
from hytop.core.ssh import parse_host_target
target = parse_host_target("node01")
assert target.name == "node01"
assert target.hostname == "node01"
assert target.port is None
def test_parse_host_target_with_port(self):
from hytop.core.ssh import parse_host_target
target = parse_host_target("node01:3333")
assert target.name == "node01:3333"
assert target.hostname == "node01"
assert target.port == 3333
def test_parse_host_target_invalid_port(self):
from hytop.core.ssh import parse_host_target
target = parse_host_target("node01:abc")
assert target.name == "node01:abc"
assert target.hostname == "node01:abc"
assert target.port is None
@patch("hytop.core.ssh.subprocess.run")
def test_collect_from_host_custom_port(self, mock_run):
mock_run.return_value = _make_proc(stdout="{}")
collect_from_host("node01:3333", ssh_timeout=5, cmd_timeout=10, hy_smi_args=["--json"])
cmd = mock_run.call_args[0][0]
assert cmd[0] == "ssh"
assert "-p" in cmd
assert "3333" in cmd
@patch("hytop.core.ssh.subprocess.run")
def test_collect_python_from_host_custom_port(self, mock_run):
mock_run.return_value = _make_proc(stdout='{"ok":1}')
collect_python_from_host(
"node01:3333", ssh_timeout=5, cmd_timeout=10, python_code="print()"
)
cmd = mock_run.call_args[0][0]
assert cmd[0] == "ssh"
assert "-p" in cmd
assert "3333" in cmd
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment