"vscode:/vscode.git/clone" did not exist on "de5c39df32449fd5d84fff613e260f5a73f06d19"
sgl_http_server.py 4.98 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
import asyncio
import logging

import uvicorn
import uvloop
from fastapi import FastAPI

from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging

FLUSH_CACHE_ENDPOINT = "flush_cache"

configure_dynamo_logging()


class SglangHttpServer:
    def __init__(self, port: int, runtime: DistributedRuntime, args):
        self.port = port
        self.app = FastAPI()
        self.runtime = runtime
        self.args = args
        self.setup_routes()

    async def _discover_endpoints(self):
        """Discover endpoints that match the pattern"""
        etcd_client = self.runtime.etcd_client()
        if etcd_client is None:
            raise RuntimeError("Runtime has no etcd client; cannot discover endpoints")

        prefix = "instances/"
        kvs = await etcd_client.kv_get_prefix(prefix)

        # Collect (namespace, component) combos that expose flush_cache
        discovered = set()
        for kv in kvs:
            key = kv["key"] if isinstance(kv, dict) else kv.key
            if isinstance(key, bytes):
                key = key.decode()
            if not key.startswith(prefix):
                continue

            segments = key.split("/")
            # Format: instances/<ns>/<comp>/<endpoint:lease>
            if len(segments) < 4:
                continue
            ns, comp, ep_with_lease = segments[1], segments[2], segments[3]

            if self.args.ns and ns != self.args.ns:
                continue
            if self.args.comp and comp != self.args.comp:
                continue

            ep_name = ep_with_lease.split(":", 1)[0]
            if ep_name == self.args.endpoint:
                discovered.add((ns, comp))
                logging.debug(f"Discovered endpoint: {ns}.{comp}")

        logging.debug(
            f"Endpoint discovery complete. Found {len(discovered)} matching endpoints"
        )
        return discovered

    def setup_routes(self):
        @self.app.post("/flush_cache")
        async def flush_cache():
            """Flush the radix cache."""
            try:
                discovered = await self._discover_endpoints()

                if not discovered:
                    return {"message": "No matching endpoints found", "success": False}

                logging.debug(
                    f"Found components: {', '.join([f'{ns}.{comp}' for ns, comp in discovered])}"
                )

                for ns, comp in discovered:
                    ep = (
                        self.runtime.namespace(ns)
                        .component(comp)
                        .endpoint(self.args.endpoint)
                    )
                    client = await ep.client()
                    await client.wait_for_instances()
                    ids = client.instance_ids()

                    logging.debug(f"-- {ns}.{comp} : {len(ids)} instances --")

                    for inst_id in ids:
                        try:
                            stream = await client.direct("{}", inst_id)
                            async for payload in stream:
                                logging.debug(f"[{ns}.{comp}][{inst_id}] -> {payload}")
                        except Exception as e:
                            logging.error(f"[{ns}.{comp}][{inst_id}] flush error: {e}")

                return {"message": "Cache flush initiated", "success": True}
            except Exception as e:
                logging.error(f"Cache flush error: {e}")
                return {"message": f"Cache flush failed: {str(e)}", "success": False}

    async def start_server(self):
        """Start the HTTP server"""
        config = uvicorn.Config(
            self.app,
            host="0.0.0.0",
            port=self.port,
        )
        server = uvicorn.Server(config)

        # Single nice log with available endpoints
        logging.info(
            f"🚀 SGL engine HTTP server running on http://0.0.0.0:{self.port} - Endpoints: POST /flush_cache"
        )

        await server.serve()


def parse_args():
    p = argparse.ArgumentParser(description="SGLang HTTP server for cache management")
    p.add_argument("--port", type=int, default=9001, help="Port to listen on")
    p.add_argument(
        "--ns",
        "--namespace",
        default="dynamo",
        help="Specify Dynamo namespace (default: discover all)",
    )
    p.add_argument(
        "--comp",
        "--component",
        default=None,
        help="Specify component name (default: discover all)",
    )
    p.add_argument(
        "--endpoint", default=FLUSH_CACHE_ENDPOINT, help="Specify endpoint name"
    )
    return p.parse_args()


@dynamo_worker(static=False)
async def main(runtime: DistributedRuntime):
    args = parse_args()

    http_server = SglangHttpServer(args.port, runtime, args)
    await http_server.start_server()


if __name__ == "__main__":
    uvloop.install()
    asyncio.run(main())