--- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 title: SGLang for Agentic Workloads subtitle: Priority scheduling and KV cache eviction policies for multi-turn agentic serving --- # SGLang for Agentic Workloads This guide covers SGLang-specific configuration for agentic serving with Dynamo. It explains which SGLang engine flags to enable and how Dynamo's [agent hints](../../components/frontend/nvext.md#agent-hints) map to SGLang behavior. ## Overview Agentic workloads (tool-calling loops, multi-turn reasoning, code generation pipelines) have different performance characteristics than batch inference: - **Prefix-heavy**: Successive turns share a growing conversation prefix. KV cache reuse is critical for low TTFT. - **Priority-sensitive**: Some requests (user-facing agent turns) matter more than background tasks. - **Long-lived**: Conversations span minutes to hours. Cache eviction under memory pressure can destroy accumulated KV state. Dynamo's agent hints give the router per-request metadata. SGLang's engine flags control how that metadata affects scheduling and eviction on the worker. ## SGLang Engine Flags ### Priority Scheduling Enable priority-based scheduling so the engine respects the `priority` value from `nvext.agent_hints.priority`: ```bash python -m dynamo.sglang \ --model-path \ --enable-priority-scheduling \ ... ``` | Flag | Description | |------|-------------| | `--enable-priority-scheduling` | Enables priority-based request scheduling instead of FCFS. | When priority scheduling is enabled, the engine uses the `priority` field from `nvext.agent_hints` to order requests in its internal queue. Requests with higher effective priority are scheduled before lower-priority ones. Ties are broken by arrival time. ### Priority-Based KV Cache Eviction By default, SGLang evicts radix tree nodes using LRU. You can switch to priority-based eviction so that low-priority cache entries are evicted before high-priority ones: ```bash python -m dynamo.sglang \ --model-path \ --radix-eviction-policy priority \ ... ``` | Flag | Values | Default | Description | |------|--------|---------|-------------| | `--radix-eviction-policy` | `lru`, `priority` | `lru` | Eviction strategy for the GPU radix cache. `priority` uses a heap ordered by the request's priority value. | This does **not** require HiCache. It controls GPU-only radix tree eviction. When the GPU KV cache is full: - **`lru`**: Evicts the least recently used leaf nodes first. - **`priority`**: Evicts lowest-priority leaf nodes first. Nodes with equal priority fall back to LRU ordering. #### Interaction with HiCache When both `--radix-eviction-policy priority` and `--enable-hierarchical-cache` are enabled, priority affects eviction at both tiers: | Event | Behavior | |-------|----------| | **GPU full** | Low-priority nodes are evicted (demoted to host) first. With `write_through`, all nodes survive on host -- priority only affects demotion order. | | **Host full** | Low-priority nodes are deleted from host first. High-priority nodes survive longer. Pinned nodes are skipped entirely. | The practical impact depends on your write policy. With `write_through`, GPU eviction is just a demotion -- the real deletion happens at host eviction, which is where priority ordering matters most. ## How Agent Hints Map to SGLang Dynamo's `nvext.agent_hints` fields are consumed by the router and forwarded to SGLang workers. Here is how each hint interacts with the SGLang engine: | Agent Hint | Router Behavior | SGLang Engine Behavior | |------------|----------------|----------------------| | `priority` | Raises router queue priority when `--router-queue-threshold` is set. | Queue ordering when `--enable-priority-scheduling` is set. Also affects radix cache eviction order when `--radix-eviction-policy priority` is set. | | `osl` | Output block tracking for routing decisions (requires `--router-track-output-blocks`) | No direct engine effect. | | `speculative_prefill` | After response completes, sends a `max_tokens=1` prefill to warm the KV cache for the predicted next turn. | SGLang processes the prefill request normally, populating the radix cache. | ### Example: Agentic Request with Hints ```python from openai import OpenAI client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy") response = client.chat.completions.create( model="Qwen/Qwen3-14B-FP8", messages=[ {"role": "system", "content": "You are a coding assistant."}, {"role": "user", "content": "Write a Python function to parse CSV files."}, ], stream=True, extra_body={ "nvext": { "agent_hints": { "priority": 10, "speculative_prefill": True, "osl": 512 } } } ) for chunk in response: if chunk.choices[0].delta.content: print(chunk.choices[0].delta.content, end="") ``` ## See Also - **[NVIDIA Request Extensions (nvext)](../../components/frontend/nvext.md)**: Full `nvext` field reference including agent hints - **[Router Guide](../../components/router/router-guide.md)**: Router configuration and CLI arguments - **[SGLang HiCache](../../integrations/sglang-hicache.md)**: Enabling hierarchical KV cache