Unverified Commit 200a3c0b authored by Minglei Zhu's avatar Minglei Zhu Committed by GitHub
Browse files

[Documentation] add doc for deterministic inference (#11956)

parent 77258ce0
# Deterministic Inference
## Why Deterministic Inference Matters
Deterministic inference ensures consistent LLM outputs across runs, which is critical for:
- **Reinforcement Learning**: Ensures consistent logprobs across runs, reducing stochastic noise and making RL training more stable, reproducible, and debuggable.
- **Testing & Debugging**: Enables reproducible validation
- **Production**: Improves reliability and user experience
Even with `temperature=0`, standard LLM inference can produce different outputs due to dynamic batching and varying reduction orders in GPU kernels.
## The Root Cause of Non-Determinism
The main source is **varying batch sizes**. Different batch sizes cause GPU kernels to split reduction operations differently, leading to different addition orders. Due to floating-point non-associativity (`(a + b) + c ≠ a + (b + c)`), this produces different results even for identical inputs.
## SGLang's Solution
Building on [Thinking Machines Lab's batch-invariant operators](https://github.com/thinking-machines-lab/batch_invariant_ops), SGLang achieves fully deterministic inference while maintaining compatibility with chunked prefill, CUDA graphs, radix cache, and non-greedy sampling. The development roadmap for deterministic inference features can be found in this [issue](https://github.com/sgl-project/sglang/issues/10278).
### Supported Backends
Deterministic inference is only supported with the following three attention backends: **FlashInfer**, **FlashAttention 3 (FA3)**, and **Triton**.
The following table shows feature compatibility for deterministic inference across different attention backends:
| Attention Backend | CUDA Graph | Chunked Prefill | Radix Cache | Non-greedy Sampling (Temp > 0) |
|-------------------|------------|-----------------|-------------|---------------------|
| **FlashInfer** | ✅ Yes | ✅ Yes | ❌ No | ✅ Yes |
| **FlashAttention 3 (FA3)** | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
| **Triton** | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
## Usage
### Basic Usage
Enable deterministic inference by adding the `--enable-deterministic-inference` flag:
```bash
python3 -m sglang.launch_server \
--model-path Qwen/Qwen3-8B \
--attention-backend fa3 \
--enable-deterministic-inference
```
### Server Arguments
| Argument | Type/Default | Description |
|----------|--------------|-------------|
| `--enable-deterministic-inference` | flag; default: disabled | Enable deterministic inference with batch-invariant operations |
| `--attention-backend` | string; default: fa3 | Choose attention backend (flashinfer, fa3, or triton) |
### Example Configurations
#### Qwen3-8B
```bash
python3 -m sglang.launch_server \
--model-path Qwen/Qwen3-8B \
--attention-backend flashinfer \
--enable-deterministic-inference
```
#### Llama Models
```bash
python3 -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \
--attention-backend fa3 \
--enable-deterministic-inference
```
#### Qwen3-30B-A3B (MoE Model)
```bash
python3 -m sglang.launch_server \
--model-path Qwen/Qwen3-30B-A3B \
--attention-backend fa3 \
--enable-deterministic-inference
```
### Deterministic Inference with Non-Greedy Sampling (Temperature > 0)
SGLang supports deterministic inference even with non-greedy sampling by using sampling seeds. This is particularly useful for reinforcement learning scenarios like GRPO (Group Relative Policy Optimization) where you need multiple diverse but reproducible responses.
#### Default Behavior
By default, SGLang uses a sampling seed of `42` for reproducible sampling:
```python
import requests
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "Tell me a joke",
"sampling_params": {
"temperature": 0.8, # Non-greedy sampling
"max_new_tokens": 128,
},
},
)
print(response.json())
# This will always produce the same response across runs
```
#### Generating Multiple Reproducible Responses
To sample different responses from the same prompt while maintaining reproducibility (e.g., for GRPO training), provide different sampling seeds in your requests:
```python
import requests
# Prepare a list of sampling seeds for different responses
sampling_seeds = [42, 43, 44, 45, 46]
responses = []
for seed in sampling_seeds:
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "Tell me a joke",
"sampling_params": {
"temperature": 0.8,
"max_new_tokens": 128,
"sampling_seed": seed, # Specify sampling seed
},
},
)
responses.append(response.json())
# Each seed will produce a different but reproducible response
# Using the same seed will always produce the same response
```
This approach ensures that:
- Different seeds produce diverse responses
- The same seed always produces the same response across different runs
- Results are reproducible for debugging and evaluation
## Verification
Run deterministic tests to verify consistent outputs:
```bash
# Single test: same prompt, varying batch sizes
python3 -m sglang.test.test_deterministic --test-mode single --n-trials 50
# Prefix test: prompts with different prefix lengths
python3 -m sglang.test.test_deterministic --test-mode prefix --n-trials 50
# Radix Cache Consistency mode: test radix cache determinism (cached vs uncached prefill)
python3 -m sglang.test.test_deterministic --test-mode radix_cache
```
Expected result: All tests should show `Unique samples: 1` (perfectly deterministic).
......@@ -50,6 +50,7 @@ Its core features include:
advanced_features/pd_multiplexing.md
advanced_features/vlm_query.ipynb
advanced_features/router.md
advanced_features/deterministic_inference.md
advanced_features/observability.md
.. toctree::
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment