args.py 2.47 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
"""Argument parsing for GPU Memory Service server."""
5
6
7
8

import argparse
import logging
from dataclasses import dataclass
9
from typing import Optional
10

11
12
from gpu_memory_service.common.utils import get_socket_path

13
14
15
16
17
18
19
20
logger = logging.getLogger(__name__)


@dataclass
class Config:
    """Configuration for GPU Memory Service server."""

    device: int
21
    tag: str
22
    socket_path: str
23
24
    alloc_retry_interval: float
    alloc_retry_timeout: Optional[float]
25
26
27
28
29
30
    verbose: bool


def parse_args() -> Config:
    """Parse command line arguments for GPU Memory Service server."""
    parser = argparse.ArgumentParser(
31
        description="GPU Memory Service allocation server."
32
33
34
35
36
37
38
39
    )

    parser.add_argument(
        "--device",
        type=int,
        required=True,
        help="CUDA device ID to manage memory for.",
    )
40
41
42
43
44
45
    parser.add_argument(
        "--tag",
        type=str,
        default="weights",
        help="Logical GMS tag for this server (default: weights).",
    )
46
47
48
49
    parser.add_argument(
        "--socket-path",
        type=str,
        default=None,
50
        help="Path for Unix domain socket. Default uses GPU UUID for stability.",
51
52
53
54
55
56
57
    )
    parser.add_argument(
        "--verbose",
        "-v",
        action="store_true",
        help="Enable verbose logging.",
    )
58
59
60
61
62
63
64
65
66
67
68
69
    parser.add_argument(
        "--alloc-retry-interval",
        type=float,
        default=0.5,
        help="Seconds to sleep between allocation retries on CUDA OOM (default: 0.5).",
    )
    parser.add_argument(
        "--alloc-retry-timeout",
        type=float,
        default=None,
        help="Optional max seconds to wait for allocation retries before failing (default: wait indefinitely).",
    )
70
71
72

    args = parser.parse_args()

73
    # Use UUID-based socket path by default (stable across CUDA_VISIBLE_DEVICES)
74
75
76
77
78
    socket_path = args.socket_path or get_socket_path(args.device, args.tag)
    if args.alloc_retry_interval <= 0:
        parser.error("--alloc-retry-interval must be > 0")
    if args.alloc_retry_timeout is not None and args.alloc_retry_timeout <= 0:
        parser.error("--alloc-retry-timeout must be > 0 when set")
79

80
    return Config(
81
        device=args.device,
82
        tag=args.tag,
83
        socket_path=socket_path,
84
85
        alloc_retry_interval=args.alloc_retry_interval,
        alloc_retry_timeout=args.alloc_retry_timeout,
86
87
        verbose=args.verbose,
    )