mock_nim_backend.py 4.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Mock NIM Backend Server for Metrics Testing

This server mocks a NIM (NVIDIA Inference Microservices) backend that exposes
runtime statistics via the runtime_stats endpoint.

NOTE: This is temporary code for testing purposes only. Once NIM starts using
Dynamo backend components natively, this mock server and the associated NIM
metrics polling code in the frontend will be removed. The NIM-specific metrics
collection exists only as a bridge until NIM adopts the Dynamo runtime.

The server demonstrates:
- Dynamic metric generation (gauges and counters)
- Proper async generator pattern for Dynamo endpoints
- JSON-encoded metric responses compatible with the frontend metrics collector
"""
import asyncio
import json
import time
from typing import Any, AsyncGenerator

import uvloop

from dynamo.runtime import DistributedRuntime, dynamo_worker

# Global counter for incrementing metrics
request_count = 0


async def handle_stats_request(request: Any) -> AsyncGenerator[str, None]:
    """Mock stats handler - returns incrementing metrics for testing

    Args:
        request: JsonLike input from the client (can be dict, list, str, int, float, bool, or None)

    Yields:
        str: JSON string of stats dict conforming to the runtime_stats schema
    """
    global request_count
    request_count += 1

    print(f"Received stats request #{request_count}: {request!r}")

    # Simulate changing metrics
    kv_cache_usage = 0.3 + (request_count % 10) * 0.07  # Cycles between 0.3 and 0.93
    gpu_utilization = 50 + (request_count % 20) * 2.5  # Cycles between 50 and 97.5
    active_requests = request_count % 15  # Cycles 0-14

    stats = {
        "schema_version": 1,
        "worker_id": "mock-worker-1",
        "backend": "vllm",
        "ts": int(time.time()),
        "metrics": {
            "gauges": {
                "kv_cache_usage_perc": round(kv_cache_usage, 2),
                "gpu_utilization_perc": round(gpu_utilization, 2),
                "active_requests": active_requests,
            },
        },
    }
    # Yield as JSON string for Rust Annotated<String> compatibility
    yield json.dumps(stats)


async def worker(runtime: DistributedRuntime):
    import argparse

    parser = argparse.ArgumentParser(description="Mock NIM Backend Server")
    parser.add_argument(
        "--custom-backend-metrics-endpoint",
        type=str,
        default="nim.backend.runtime_stats",
        help="Custom backend metrics endpoint in format 'namespace.component.endpoint' (default: 'nim.backend.runtime_stats')",
    )
    parser.add_argument(
        "--use-etcd",
        action="store_true",
        help="Use etcd for service discovery (dynamic mode). Default is static mode (no etcd).",
    )
    args = parser.parse_args()

    # Parse endpoint (namespace.component.endpoint)
    parts = args.custom_backend_metrics_endpoint.split(".")
    if len(parts) != 3:
        raise ValueError(
            f"Invalid endpoint format. Expected 'namespace.component.endpoint', got: {args.custom_backend_metrics_endpoint}"
        )

    namespace, comp_name, endpoint_name = parts

    component = runtime.namespace(namespace).component(comp_name)
    await component.create_service()

    stats_endpoint = component.endpoint(endpoint_name)
    print(
        f"Mock NIM stats server started on {namespace}/{comp_name}/{endpoint_name} endpoint"
    )
    print(
        "Exposing incrementing metrics: kv_cache_usage_perc, gpu_utilization_perc, active_requests, memory_used_gb, counters"
    )

    await stats_endpoint.serve_endpoint(handle_stats_request)  # type: ignore[arg-type]


def main():
    import argparse

    # Parse args before calling dynamo_worker to determine static mode
    parser = argparse.ArgumentParser(
        description="Mock NIM Backend Server", add_help=False
    )
    parser.add_argument("--use-etcd", action="store_true")
    args, _ = parser.parse_known_args()

    # Set static mode based on --use-etcd flag (default is static/no etcd)
    is_static = not args.use_etcd

    # Create the worker with appropriate static mode
    worker_func = dynamo_worker(static=is_static)(worker)

    uvloop.install()
    asyncio.run(worker_func())  # type: ignore[arg-type]


if __name__ == "__main__":
    main()