Unverified Commit a80bcb5a authored by yinghui's avatar yinghui Committed by GitHub
Browse files

Add env var to disable FA4 warmup (#12430)

parent f7f9e41b
......@@ -28,6 +28,7 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_SKIP_P2P_CHECK` | Skip P2P (peer-to-peer) access check | `false` |
| `SGL_CHUNKED_PREFIX_CACHE_THRESHOLD` | Sets the threshold for enabling chunked prefix caching | `8192` |
| `SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION` | Enable RoPE fusion in Fused Multi-Layer Attention | `1` |
| `SGLANG_DISABLE_FA4_WARMUP` | Disable Flash Attention 4 warmup passes (set to `1`, `true`, `yes`, or `on` to disable) | `false` |
## DeepGEMM Configuration (Advanced Optimization)
......
......@@ -8,6 +8,7 @@ import copy
import gc
import logging
import math
import os
from typing import Callable, Optional, Tuple
logger = logging.getLogger(__name__)
......@@ -416,6 +417,15 @@ def warmup_flash_attn(f):
- Executes sequentially to minimize peak GPU mem
- Does not modify user tensors (clones)
"""
disable_warmup = os.getenv("SGLANG_DISABLE_FA4_WARMUP", "").lower() in (
"1",
"true",
"yes",
"on",
)
if disable_warmup:
return f
done = False
def _clone_args(args, kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment