Unverified Commit a7d51b39 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

feat: route requests by device type and load for sglang epd (#7215)


Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
parent a3e6468d
...@@ -241,7 +241,8 @@ class FrontendArgGroup(ArgGroup): ...@@ -241,7 +241,8 @@ class FrontendArgGroup(ArgGroup):
default="round-robin", default="round-robin",
help="How to route the request. power-of-two picks 2 random workers and " help="How to route the request. power-of-two picks 2 random workers and "
"routes to the one with fewer in-flight requests. least-loaded routes to " "routes to the one with fewer in-flight requests. least-loaded routes to "
"the worker with the fewest active requests. In disaggregated prefill mode, " "the worker with the fewest active requests. device-aware-weighted routes "
"based on worker device type (CPU/CUDA). In disaggregated prefill mode, "
"both power-of-two and least-loaded skip bootstrap optimization and fall " "both power-of-two and least-loaded skip bootstrap optimization and fall "
"back to the synchronous prefill path.", "back to the synchronous prefill path.",
choices=[ choices=[
...@@ -251,6 +252,7 @@ class FrontendArgGroup(ArgGroup): ...@@ -251,6 +252,7 @@ class FrontendArgGroup(ArgGroup):
"kv", "kv",
"direct", "direct",
"least-loaded", "least-loaded",
"device-aware-weighted",
], ],
) )
add_argument( add_argument(
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# - Auto-discovery: Watches etcd for engine/worker registration (via `register_model`). # - Auto-discovery: Watches etcd for engine/worker registration (via `register_model`).
# - Pre-processor: Prompt templating and tokenization. # - Pre-processor: Prompt templating and tokenization.
# - Router, defaulting to round-robin. Use --router-mode to switch # - Router, defaulting to round-robin. Use --router-mode to switch
# (round-robin, random, kv, direct, least-loaded). # (round-robin, random, kv, direct, least-loaded, device-aware-weighted).
# #
# Pass `--interactive` or `-i` for text chat instead of HTTP server. # Pass `--interactive` or `-i` for text chat instead of HTTP server.
# #
...@@ -251,6 +251,9 @@ async def async_main(): ...@@ -251,6 +251,9 @@ async def async_main():
elif config.router_mode == "least-loaded": elif config.router_mode == "least-loaded":
router_mode = RouterMode.LeastLoaded router_mode = RouterMode.LeastLoaded
kv_router_config = None kv_router_config = None
elif config.router_mode == "device-aware-weighted":
router_mode = RouterMode.DeviceAwareWeighted
kv_router_config = None
else: else:
router_mode = RouterMode.RoundRobin router_mode = RouterMode.RoundRobin
kv_router_config = None kv_router_config = None
......
...@@ -85,7 +85,7 @@ spec: ...@@ -85,7 +85,7 @@ spec:
|-----------|---------|-------------| |-----------|---------|-------------|
| `--http-port` | 8000 | HTTP server port | | `--http-port` | 8000 | HTTP server port |
| `--kserve-grpc-server` | false | Enable KServe gRPC server | | `--kserve-grpc-server` | false | Enable KServe gRPC server |
| `--router-mode` | `round-robin` | Routing strategy: `round-robin`, `random`, `kv`, `direct`, `least-loaded` (`power-of-two` and `least-loaded` use synchronous prefill fallback in disaggregated prefill mode) | | `--router-mode` | `round-robin` | Routing strategy: `round-robin`, `random`, `kv`, `direct`, `least-loaded`, `device-aware-weighted` (`power-of-two` and `least-loaded` use synchronous prefill fallback in disaggregated prefill mode) |
See the [Frontend Guide](frontend-guide.md) for full configuration options. See the [Frontend Guide](frontend-guide.md) for full configuration options.
......
...@@ -46,7 +46,7 @@ For all CLI arguments, environment variables, K8s deployment examples, and tunin ...@@ -46,7 +46,7 @@ For all CLI arguments, environment variables, K8s deployment examples, and tunin
**Limitations:** **Limitations:**
- Static endpoints not supported—KV router requires dynamic model discovery via etcd to track worker instances and their KV cache states - Static endpoints not supported—KV router requires dynamic model discovery via etcd to track worker instances and their KV cache states
For basic model registration without KV routing, use `--router-mode round-robin`, `--router-mode random`, or `--router-mode least-loaded` with both static and dynamic endpoints. For basic model registration without KV routing, use `--router-mode round-robin`, `--router-mode random`, `--router-mode least-loaded`, or `--router-mode device-aware-weighted` with both static and dynamic endpoints.
## Next Steps ## Next Steps
......
...@@ -21,6 +21,7 @@ The Dynamo router can be deployed in several configurations. The table below sho ...@@ -21,6 +21,7 @@ The Dynamo router can be deployed in several configurations. The table below sho
| **Frontend + KV (Aggregated)** | `python -m dynamo.frontend --router-mode kv` | KV cache overlap + load | NATS Core / JetStream / ZMQ / Approx | Aggregated | Production single-pool serving with cache reuse | | **Frontend + KV (Aggregated)** | `python -m dynamo.frontend --router-mode kv` | KV cache overlap + load | NATS Core / JetStream / ZMQ / Approx | Aggregated | Production single-pool serving with cache reuse |
| **Frontend + KV (Disaggregated)** | `python -m dynamo.frontend --router-mode kv` with prefill + decode workers | KV cache overlap + load | NATS Core / JetStream / ZMQ / Approx | Disaggregated (prefill + decode pools) | Separate prefill/decode for large-scale serving | | **Frontend + KV (Disaggregated)** | `python -m dynamo.frontend --router-mode kv` with prefill + decode workers | KV cache overlap + load | NATS Core / JetStream / ZMQ / Approx | Disaggregated (prefill + decode pools) | Separate prefill/decode for large-scale serving |
| **Frontend + Least-Loaded** | `python -m dynamo.frontend --router-mode least-loaded` | Fewest active connections | None | Aggregated or disaggregated fallback | Simple load-aware balancing without KV awareness | | **Frontend + Least-Loaded** | `python -m dynamo.frontend --router-mode least-loaded` | Fewest active connections | None | Aggregated or disaggregated fallback | Simple load-aware balancing without KV awareness |
| **Frontend + Device-Aware Weighted** | `python -m dynamo.frontend --router-mode device-aware-weighted` | Device-aware budget + least-loaded within selected device group | None | Aggregated or disaggregated fallback | Heterogeneous fleet balancing (CPU/non-CPU); degenerates to least-loaded when only one device class is present |
| **Frontend + Direct** | `python -m dynamo.frontend --router-mode direct` | Worker ID from request hints | None | Aggregated | External orchestrator (e.g., EPP/GAIE) selects workers | | **Frontend + Direct** | `python -m dynamo.frontend --router-mode direct` | Worker ID from request hints | None | Aggregated | External orchestrator (e.g., EPP/GAIE) selects workers |
| **Standalone Router** | `python -m dynamo.router` | KV cache overlap + load | NATS Core / JetStream / ZMQ | Any | Routing without the HTTP frontend (multi-tier, custom pipelines) | | **Standalone Router** | `python -m dynamo.router` | KV cache overlap + load | NATS Core / JetStream / ZMQ | Any | Routing without the HTTP frontend (multi-tier, custom pipelines) |
...@@ -32,8 +33,71 @@ The Dynamo router can be deployed in several configurations. The table below sho ...@@ -32,8 +33,71 @@ The Dynamo router can be deployed in several configurations. The table below sho
| **Random** | `random` | Selects a random worker for each request | | **Random** | `random` | Selects a random worker for each request |
| **KV** | `kv` | Evaluates KV cache overlap and decode load per worker; picks lowest cost | | **KV** | `kv` | Evaluates KV cache overlap and decode load per worker; picks lowest cost |
| **Least-Loaded** | `least-loaded` | Routes to the worker with fewest active connections; in disaggregated prefill paths it skips bootstrap optimization and falls back to synchronous prefill | | **Least-Loaded** | `least-loaded` | Routes to the worker with fewest active connections; in disaggregated prefill paths it skips bootstrap optimization and falls back to synchronous prefill |
| **Device-Aware Weighted** | `device-aware-weighted` | Partitions workers into CPU and non-CPU groups, applies capability-normalized ratio budgeting using `DYN_ENCODER_CUDA_TO_CPU_RATIO` to decide which group receives the request, then selects the least-loaded worker within that group; when only one device class exists, behavior degenerates to least-loaded |
| **Direct** | `direct` | Reads the target `worker_id` from the request's routing hints; no selection logic | | **Direct** | `direct` | Reads the target `worker_id` from the request's routing hints; no selection logic |
### Device-Aware Weighted Routing
`device-aware-weighted` is designed for **heterogeneous fleets** where workers of different compute capability — for example CPU embedding encoders alongside GPU embedding encoders — share the same endpoint.
Raw in-flight counts are not directly comparable across device types: a GPU can sustain far more concurrent requests than a CPU before reaching the same relative load. Comparing raw counts would permanently starve CPU workers because GPUs would always look less loaded.
#### Budget policy
Workers are split into two groups: *CPU* and *non-CPU* (CUDA/GPU/etc.). A **capability-normalized load** is compared between the two groups:
```
normalized_load = total_inflight(group) / (instance_count(group) × throughput_weight)
```
where `throughput_weight` is **1** for CPU workers and **`DYN_ENCODER_CUDA_TO_CPU_RATIO`** for non-CPU workers. The request is routed to the group with the **lower normalized load**, i.e. the group that has more headroom relative to its compute capacity.
This comparison rearranges to an equivalent integer budget check (avoiding floating-point division):
```
CPU group is selected when:
total_cpu_inflight < total_non_cpu_inflight × cpu_count / (ratio × non_cpu_count)
```
The right-hand side is the *allowed CPU in-flight budget*: the number of concurrent CPU requests that would produce the same relative load as the current non-CPU workload. When actual CPU in-flight is below that budget, CPUs are underloaded relative to GPUs and get the next request; once the budget is exhausted the router switches back to the non-CPU group.
#### Example
Ratio = 8 (default; each GPU handles 8× the requests of a CPU at equal normalized load).
Current per-instance load:
- GPU instance g1: 8 in-flight
- GPU instance g2: 8 in-flight
- CPU instance c1: 0 in-flight
- CPU instance c2: 0 in-flight
So `non_cpu_count = 2`, `cpu_count = 2`, `total_non_cpu_inflight = 16`, and `total_cpu_inflight = 0`:
```
allowed_cpu_inflight = 16 × 2 / (8 × 2) = 2
total_cpu_inflight = 0 < 2 → route to CPU group
```
Interpretation: with this non-CPU load, the CPU group budget is 2 total in-flight requests, i.e. roughly 1 request per CPU instance before normalized load matches the GPU group.
After c1=1 and c2=1 (total CPU in-flight = 2):
```
total_cpu_inflight = 2 < 2 is false → route to non-CPU group
```
This keeps both groups at roughly equal *normalized* load.
#### Configuration
| Variable | Default | Description |
|----------|---------|-------------|
| `DYN_ENCODER_CUDA_TO_CPU_RATIO` | `8` | Throughput ratio of a non-CPU (GPU) worker relative to one CPU worker. Set to the approximate ratio of requests-per-second each device type can sustain under your workload. |
> [!TIP]
> When only one device class is present (all GPU or all CPU) the policy returns all instances and behavior degenerates to standard least-loaded selection within that single group.
### KV Event Transport Modes (within `--router-mode kv`) ### KV Event Transport Modes (within `--router-mode kv`)
When using KV routing, the router needs to know what each worker has cached. There are four ways to get this information: When using KV routing, the router needs to know what each worker has cached. There are four ways to get this information:
...@@ -268,6 +332,7 @@ We can then use the default routing methods exposed by the client class to send ...@@ -268,6 +332,7 @@ We can then use the default routing methods exposed by the client class to send
- **Direct routing**: Explicitly targets a specific worker via `client.direct(input, component_id)` - **Direct routing**: Explicitly targets a specific worker via `client.direct(input, component_id)`
- **Least-loaded routing**: Routes to the worker with fewest active connections via `--router-mode least-loaded` - **Least-loaded routing**: Routes to the worker with fewest active connections via `--router-mode least-loaded`
In disaggregated prefill paths it skips bootstrap optimization and uses the synchronous prefill path, matching power-of-two routing. In disaggregated prefill paths it skips bootstrap optimization and uses the synchronous prefill path, matching power-of-two routing.
- **Device-aware weighted routing**: Routes using CPU/non-CPU ratio budgeting plus least-loaded selection within the selected device group via `--router-mode device-aware-weighted`. Tune ratio with `DYN_ENCODER_CUDA_TO_CPU_RATIO` (default `8`). See [Device-Aware Weighted Routing](#device-aware-weighted-routing) below for a detailed explanation of the budget policy.
KV Cache routing uses direct routing with a special worker selection algorithm. KV Cache routing uses direct routing with a special worker selection algorithm.
......
...@@ -52,6 +52,7 @@ pub enum RouterMode { ...@@ -52,6 +52,7 @@ pub enum RouterMode {
/// Used when an external orchestrator (e.g., EPP) handles worker selection. /// Used when an external orchestrator (e.g., EPP) handles worker selection.
Direct, Direct,
LeastLoaded, LeastLoaded,
DeviceAwareWeighted,
} }
impl From<RouterMode> for RsRouterMode { impl From<RouterMode> for RsRouterMode {
...@@ -63,6 +64,7 @@ impl From<RouterMode> for RsRouterMode { ...@@ -63,6 +64,7 @@ impl From<RouterMode> for RsRouterMode {
RouterMode::KV => Self::KV, RouterMode::KV => Self::KV,
RouterMode::Direct => Self::Direct, RouterMode::Direct => Self::Direct,
RouterMode::LeastLoaded => Self::LeastLoaded, RouterMode::LeastLoaded => Self::LeastLoaded,
RouterMode::DeviceAwareWeighted => Self::DeviceAwareWeighted,
} }
} }
} }
......
...@@ -1132,6 +1132,7 @@ class RouterMode: ...@@ -1132,6 +1132,7 @@ class RouterMode:
KV: "RouterMode" KV: "RouterMode"
Direct: "RouterMode" Direct: "RouterMode"
LeastLoaded: "RouterMode" LeastLoaded: "RouterMode"
DeviceAwareWeighted: "RouterMode"
... ...
class RouterConfig: class RouterConfig:
...@@ -1152,7 +1153,7 @@ class RouterConfig: ...@@ -1152,7 +1153,7 @@ class RouterConfig:
Create a RouterConfig. Create a RouterConfig.
Args: Args:
mode: The router mode (RoundRobin, Random, KV, Direct, or LeastLoaded) mode: The router mode (RoundRobin, Random, KV, Direct, LeastLoaded, or DeviceAwareWeighted)
config: Optional KV router configuration (used when mode is KV) config: Optional KV router configuration (used when mode is KV)
active_decode_blocks_threshold: Threshold percentage (0.0-1.0) for decode blocks busy detection active_decode_blocks_threshold: Threshold percentage (0.0-1.0) for decode blocks busy detection
active_prefill_tokens_threshold: Literal token count threshold for prefill busy detection active_prefill_tokens_threshold: Literal token count threshold for prefill busy detection
......
...@@ -590,6 +590,7 @@ impl ModelManager { ...@@ -590,6 +590,7 @@ impl ModelManager {
component: router_endpoint_id.component.clone(), component: router_endpoint_id.component.clone(),
endpoint: router_endpoint_id.name.clone(), endpoint: router_endpoint_id.name.clone(),
transport, transport,
device_type: None,
}; };
discovery.register(discovery_spec).await?; discovery.register(discovery_spec).await?;
......
...@@ -333,7 +333,8 @@ where ...@@ -333,7 +333,8 @@ where
RouterMode::Random RouterMode::Random
| RouterMode::RoundRobin | RouterMode::RoundRobin
| RouterMode::PowerOfTwoChoices | RouterMode::PowerOfTwoChoices
| RouterMode::LeastLoaded => ServiceBackend::from_engine(Arc::new(router)), | RouterMode::LeastLoaded
| RouterMode::DeviceAwareWeighted => ServiceBackend::from_engine(Arc::new(router)),
RouterMode::KV => { RouterMode::KV => {
let Some(chooser) = chooser else { let Some(chooser) = chooser else {
anyhow::bail!("RouterMode::KV requires KVRouter to not be null"); anyhow::bail!("RouterMode::KV requires KVRouter to not be null");
......
...@@ -77,6 +77,13 @@ pub enum TransportType { ...@@ -77,6 +77,13 @@ pub enum TransportType {
Tcp(String), Tcp(String),
} }
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum DeviceType {
Cpu,
Cuda,
}
#[derive(Default)] #[derive(Default)]
pub struct RegistryInner { pub struct RegistryInner {
pub(crate) services: HashMap<String, Service>, pub(crate) services: HashMap<String, Service>,
...@@ -94,6 +101,8 @@ pub struct Instance { ...@@ -94,6 +101,8 @@ pub struct Instance {
pub namespace: String, pub namespace: String,
pub instance_id: u64, pub instance_id: u64,
pub transport: TransportType, pub transport: TransportType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub device_type: Option<DeviceType>,
} }
impl Instance { impl Instance {
......
...@@ -10,7 +10,7 @@ use educe::Educe; ...@@ -10,7 +10,7 @@ use educe::Educe;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::{ use crate::{
component::{Endpoint, Instance, TransportType}, component::{DeviceType, Endpoint, Instance, TransportType},
distributed::RequestPlaneMode, distributed::RequestPlaneMode,
pipeline::network::{PushWorkHandler, ingress::push_endpoint::PushEndpoint}, pipeline::network::{PushWorkHandler, ingress::push_endpoint::PushEndpoint},
protocols::EndpointId, protocols::EndpointId,
...@@ -18,6 +18,35 @@ use crate::{ ...@@ -18,6 +18,35 @@ use crate::{
transports::nats, transports::nats,
}; };
fn endpoint_device_type() -> Option<DeviceType> {
// Common CUDA masks that explicitly disable GPU visibility.
if std::env::var("CUDA_VISIBLE_DEVICES")
.ok()
.map(|v| {
let l = v.trim().to_ascii_lowercase();
l.is_empty() || l == "-1" || l == "none" || l == "void"
})
.unwrap_or(false)
{
return Some(DeviceType::Cpu);
}
// Container runtimes often use NVIDIA_VISIBLE_DEVICES to gate GPU visibility.
if std::env::var("NVIDIA_VISIBLE_DEVICES")
.ok()
.map(|v| {
let l = v.trim().to_ascii_lowercase();
l == "none" || l == "void"
})
.unwrap_or(false)
{
return Some(DeviceType::Cpu);
}
// Default: no explicit CPU override means this endpoint is CUDA-capable.
Some(DeviceType::Cuda)
}
#[derive(Educe, Builder, Dissolve)] #[derive(Educe, Builder, Dissolve)]
#[educe(Debug)] #[educe(Debug)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))] #[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
...@@ -124,6 +153,7 @@ impl EndpointConfigBuilder { ...@@ -124,6 +153,7 @@ impl EndpointConfigBuilder {
namespace: endpoint_id.namespace.clone(), namespace: endpoint_id.namespace.clone(),
instance_id: connection_id, instance_id: connection_id,
transport, transport,
device_type: endpoint_device_type(),
}; };
tracing::debug!(endpoint_name = %endpoint.name, "Registering endpoint health check target"); tracing::debug!(endpoint_name = %endpoint.name, "Registering endpoint health check target");
let guard = system_health.lock(); let guard = system_health.lock();
...@@ -202,6 +232,7 @@ impl EndpointConfigBuilder { ...@@ -202,6 +232,7 @@ impl EndpointConfigBuilder {
component: endpoint_id.component.clone(), component: endpoint_id.component.clone(),
endpoint: endpoint_id.name.clone(), endpoint: endpoint_id.name.clone(),
transport, transport,
device_type: endpoint_device_type(),
}; };
if let Err(e) = discovery.register(discovery_spec).await { if let Err(e) = discovery.register(discovery_spec).await {
...@@ -341,6 +372,7 @@ impl Endpoint { ...@@ -341,6 +372,7 @@ impl Endpoint {
endpoint: endpoint_id.name, endpoint: endpoint_id.name,
instance_id, instance_id,
transport, transport,
device_type: endpoint_device_type(),
}); });
let discovery = drt.discovery(); let discovery = drt.discovery();
...@@ -382,6 +414,7 @@ impl Endpoint { ...@@ -382,6 +414,7 @@ impl Endpoint {
component: endpoint_id.component, component: endpoint_id.component,
endpoint: endpoint_id.name, endpoint: endpoint_id.name,
transport, transport,
device_type: endpoint_device_type(),
}; };
let discovery = drt.discovery(); let discovery = drt.discovery();
......
...@@ -613,6 +613,7 @@ mod tests { ...@@ -613,6 +613,7 @@ mod tests {
component: "comp1".to_string(), component: "comp1".to_string(),
endpoint: "ep1".to_string(), endpoint: "ep1".to_string(),
transport: TransportType::Nats("nats://localhost:4222".to_string()), transport: TransportType::Nats("nats://localhost:4222".to_string()),
device_type: None,
}; };
let instance = client.register(spec).await.unwrap(); let instance = client.register(spec).await.unwrap();
...@@ -638,6 +639,7 @@ mod tests { ...@@ -638,6 +639,7 @@ mod tests {
namespace: "ns1".to_string(), namespace: "ns1".to_string(),
component: "comp1".to_string(), component: "comp1".to_string(),
endpoint: "ep1".to_string(), endpoint: "ep1".to_string(),
device_type: None,
transport: TransportType::Nats("nats://localhost:4222".to_string()), transport: TransportType::Nats("nats://localhost:4222".to_string()),
}; };
client.register(spec1).await.unwrap(); client.register(spec1).await.unwrap();
...@@ -645,6 +647,7 @@ mod tests { ...@@ -645,6 +647,7 @@ mod tests {
let spec2 = DiscoverySpec::Endpoint { let spec2 = DiscoverySpec::Endpoint {
namespace: "ns1".to_string(), namespace: "ns1".to_string(),
component: "comp1".to_string(), component: "comp1".to_string(),
device_type: None,
endpoint: "ep2".to_string(), endpoint: "ep2".to_string(),
transport: TransportType::Nats("nats://localhost:4222".to_string()), transport: TransportType::Nats("nats://localhost:4222".to_string()),
}; };
...@@ -652,6 +655,7 @@ mod tests { ...@@ -652,6 +655,7 @@ mod tests {
let spec3 = DiscoverySpec::Endpoint { let spec3 = DiscoverySpec::Endpoint {
namespace: "ns2".to_string(), namespace: "ns2".to_string(),
device_type: None,
component: "comp2".to_string(), component: "comp2".to_string(),
endpoint: "ep1".to_string(), endpoint: "ep1".to_string(),
transport: TransportType::Nats("nats://localhost:4222".to_string()), transport: TransportType::Nats("nats://localhost:4222".to_string()),
...@@ -699,6 +703,7 @@ mod tests { ...@@ -699,6 +703,7 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let spec = DiscoverySpec::Endpoint { let spec = DiscoverySpec::Endpoint {
device_type: None,
namespace: "test".to_string(), namespace: "test".to_string(),
component: "comp1".to_string(), component: "comp1".to_string(),
endpoint: "ep1".to_string(), endpoint: "ep1".to_string(),
......
...@@ -404,6 +404,7 @@ mod tests { ...@@ -404,6 +404,7 @@ mod tests {
endpoint: "ep1".to_string(), endpoint: "ep1".to_string(),
instance_id: 123, instance_id: 123,
transport: TransportType::Nats("nats://localhost:4222".to_string()), transport: TransportType::Nats("nats://localhost:4222".to_string()),
device_type: None,
}); });
metadata.register_endpoint(instance).unwrap(); metadata.register_endpoint(instance).unwrap();
...@@ -436,6 +437,7 @@ mod tests { ...@@ -436,6 +437,7 @@ mod tests {
endpoint: format!("ep{}", i), endpoint: format!("ep{}", i),
instance_id: i, instance_id: i,
transport: TransportType::Nats("nats://localhost:4222".to_string()), transport: TransportType::Nats("nats://localhost:4222".to_string()),
device_type: None,
}); });
meta.register_endpoint(instance).unwrap(); meta.register_endpoint(instance).unwrap();
}) })
...@@ -464,6 +466,7 @@ mod tests { ...@@ -464,6 +466,7 @@ mod tests {
endpoint: format!("ep{}", i), endpoint: format!("ep{}", i),
instance_id: i, instance_id: i,
transport: TransportType::Nats("nats://localhost:4222".to_string()), transport: TransportType::Nats("nats://localhost:4222".to_string()),
device_type: None,
}); });
metadata.register_endpoint(instance).unwrap(); metadata.register_endpoint(instance).unwrap();
} }
...@@ -551,6 +554,7 @@ mod tests { ...@@ -551,6 +554,7 @@ mod tests {
endpoint: "ep1".to_string(), endpoint: "ep1".to_string(),
instance_id: 1, instance_id: 1,
transport: TransportType::Nats("nats://localhost:4222".to_string()), transport: TransportType::Nats("nats://localhost:4222".to_string()),
device_type: None,
}); });
metadata.register_endpoint(endpoint).unwrap(); metadata.register_endpoint(endpoint).unwrap();
......
...@@ -292,6 +292,7 @@ mod tests { ...@@ -292,6 +292,7 @@ mod tests {
component: "test-comp".to_string(), component: "test-comp".to_string(),
endpoint: "test-ep".to_string(), endpoint: "test-ep".to_string(),
transport: crate::component::TransportType::Nats("test-subject".to_string()), transport: crate::component::TransportType::Nats("test-subject".to_string()),
device_type: None,
}; };
let query = DiscoveryQuery::Endpoint { let query = DiscoveryQuery::Endpoint {
......
...@@ -20,7 +20,7 @@ mod kube; ...@@ -20,7 +20,7 @@ mod kube;
pub use kube::{KubeDiscoveryClient, hash_pod_name}; pub use kube::{KubeDiscoveryClient, hash_pod_name};
pub mod utils; pub mod utils;
use crate::component::TransportType; use crate::component::{DeviceType, TransportType};
pub use utils::watch_and_extract_field; pub use utils::watch_and_extract_field;
/// Transport kind for event plane - used for configuration and env var selection. /// Transport kind for event plane - used for configuration and env var selection.
...@@ -298,6 +298,9 @@ pub enum DiscoverySpec { ...@@ -298,6 +298,9 @@ pub enum DiscoverySpec {
endpoint: String, endpoint: String,
/// Transport type and routing information /// Transport type and routing information
transport: TransportType, transport: TransportType,
/// Optional execution device for this endpoint instance.
/// Used by hetero routing to distinguish CPU and CUDA workers.
device_type: Option<DeviceType>,
}, },
Model { Model {
namespace: String, namespace: String,
...@@ -368,12 +371,14 @@ impl DiscoverySpec { ...@@ -368,12 +371,14 @@ impl DiscoverySpec {
component, component,
endpoint, endpoint,
transport, transport,
device_type,
} => DiscoveryInstance::Endpoint(crate::component::Instance { } => DiscoveryInstance::Endpoint(crate::component::Instance {
namespace, namespace,
component, component,
endpoint, endpoint,
instance_id, instance_id,
transport, transport,
device_type,
}), }),
Self::Model { Self::Model {
namespace, namespace,
......
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
use super::{AsyncEngineContextProvider, ResponseStream}; use super::{AsyncEngineContextProvider, ResponseStream};
use crate::error::{BackendError, DynamoError, ErrorType, match_error_chain}; use crate::error::{BackendError, DynamoError, ErrorType, match_error_chain};
use crate::{ use crate::{
component::{Client, Endpoint, RoutingOccupancyState, get_or_create_routing_occupancy_state}, component::{
Client, DeviceType, Endpoint, RoutingOccupancyState, get_or_create_routing_occupancy_state,
},
dynamo_nvtx_range, dynamo_nvtx_range,
engine::{AsyncEngine, AsyncEngineContext, Data}, engine::{AsyncEngine, AsyncEngineContext, Data},
metrics::frontend_perf::STAGE_DURATION_SECONDS, metrics::frontend_perf::STAGE_DURATION_SECONDS,
...@@ -20,6 +22,7 @@ use futures::Stream; ...@@ -20,6 +22,7 @@ use futures::Stream;
use rand::Rng; use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
collections::HashMap,
marker::PhantomData, marker::PhantomData,
pin::Pin, pin::Pin,
sync::{ sync::{
...@@ -163,6 +166,8 @@ pub enum RouterMode { ...@@ -163,6 +166,8 @@ pub enum RouterMode {
KV, KV,
Direct, Direct,
LeastLoaded, LeastLoaded,
/// Device-aware weighted routing for heterogeneous workers.
DeviceAwareWeighted,
} }
impl RouterMode { impl RouterMode {
...@@ -201,6 +206,56 @@ fn p2c_select_from(occupancy_state: &RoutingOccupancyState, instance_ids: &[u64] ...@@ -201,6 +206,56 @@ fn p2c_select_from(occupancy_state: &RoutingOccupancyState, instance_ids: &[u64]
selected selected
} }
/// Select the target device group for the next request in `DeviceAwareWeighted` mode.
///
/// If only one class exists (all CPU or all non-CPU), returns that class directly.
/// If both classes exist, compares capability-normalized load and returns the less-loaded group.
///
/// Budget check (integer form):
/// `allowed_cpu_inflight = total_non_cpu_inflight * cpu_count / (ratio * non_cpu_count)`
/// and choose CPU when `total_cpu_inflight < allowed_cpu_inflight`.
///
/// `ratio` is `non_cpu_to_cpu_ratio` (from `DYN_ENCODER_CUDA_TO_CPU_RATIO`,
/// default `8` in `device_aware_weighted`).
fn device_aware_candidate_group(
state: &RoutingOccupancyState,
instance_ids: &[u64],
device_type_map: &HashMap<u64, Option<DeviceType>>,
non_cpu_to_cpu_ratio: usize,
) -> Vec<u64> {
let cpu_ids: Vec<u64> = instance_ids
.iter()
.copied()
.filter(|id| matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
.collect();
let non_cpu_ids: Vec<u64> = instance_ids
.iter()
.copied()
.filter(|id| !matches!(device_type_map.get(id), Some(Some(DeviceType::Cpu))))
.collect();
if cpu_ids.is_empty() {
return non_cpu_ids;
}
if non_cpu_ids.is_empty() {
return cpu_ids;
}
// Both classes exist: compute a budget for CPU in-flight requests.
let total_non_cpu_inflight: u64 = non_cpu_ids.iter().map(|id| state.load(*id)).sum();
let total_cpu_inflight: u64 = cpu_ids.iter().map(|id| state.load(*id)).sum();
let cpu_count = cpu_ids.len() as u64;
let non_cpu_count = non_cpu_ids.len() as u64;
let allowed_cpu_inflight = total_non_cpu_inflight.saturating_mul(cpu_count)
/ ((non_cpu_to_cpu_ratio as u64).saturating_mul(non_cpu_count));
if total_cpu_inflight < allowed_cpu_inflight {
cpu_ids
} else {
non_cpu_ids
}
}
async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> { async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
// Get network manager and create client (no mode checks!) // Get network manager and create client (no mode checks!)
let manager = endpoint.drt().network_manager(); let manager = endpoint.drt().network_manager();
...@@ -238,7 +293,9 @@ where ...@@ -238,7 +293,9 @@ where
let occupancy_state = if matches!( let occupancy_state = if matches!(
router_mode, router_mode,
RouterMode::PowerOfTwoChoices | RouterMode::LeastLoaded RouterMode::PowerOfTwoChoices
| RouterMode::LeastLoaded
| RouterMode::DeviceAwareWeighted
) { ) {
Some(get_or_create_routing_occupancy_state(&client.endpoint).await) Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
} else { } else {
...@@ -274,7 +331,9 @@ where ...@@ -274,7 +331,9 @@ where
let occupancy_state = if matches!( let occupancy_state = if matches!(
router_mode, router_mode,
RouterMode::PowerOfTwoChoices | RouterMode::LeastLoaded RouterMode::PowerOfTwoChoices
| RouterMode::LeastLoaded
| RouterMode::DeviceAwareWeighted
) { ) {
Some(get_or_create_routing_occupancy_state(&client.endpoint).await) Some(get_or_create_routing_occupancy_state(&client.endpoint).await)
} else { } else {
...@@ -394,6 +453,84 @@ where ...@@ -394,6 +453,84 @@ where
.await .await
} }
/// Issue a request using device-aware weighted routing.
///
/// Instances are partitioned by device type (CPU vs non-CPU), then the router
/// applies a budget policy and selects the least-loaded instance within the
/// chosen group.
///
/// If only one device class exists (all CPU or all non-CPU), this naturally
/// degenerates to least-loaded routing over the available instances.
pub async fn device_aware_weighted(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let state = self.occupancy_state()?;
let instance_ids = self
.client
.instance_ids_avail()
.iter()
.copied()
.collect::<Vec<_>>();
if instance_ids.is_empty() {
return Err(anyhow::anyhow!(
"no instances found for endpoint {}",
self.client.endpoint.id()
));
}
// Apply a unified policy for all endpoints.
let endpoint_id = self.client.endpoint.id();
// For encoder endpoints, partition by device type
let instances = self.client.instances();
let device_type_map: std::collections::HashMap<u64, Option<DeviceType>> = instances
.iter()
.map(|inst| (inst.instance_id, inst.device_type.clone()))
.collect();
// Apply budget-based routing to determine which group to send to
let cuda_to_cpu_ratio = std::env::var("DYN_ENCODER_CUDA_TO_CPU_RATIO")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|v| *v >= 1)
.unwrap_or(8);
let candidates = device_aware_candidate_group(
state.as_ref(),
&instance_ids,
&device_type_map,
cuda_to_cpu_ratio,
);
// Select least-loaded within the chosen group
let instance_id = state
.select_exact_min_and_increment(&candidates)
.await
.ok_or_else(|| {
anyhow::anyhow!(
"no instances in selected device group for endpoint {}",
endpoint_id
)
})?;
let permit = OccupancyPermit::new(state.clone(), instance_id);
let is_cpu = matches!(
device_type_map.get(&instance_id),
Some(Some(DeviceType::Cpu))
);
tracing::info!(
endpoint = %endpoint_id,
selected_instance = instance_id,
is_cpu,
"DeviceAwareWeighted selected instance"
);
match self
.generate_with_fault_detection(instance_id, request)
.await
{
Ok(stream) => Ok(permit.into_tracked_stream(stream)),
Err(err) => Err(err),
}
}
/// Issue a request to the instance with the fewest active connections. /// Issue a request to the instance with the fewest active connections.
pub async fn least_loaded(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> { pub async fn least_loaded(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let state = self.occupancy_state()?; let state = self.occupancy_state()?;
...@@ -446,7 +583,10 @@ where ...@@ -446,7 +583,10 @@ where
let counter = rand::rng().random::<u64>() as usize; let counter = rand::rng().random::<u64>() as usize;
Some(instance_ids[counter % count]) Some(instance_ids[counter % count])
} }
RouterMode::PowerOfTwoChoices | RouterMode::Direct | RouterMode::LeastLoaded => None, RouterMode::PowerOfTwoChoices
| RouterMode::Direct
| RouterMode::LeastLoaded
| RouterMode::DeviceAwareWeighted => None,
RouterMode::KV => { RouterMode::KV => {
panic!( panic!(
"select_next_worker should not be called for {:?} routing mode", "select_next_worker should not be called for {:?} routing mode",
...@@ -478,7 +618,10 @@ where ...@@ -478,7 +618,10 @@ where
let counter = rand::rng().random::<u64>() as usize; let counter = rand::rng().random::<u64>() as usize;
Some(instance_ids[counter % count]) Some(instance_ids[counter % count])
} }
RouterMode::PowerOfTwoChoices | RouterMode::Direct | RouterMode::LeastLoaded => None, RouterMode::PowerOfTwoChoices
| RouterMode::Direct
| RouterMode::LeastLoaded
| RouterMode::DeviceAwareWeighted => None,
RouterMode::KV => { RouterMode::KV => {
panic!( panic!(
"peek_next_worker should not be called for {:?} routing mode", "peek_next_worker should not be called for {:?} routing mode",
...@@ -730,6 +873,7 @@ where ...@@ -730,6 +873,7 @@ where
); );
} }
RouterMode::LeastLoaded => self.least_loaded(request).await, RouterMode::LeastLoaded => self.least_loaded(request).await,
RouterMode::DeviceAwareWeighted => self.device_aware_weighted(request).await,
} }
} }
} }
...@@ -959,6 +1103,117 @@ mod tests { ...@@ -959,6 +1103,117 @@ mod tests {
rt.shutdown(); rt.shutdown();
} }
#[tokio::test]
async fn device_aware_cpu_only_selects_least_loaded_instance() {
let state = RoutingOccupancyState::default();
// All candidates are CPU. Make worker 2 the least-loaded one.
for _ in 0..3 {
state.increment(1);
}
state.increment(3);
let instance_ids = vec![1, 2, 3];
let device_type_map = HashMap::from([
(1, Some(DeviceType::Cpu)),
(2, Some(DeviceType::Cpu)),
(3, Some(DeviceType::Cpu)),
]);
let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
assert_eq!(candidates, vec![1, 2, 3]);
let selected = state
.select_exact_min_and_increment(&candidates)
.await
.unwrap();
assert_eq!(selected, 2);
}
#[tokio::test]
async fn device_aware_non_cpu_only_selects_least_loaded_instance() {
let state = RoutingOccupancyState::default();
// All candidates are non-CPU. Make worker 2 the least-loaded one.
for _ in 0..3 {
state.increment(1);
}
state.increment(3);
let instance_ids = vec![1, 2, 3];
let device_type_map = HashMap::from([
(1, Some(DeviceType::Cuda)),
(2, Some(DeviceType::Cuda)),
(3, Some(DeviceType::Cuda)),
]);
let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 8);
assert_eq!(candidates, vec![1, 2, 3]);
let selected = state
.select_exact_min_and_increment(&candidates)
.await
.unwrap();
assert_eq!(selected, 2);
}
#[test]
fn device_aware_group_uses_ratio_budget() {
let state = RoutingOccupancyState::default();
// CPU ids: 1,2 ; non-CPU ids: 3,4
for _ in 0..4 {
state.increment(3);
state.increment(4);
}
// CPU inflight can differ across instances; budgeting uses total CPU inflight.
for _ in 0..3 {
state.increment(1);
}
// total_non_cpu_inflight=8, cpu_count=2, non_cpu_count=2, ratio=2
// allowed_cpu_inflight = 8*2/(2*2)=4
// total_cpu_inflight=3 < 4 => choose CPU group.
let instance_ids = vec![1, 2, 3, 4];
let device_type_map = HashMap::from([
(1, Some(DeviceType::Cpu)),
(2, Some(DeviceType::Cpu)),
(3, Some(DeviceType::Cuda)),
(4, Some(DeviceType::Cuda)),
]);
let candidates = device_aware_candidate_group(&state, &instance_ids, &device_type_map, 2);
assert_eq!(candidates, vec![1, 2]);
// Within selected CPU group, final choice should be the least-loaded instance (id=2).
let selected =
futures::executor::block_on(state.select_exact_min_and_increment(&candidates)).unwrap();
assert_eq!(selected, 2);
}
#[tokio::test]
async fn device_aware_weighted_select_and_peek_return_none_with_available_worker() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt
.namespace("test_device_aware_router".to_string())
.unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
client.wait_for_instances().await.unwrap();
let router =
PushRouter::<u64, TestResponse>::from_client(client, RouterMode::DeviceAwareWeighted)
.await
.unwrap();
assert_eq!(router.select_next_worker(), None);
assert_eq!(router.peek_next_worker(), None);
rt.shutdown();
}
/// When the router selects an instance that has deregistered between selection /// When the router selects an instance that has deregistered between selection
/// and transport resolution, it should fall back to another available instance /// and transport resolution, it should fall back to another available instance
/// rather than returning a 500 error. /// rather than returning a 500 error.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment