router.py 11.5 KB
Newer Older
1
from typing import Dict, List, Optional
2
3
4
5
6
7
8
9
10
11

from sglang_router_rs import PolicyType
from sglang_router_rs import Router as _Router


class Router:
    """
    A high-performance router for distributing requests across worker nodes.

    Args:
12
13
        worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
            the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
14
15
16
        policy: Load balancing policy to use. Options:
            - PolicyType.Random: Randomly select workers
            - PolicyType.RoundRobin: Distribute requests in round-robin fashion
17
            - PolicyType.CacheAware: Distribute requests based on cache state and load balance
18
            - PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
19
20
        host: Host address to bind the router server. Default: '127.0.0.1'
        port: Port number to bind the router server. Default: 3001
21
        worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
22
        worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10
23
24
25
        cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
            if the match rate exceeds threshold, otherwise routes to the worker with the smallest
            tree. Default: 0.5
26
27
28
29
        balance_abs_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
            AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32
        balance_rel_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
            AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
30
31
        eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
            routing. Default: 60
32
        max_payload_size: Maximum payload size in bytes. Default: 256MB
33
        max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
34
35
36
37
        dp_aware: Enable data parallelism aware schedule. Default: False
        api_key: The api key used for the authorization with the worker.
            Useful when the dp aware scheduling strategy is enabled.
            Default: None
38
        log_dir: Directory to store log files. If None, logs are only output to console. Default: None
39
        log_level: Logging level. Options: 'debug', 'info', 'warning', 'error', 'critical'.
40
41
42
43
44
45
46
47
        service_discovery: Enable Kubernetes service discovery. When enabled, the router will
            automatically discover worker pods based on the selector. Default: False
        selector: Dictionary mapping of label keys to values for Kubernetes pod selection.
            Example: {"app": "sglang-worker"}. Default: {}
        service_discovery_port: Port to use for service discovery. The router will generate
            worker URLs using this port. Default: 80
        service_discovery_namespace: Kubernetes namespace to watch for pods. If not provided,
            watches pods across all namespaces (requires cluster-wide permissions). Default: None
48
49
50
51
        prefill_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
            for prefill servers (PD mode only). Default: {}
        decode_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
            for decode servers (PD mode only). Default: {}
52
53
        prometheus_port: Port to expose Prometheus metrics. Default: None
        prometheus_host: Host address to bind the Prometheus metrics server. Default: None
54
        pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
55
56
        prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
        decode_urls: List of URLs for decode servers (PD mode only)
57
58
59
60
        prefill_policy: Specific load balancing policy for prefill nodes (PD mode only).
            If not specified, uses the main policy. Default: None
        decode_policy: Specific load balancing policy for decode nodes (PD mode only).
            If not specified, uses the main policy. Default: None
61
62
63
        request_id_headers: List of HTTP headers to check for request IDs. If not specified,
            uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
            Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
64
65
66
        bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode).
            Default: 'sglang.ai/bootstrap-port'
        request_timeout_secs: Request timeout in seconds. Default: 600
67
68
69
70
        max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 256
        queue_size: Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately). Default: 100
        queue_timeout_secs: Maximum time (in seconds) a request can wait in queue before timing out. Default: 60
        rate_limit_tokens_per_second: Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests. Default: None
71
        cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: []
72
73
74
75
76
        health_failure_threshold: Number of consecutive health check failures before marking worker unhealthy. Default: 3
        health_success_threshold: Number of consecutive health check successes before marking worker healthy. Default: 2
        health_check_timeout_secs: Timeout in seconds for health check requests. Default: 5
        health_check_interval_secs: Interval in seconds between runtime health checks. Default: 60
        health_check_endpoint: Health check endpoint path. Default: '/health'
77
78
        model_path: Model path for loading tokenizer (HuggingFace model ID or local path). Default: None
        tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
79
80
81
82
83
84
85
86
    """

    def __init__(
        self,
        worker_urls: List[str],
        policy: PolicyType = PolicyType.RoundRobin,
        host: str = "127.0.0.1",
        port: int = 3001,
87
88
89
90
91
92
93
94
        worker_startup_timeout_secs: int = 600,
        worker_startup_check_interval: int = 30,
        cache_threshold: float = 0.3,
        balance_abs_threshold: int = 64,
        balance_rel_threshold: float = 1.5,
        eviction_interval_secs: int = 120,
        max_tree_size: int = 2**26,
        max_payload_size: int = 512 * 1024 * 1024,  # 512MB
95
96
        dp_aware: bool = False,
        api_key: Optional[str] = None,
97
        log_dir: Optional[str] = None,
98
        log_level: Optional[str] = None,
99
100
101
102
        service_discovery: bool = False,
        selector: Dict[str, str] = None,
        service_discovery_port: int = 80,
        service_discovery_namespace: Optional[str] = None,
103
104
        prefill_selector: Dict[str, str] = None,
        decode_selector: Dict[str, str] = None,
105
        bootstrap_port_annotation: str = "sglang.ai/bootstrap-port",
106
107
        prometheus_port: Optional[int] = None,
        prometheus_host: Optional[str] = None,
108
        request_timeout_secs: int = 1800,
109
        request_id_headers: Optional[List[str]] = None,
110
        pd_disaggregation: bool = False,
111
112
        prefill_urls: Optional[List[tuple]] = None,
        decode_urls: Optional[List[str]] = None,
113
114
        prefill_policy: Optional[PolicyType] = None,
        decode_policy: Optional[PolicyType] = None,
115
        max_concurrent_requests: int = 256,
116
117
118
        queue_size: int = 100,
        queue_timeout_secs: int = 60,
        rate_limit_tokens_per_second: Optional[int] = None,
119
        cors_allowed_origins: List[str] = None,
120
121
122
123
124
125
126
127
128
        retry_max_retries: int = 5,
        retry_initial_backoff_ms: int = 50,
        retry_max_backoff_ms: int = 30_000,
        retry_backoff_multiplier: float = 1.5,
        retry_jitter_factor: float = 0.2,
        cb_failure_threshold: int = 10,
        cb_success_threshold: int = 3,
        cb_timeout_duration_secs: int = 60,
        cb_window_duration_secs: int = 120,
129
130
        disable_retries: bool = False,
        disable_circuit_breaker: bool = False,
131
132
133
134
135
        health_failure_threshold: int = 3,
        health_success_threshold: int = 2,
        health_check_timeout_secs: int = 5,
        health_check_interval_secs: int = 60,
        health_check_endpoint: str = "/health",
136
137
        model_path: Optional[str] = None,
        tokenizer_path: Optional[str] = None,
138
    ):
