sglang.py 3.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""SGLang-specific utilities for GPU Memory Service tests."""

import logging
import os
import shutil

import requests

from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api

16
17
from .runtime import REPO_ROOT

18
logger = logging.getLogger(__name__)
19
SGLANG_BIN = REPO_ROOT / "dynamo-sglang" / "bin"
20
21
22
23
24
25
26
27
28
29
30
31


class SGLangWithGMSProcess(ManagedProcess):
    """SGLang engine with GPU Memory Service integration."""

    def __init__(
        self,
        request,
        engine_id: str,
        system_port: int,
        sglang_port: int,
        frontend_port: int,
32
33
        *,
        read_only_weights: bool = False,
34
35
36
37
38
39
40
    ):
        self.engine_id = engine_id
        self.system_port = system_port

        log_dir = f"{request.node.name}_{engine_id}"
        shutil.rmtree(log_dir, ignore_errors=True)

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        command = [
            "python",
            "-m",
            "dynamo.sglang",
            "--model-path",
            FAULT_TOLERANCE_MODEL_NAME,
            "--load-format",
            "gms",
            "--enable-memory-saver",
            "--mem-fraction-static",
            "0.9",
            "--port",
            str(sglang_port),
        ]
        if read_only_weights:
            command.extend(
                [
                    "--model-loader-extra-config",
                    '{"gms_read_only": true}',
                ]
            )
62
        super().__init__(
63
            command=command,
64
65
            env={
                **os.environ,
66
67
68
                "PATH": f"/usr/local/cuda/bin:{SGLANG_BIN}:{os.environ.get('PATH', '')}",
                "CC": "/usr/bin/gcc",
                "CXX": "/usr/bin/g++",
69
70
71
72
73
74
75
76
77
78
                "DYN_LOG": "debug",
                "DYN_SYSTEM_PORT": str(system_port),
            },
            health_check_urls=[
                (f"http://localhost:{system_port}/health", self._is_ready),
                (f"http://localhost:{frontend_port}/v1/models", check_models_api),
                (f"http://localhost:{frontend_port}/health", check_health_generate),
            ],
            timeout=300,
            display_output=True,
79
            terminate_all_matching_process_names=False,
80
81
            stragglers=[],
            log_dir=log_dir,
82
            display_name=engine_id,
83
84
85
86
87
88
89
90
91
        )

    def _is_ready(self, response) -> bool:
        try:
            return response.json().get("status") == "ready"
        except ValueError:
            return False

    def sleep(self) -> dict:
92
        """Put the engine to sleep, offloading weights and KV cache."""
93
94
        r = requests.post(
            f"http://localhost:{self.system_port}/engine/release_memory_occupation",
95
            json={"tags": ["weights", "kv_cache"]},
96
97
98
99
100
101
            timeout=30,
        )
        r.raise_for_status()
        logger.info(f"{self.engine_id} release_memory_occupation: {r.json()}")
        return r.json()

102
103
    def wake(self, timeout: int = 30) -> dict:
        """Wake the engine, restoring weights and KV cache."""
104
105
        r = requests.post(
            f"http://localhost:{self.system_port}/engine/resume_memory_occupation",
106
107
            json={"tags": ["weights", "kv_cache"]},
            timeout=timeout,
108
109
110
111
        )
        r.raise_for_status()
        logger.info(f"{self.engine_id} resume_memory_occupation: {r.json()}")
        return r.json()