vllm.py 3.62 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""vLLM-specific utilities for GPU Memory Service tests."""

6
import json
7
8
9
10
11
12
13
14
15
16
import logging
import os
import shutil

import requests

from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api

17
18
from .runtime import DYNAMO_BIN

19
20
21
22
23
24
25
26
27
28
29
30
31
32
logger = logging.getLogger(__name__)


class VLLMWithGMSProcess(ManagedProcess):
    """vLLM engine with GPU Memory Service integration."""

    def __init__(
        self,
        request,
        engine_id: str,
        system_port: int,
        kv_event_port: int,
        nixl_port: int,
        frontend_port: int,
33
34
        *,
        read_only_weights: bool = False,
35
36
37
38
39
40
41
    ):
        self.engine_id = engine_id
        self.system_port = system_port

        log_dir = f"{request.node.name}_{engine_id}"
        shutil.rmtree(log_dir, ignore_errors=True)

42
43
44
45
46
47
48
49
        kv_events_cfg = json.dumps(
            {
                "publisher": "zmq",
                "topic": "kv-events",
                "endpoint": f"tcp://*:{kv_event_port}",
                "enable_kv_cache_events": True,
            }
        )
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        command = [
            "python",
            "-m",
            "dynamo.vllm",
            "--model",
            FAULT_TOLERANCE_MODEL_NAME,
            "--load-format",
            "gms",
            "--enforce-eager",
            "--enable-sleep-mode",
            "--gpu-memory-utilization",
            "0.9",
            "--kv-events-config",
            kv_events_cfg,
        ]
        if read_only_weights:
            command.extend(
                [
                    "--model-loader-extra-config",
                    json.dumps({"gms_read_only": True}),
                ]
            )
72
        super().__init__(
73
            command=command,
74
75
            env={
                **os.environ,
76
                "PATH": f"{DYNAMO_BIN}:{os.environ.get('PATH', '')}",
77
78
79
80
81
82
83
84
85
86
87
                "DYN_LOG": "debug",
                "DYN_SYSTEM_PORT": str(system_port),
                "VLLM_NIXL_SIDE_CHANNEL_PORT": str(nixl_port),
            },
            health_check_urls=[
                (f"http://localhost:{system_port}/health", self._is_ready),
                (f"http://localhost:{frontend_port}/v1/models", check_models_api),
                (f"http://localhost:{frontend_port}/health", check_health_generate),
            ],
            timeout=300,
            display_output=True,
88
            terminate_all_matching_process_names=False,
89
90
            stragglers=[],
            log_dir=log_dir,
91
            display_name=engine_id,
92
93
94
95
96
97
98
99
100
        )

    def _is_ready(self, response) -> bool:
        try:
            return response.json().get("status") == "ready"
        except ValueError:
            return False

    def sleep(self) -> dict:
101
        """Put the engine to sleep, offloading weights and KV cache."""
102
103
        r = requests.post(
            f"http://localhost:{self.system_port}/engine/sleep",
104
            json={"level": 2},
105
106
107
108
109
110
            timeout=30,
        )
        r.raise_for_status()
        logger.info(f"{self.engine_id} sleep: {r.json()}")
        return r.json()

111
112
    def wake(self, timeout: int = 30) -> dict:
        """Wake the engine, restoring weights and KV cache."""
113
        r = requests.post(
114
115
116
            f"http://localhost:{self.system_port}/engine/wake_up",
            json={"tags": ["weights", "kv_cache"]},
            timeout=timeout,
117
118
119
120
        )
        r.raise_for_status()
        logger.info(f"{self.engine_id} wake: {r.json()}")
        return r.json()