139
140
        if selector is None:
            selector = {}
141
142
143
144
        if prefill_selector is None:
            prefill_selector = {}
        if decode_selector is None:
            decode_selector = {}
145
146
        if cors_allowed_origins is None:
            cors_allowed_origins = []
147

148
149
150
151
152
        self._router = _Router(
            worker_urls=worker_urls,
            policy=policy,
            host=host,
            port=port,
153
            worker_startup_timeout_secs=worker_startup_timeout_secs,
154
            worker_startup_check_interval=worker_startup_check_interval,
155
            cache_threshold=cache_threshold,
156
157
            balance_abs_threshold=balance_abs_threshold,
            balance_rel_threshold=balance_rel_threshold,
158
159
            eviction_interval_secs=eviction_interval_secs,
            max_tree_size=max_tree_size,
160
            max_payload_size=max_payload_size,
161
162
            dp_aware=dp_aware,
            api_key=api_key,
163
            log_dir=log_dir,
164
            log_level=log_level,
165
166
167
168
            service_discovery=service_discovery,
            selector=selector,
            service_discovery_port=service_discovery_port,
            service_discovery_namespace=service_discovery_namespace,
169
170
            prefill_selector=prefill_selector,
            decode_selector=decode_selector,
171
            bootstrap_port_annotation=bootstrap_port_annotation,
172
173
            prometheus_port=prometheus_port,
            prometheus_host=prometheus_host,
174
175
            request_timeout_secs=request_timeout_secs,
            request_id_headers=request_id_headers,
176
            pd_disaggregation=pd_disaggregation,
177
178
            prefill_urls=prefill_urls,
            decode_urls=decode_urls,
179
180
            prefill_policy=prefill_policy,
            decode_policy=decode_policy,
181
            max_concurrent_requests=max_concurrent_requests,
182
183
184
            queue_size=queue_size,
            queue_timeout_secs=queue_timeout_secs,
            rate_limit_tokens_per_second=rate_limit_tokens_per_second,
185
            cors_allowed_origins=cors_allowed_origins,
186
187
188
189
190
191
192
193
194
195
196
            retry_max_retries=retry_max_retries,
            retry_initial_backoff_ms=retry_initial_backoff_ms,
            retry_max_backoff_ms=retry_max_backoff_ms,
            retry_backoff_multiplier=retry_backoff_multiplier,
            retry_jitter_factor=retry_jitter_factor,
            cb_failure_threshold=cb_failure_threshold,
            cb_success_threshold=cb_success_threshold,
            cb_timeout_duration_secs=cb_timeout_duration_secs,
            cb_window_duration_secs=cb_window_duration_secs,
            disable_retries=disable_retries,
            disable_circuit_breaker=disable_circuit_breaker,
197
198
199
200
201
            health_failure_threshold=health_failure_threshold,
            health_success_threshold=health_success_threshold,
            health_check_timeout_secs=health_check_timeout_secs,
            health_check_interval_secs=health_check_interval_secs,
            health_check_endpoint=health_check_endpoint,
202
203
            model_path=model_path,
            tokenizer_path=tokenizer_path,
204
205
206
207
208
209
210
211
        )

    def start(self) -> None:
        """Start the router server.

        This method blocks until the server is shut down.
        """
        self._router.start()