Unverified Commit acc2327b authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Move deep gemm related arguments to `sglang.srt.environ` (#11547)

parent bfadb5ea
...@@ -144,7 +144,7 @@ With data parallelism attention enabled, we have achieved up to **1.9x** decodin ...@@ -144,7 +144,7 @@ With data parallelism attention enabled, we have achieved up to **1.9x** decodin
- **DeepGEMM**: The [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM) kernel library optimized for FP8 matrix multiplications. - **DeepGEMM**: The [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM) kernel library optimized for FP8 matrix multiplications.
**Usage**: The activation and weight optimization above are turned on by default for DeepSeek V3 models. DeepGEMM is enabled by default on NVIDIA Hopper GPUs and disabled by default on other devices. DeepGEMM can also be manually turned off by setting the environment variable `SGL_ENABLE_JIT_DEEPGEMM=0`. **Usage**: The activation and weight optimization above are turned on by default for DeepSeek V3 models. DeepGEMM is enabled by default on NVIDIA Hopper GPUs and disabled by default on other devices. DeepGEMM can also be manually turned off by setting the environment variable `SGLANG_ENABLE_JIT_DEEPGEMM=0`.
Before serving the DeepSeek model, precompile the DeepGEMM kernels using: Before serving the DeepSeek model, precompile the DeepGEMM kernels using:
```bash ```bash
......
...@@ -32,9 +32,9 @@ SGLang supports various environment variables that can be used to configure its ...@@ -32,9 +32,9 @@ SGLang supports various environment variables that can be used to configure its
| Environment Variable | Description | Default Value | | Environment Variable | Description | Default Value |
| --- | --- | --- | | --- | --- | --- |
| `SGL_ENABLE_JIT_DEEPGEMM` | Enable Just-In-Time compilation of DeepGEMM kernels | `"true"` | | `SGLANG_ENABLE_JIT_DEEPGEMM` | Enable Just-In-Time compilation of DeepGEMM kernels | `"true"` |
| `SGL_JIT_DEEPGEMM_PRECOMPILE` | Enable precompilation of DeepGEMM kernels | `"true"` | | `SGLANG_JIT_DEEPGEMM_PRECOMPILE` | Enable precompilation of DeepGEMM kernels | `"true"` |
| `SGL_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` | | `SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` |
| `SGL_IN_DEEPGEMM_PRECOMPILE_STAGE` | Indicator flag used during the DeepGEMM precompile script | `"false"` | | `SGL_IN_DEEPGEMM_PRECOMPILE_STAGE` | Indicator flag used during the DeepGEMM precompile script | `"false"` |
| `SGL_DG_CACHE_DIR` | Directory for caching compiled DeepGEMM kernels | `~/.cache/deep_gemm` | | `SGL_DG_CACHE_DIR` | Directory for caching compiled DeepGEMM kernels | `~/.cache/deep_gemm` |
| `SGL_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` | | `SGL_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` |
......
...@@ -80,7 +80,7 @@ spec: ...@@ -80,7 +80,7 @@ spec:
value: "true" value: "true"
- name: SGLANG_MOONCAKE_TRANS_THREAD - name: SGLANG_MOONCAKE_TRANS_THREAD
value: "16" value: "16"
- name: SGL_ENABLE_JIT_DEEPGEMM - name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1" value: "1"
- name: NCCL_IB_HCA - name: NCCL_IB_HCA
value: ^=mlx5_0,mlx5_5,mlx5_6 value: ^=mlx5_0,mlx5_5,mlx5_6
...@@ -217,7 +217,7 @@ spec: ...@@ -217,7 +217,7 @@ spec:
value: "5" value: "5"
- name: SGLANG_MOONCAKE_TRANS_THREAD - name: SGLANG_MOONCAKE_TRANS_THREAD
value: "16" value: "16"
- name: SGL_ENABLE_JIT_DEEPGEMM - name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1" value: "1"
- name: NCCL_IB_HCA - name: NCCL_IB_HCA
value: ^=mlx5_0,mlx5_5,mlx5_6 value: ^=mlx5_0,mlx5_5,mlx5_6
......
...@@ -71,7 +71,7 @@ spec: ...@@ -71,7 +71,7 @@ spec:
value: "1" value: "1"
- name: SGLANG_SET_CPU_AFFINITY - name: SGLANG_SET_CPU_AFFINITY
value: "true" value: "true"
- name: SGL_ENABLE_JIT_DEEPGEMM - name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1" value: "1"
- name: NCCL_IB_QPS_PER_CONNECTION - name: NCCL_IB_QPS_PER_CONNECTION
value: "8" value: "8"
...@@ -224,7 +224,7 @@ spec: ...@@ -224,7 +224,7 @@ spec:
value: "0" value: "0"
- name: SGLANG_MOONCAKE_TRANS_THREAD - name: SGLANG_MOONCAKE_TRANS_THREAD
value: "8" value: "8"
- name: SGL_ENABLE_JIT_DEEPGEMM - name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1" value: "1"
- name: SGL_CHUNKED_PREFIX_CACHE_THRESHOLD - name: SGL_CHUNKED_PREFIX_CACHE_THRESHOLD
value: "0" value: "0"
......
...@@ -98,7 +98,7 @@ spec: ...@@ -98,7 +98,7 @@ spec:
value: "1" value: "1"
- name: SGLANG_SET_CPU_AFFINITY - name: SGLANG_SET_CPU_AFFINITY
value: "true" value: "true"
- name: SGL_ENABLE_JIT_DEEPGEMM - name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1" value: "1"
- name: NCCL_IB_QPS_PER_CONNECTION - name: NCCL_IB_QPS_PER_CONNECTION
value: "8" value: "8"
...@@ -257,7 +257,7 @@ spec: ...@@ -257,7 +257,7 @@ spec:
value: "0" value: "0"
- name: SGLANG_MOONCAKE_TRANS_THREAD - name: SGLANG_MOONCAKE_TRANS_THREAD
value: "8" value: "8"
- name: SGL_ENABLE_JIT_DEEPGEMM - name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1" value: "1"
- name: SGL_CHUNKED_PREFIX_CACHE_THRESHOLD - name: SGL_CHUNKED_PREFIX_CACHE_THRESHOLD
value: "0" value: "0"
...@@ -421,7 +421,7 @@ spec: ...@@ -421,7 +421,7 @@ spec:
value: "true" value: "true"
- name: SGLANG_MOONCAKE_TRANS_THREAD - name: SGLANG_MOONCAKE_TRANS_THREAD
value: "16" value: "16"
- name: SGL_ENABLE_JIT_DEEPGEMM - name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1" value: "1"
- name: NCCL_IB_HCA - name: NCCL_IB_HCA
value: ^=mlx5_0,mlx5_5,mlx5_6 value: ^=mlx5_0,mlx5_5,mlx5_6
...@@ -560,7 +560,7 @@ spec: ...@@ -560,7 +560,7 @@ spec:
value: "5" value: "5"
- name: SGLANG_MOONCAKE_TRANS_THREAD - name: SGLANG_MOONCAKE_TRANS_THREAD
value: "16" value: "16"
- name: SGL_ENABLE_JIT_DEEPGEMM - name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1" value: "1"
- name: NCCL_IB_HCA - name: NCCL_IB_HCA
value: ^=mlx5_0,mlx5_5,mlx5_6 value: ^=mlx5_0,mlx5_5,mlx5_6
......
...@@ -19,6 +19,7 @@ import requests ...@@ -19,6 +19,7 @@ import requests
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.environ import envs
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -28,9 +29,9 @@ from sglang.srt.warmup import warmup ...@@ -28,9 +29,9 @@ from sglang.srt.warmup import warmup
multiprocessing.set_start_method("spawn", force=True) multiprocessing.set_start_method("spawn", force=True)
# Reduce warning # Reduce warning
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1" envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.set(True)
# Force enable deep gemm # Force enable deep gemm
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1" envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(True)
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case # Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0" os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
......
...@@ -180,6 +180,7 @@ class Envs: ...@@ -180,6 +180,7 @@ class Envs:
SGLANG_EXPERT_LOCATION_UPDATER_CANARY = EnvBool(False) SGLANG_EXPERT_LOCATION_UPDATER_CANARY = EnvBool(False)
SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS = EnvBool(False) SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS = EnvBool(False)
SGLANG_LOG_EXPERT_LOCATION_METADATA = EnvBool(False) SGLANG_LOG_EXPERT_LOCATION_METADATA = EnvBool(False)
SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR = EnvStr("/tmp")
# TBO # TBO
SGLANG_TBO_DEBUG = EnvBool(False) SGLANG_TBO_DEBUG = EnvBool(False)
......
...@@ -16,21 +16,20 @@ from __future__ import annotations ...@@ -16,21 +16,20 @@ from __future__ import annotations
import logging import logging
import math import math
import os
import time import time
from abc import ABC from abc import ABC
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
import einops import einops
import torch import torch
import torch.distributed import torch.distributed
from sglang.srt.environ import envs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var, is_npu from sglang.srt.utils import Withable, is_npu
_is_npu = is_npu() _is_npu = is_npu()
...@@ -839,7 +838,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin): ...@@ -839,7 +838,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
def _dump_to_file(name, data): def _dump_to_file(name, data):
save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp")) save_dir = envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.get()
path_output = save_dir / name path_output = save_dir / name
logger.info(f"Write expert distribution to {path_output}") logger.info(f"Write expert distribution to {path_output}")
if not save_dir.exists(): if not save_dir.exists():
......
...@@ -7,11 +7,12 @@ from typing import Dict, List, Tuple ...@@ -7,11 +7,12 @@ from typing import Dict, List, Tuple
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from sglang.srt.environ import envs
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
ENABLE_JIT_DEEPGEMM, ENABLE_JIT_DEEPGEMM,
) )
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var from sglang.srt.utils import ceil_div, get_bool_env_var
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -20,12 +21,9 @@ if ENABLE_JIT_DEEPGEMM: ...@@ -20,12 +21,9 @@ if ENABLE_JIT_DEEPGEMM:
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var( _ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get()
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
)
_DO_COMPILE_ALL = True _DO_COMPILE_ALL = True
_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true") _IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false") _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
# Force redirect deep_gemm cache_dir # Force redirect deep_gemm cache_dir
......
import logging import logging
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell from sglang.srt.environ import envs
from sglang.srt.utils import get_device_sm, is_blackwell
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -15,7 +16,7 @@ def _compute_enable_deep_gemm(): ...@@ -15,7 +16,7 @@ def _compute_enable_deep_gemm():
except ImportError: except ImportError:
return False return False
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true") return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get()
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm() ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
......
...@@ -5,6 +5,7 @@ from pathlib import Path ...@@ -5,6 +5,7 @@ from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
import sglang as sgl import sglang as sgl
from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -23,6 +24,10 @@ class _BaseTestDynamicEPLB(CustomTestCase): ...@@ -23,6 +24,10 @@ class _BaseTestDynamicEPLB(CustomTestCase):
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
with (
envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False),
envs.SGLANG_EXPERT_LOCATION_UPDATER_CANARY.override(True),
):
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
cls.base_url, cls.base_url,
...@@ -55,11 +60,6 @@ class _BaseTestDynamicEPLB(CustomTestCase): ...@@ -55,11 +60,6 @@ class _BaseTestDynamicEPLB(CustomTestCase):
"static", "static",
*cls.extra_args, *cls.extra_args,
], ],
env={
"SGL_ENABLE_JIT_DEEPGEMM": "0",
"SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1",
**os.environ,
},
) )
@classmethod @classmethod
...@@ -89,7 +89,7 @@ class TestDynamicEPLBMultiChunk(_BaseTestDynamicEPLB): ...@@ -89,7 +89,7 @@ class TestDynamicEPLBMultiChunk(_BaseTestDynamicEPLB):
class TestStaticEPLB(CustomTestCase): class TestStaticEPLB(CustomTestCase):
def test_save_expert_distribution_and_init_expert_location(self): def test_save_expert_distribution_and_init_expert_location(self):
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0" envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
engine_kwargs = dict( engine_kwargs = dict(
...@@ -108,7 +108,7 @@ class TestStaticEPLB(CustomTestCase): ...@@ -108,7 +108,7 @@ class TestStaticEPLB(CustomTestCase):
) )
print(f"Action: start engine") print(f"Action: start engine")
os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.set(tmp_dir)
engine = sgl.Engine( engine = sgl.Engine(
**engine_kwargs, **engine_kwargs,
disable_overlap_schedule=True, disable_overlap_schedule=True,
......
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -55,6 +56,7 @@ class TestDPAttn(unittest.TestCase): ...@@ -55,6 +56,7 @@ class TestDPAttn(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
with envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False):
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
cls.base_url, cls.base_url,
...@@ -92,10 +94,6 @@ class TestDPAttn(unittest.TestCase): ...@@ -92,10 +94,6 @@ class TestDPAttn(unittest.TestCase):
} }
), ),
], ],
env={
"SGL_ENABLE_JIT_DEEPGEMM": "0",
**os.environ,
},
) )
@classmethod @classmethod
......
import os
import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.environ import envs
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -18,8 +17,7 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase): ...@@ -18,8 +17,7 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
def setUpClass(cls): def setUpClass(cls):
super().setUpClass() super().setUpClass()
# Temporarily disable JIT DeepGEMM # Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
...@@ -90,8 +88,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase): ...@@ -90,8 +88,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
def setUpClass(cls): def setUpClass(cls):
super().setUpClass() super().setUpClass()
# Temporarily disable JIT DeepGEMM # Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
...@@ -162,8 +159,7 @@ class TestDisaggregationMooncakeMHAPrefillLargerTP(TestDisaggregationBase): ...@@ -162,8 +159,7 @@ class TestDisaggregationMooncakeMHAPrefillLargerTP(TestDisaggregationBase):
def setUpClass(cls): def setUpClass(cls):
super().setUpClass() super().setUpClass()
# Temporarily disable JIT DeepGEMM # Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
...@@ -234,8 +230,7 @@ class TestDisaggregationMooncakeMHADecodeLargerTP(TestDisaggregationBase): ...@@ -234,8 +230,7 @@ class TestDisaggregationMooncakeMHADecodeLargerTP(TestDisaggregationBase):
def setUpClass(cls): def setUpClass(cls):
super().setUpClass() super().setUpClass()
# Temporarily disable JIT DeepGEMM # Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
......
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.environ import envs
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -16,8 +17,7 @@ class TestDisaggregationDPAttention(TestDisaggregationBase): ...@@ -16,8 +17,7 @@ class TestDisaggregationDPAttention(TestDisaggregationBase):
def setUpClass(cls): def setUpClass(cls):
super().setUpClass() super().setUpClass()
# Temporarily disable JIT DeepGEMM # Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
......
...@@ -6,9 +6,9 @@ from pathlib import Path ...@@ -6,9 +6,9 @@ from pathlib import Path
import requests import requests
import torch import torch
from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_BASE,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase, CustomTestCase,
...@@ -32,7 +32,7 @@ class TestExpertDistribution(CustomTestCase): ...@@ -32,7 +32,7 @@ class TestExpertDistribution(CustomTestCase):
def _execute_core(self, model_path: str, mode: str = "stat", tp_size: int = 1): def _execute_core(self, model_path: str, mode: str = "stat", tp_size: int = 1):
"""Test expert distribution record endpoints""" """Test expert distribution record endpoints"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.set(tmp_dir)
process = popen_launch_server( process = popen_launch_server(
model_path, model_path,
......
import os
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests import requests
from sglang.srt.environ import envs
from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -77,8 +77,10 @@ class BaseFlashAttentionTest(CustomTestCase): ...@@ -77,8 +77,10 @@ class BaseFlashAttentionTest(CustomTestCase):
def setUpClass(cls): def setUpClass(cls):
# disable deep gemm precompile to make launch server faster # disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster # please don't do this if you want to make your inference workload faster
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" with (
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.override(False),
envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False),
):
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
cls.base_url, cls.base_url,
......
...@@ -4,6 +4,7 @@ from types import SimpleNamespace ...@@ -4,6 +4,7 @@ from types import SimpleNamespace
import requests import requests
from sglang.srt.environ import envs
from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -49,8 +50,10 @@ class TestHybridAttnBackendBase(CustomTestCase): ...@@ -49,8 +50,10 @@ class TestHybridAttnBackendBase(CustomTestCase):
def setUpClass(cls): def setUpClass(cls):
# disable deep gemm precompile to make launch server faster # disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster # please don't do this if you want to make your inference workload faster
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" with (
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.override(False),
envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False),
):
if cls.speculative_decode: if cls.speculative_decode:
model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
else: else:
......
import os
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests import requests
from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -47,8 +47,8 @@ class TestNgramSpeculativeDecodingBase(CustomTestCase): ...@@ -47,8 +47,8 @@ class TestNgramSpeculativeDecodingBase(CustomTestCase):
def setUpClass(cls): def setUpClass(cls):
# disable deep gemm precompile to make launch server faster # disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster # please don't do this if you want to make your inference workload faster
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.set(False)
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
model = cls.model model = cls.model
cls.process = popen_launch_server( cls.process = popen_launch_server(
model, model,
......
import os
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests import requests
from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -55,8 +55,8 @@ class TestStandaloneSpeculativeDecodingBase(CustomTestCase): ...@@ -55,8 +55,8 @@ class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
def setUpClass(cls): def setUpClass(cls):
# disable deep gemm precompile to make launch server faster # disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster # please don't do this if you want to make your inference workload faster
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.set(False)
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
model = cls.model model = cls.model
cls.process = popen_launch_server( cls.process = popen_launch_server(
model, model,
......
import os
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests import requests
from sglang.srt.environ import envs
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.two_batch_overlap import ( from sglang.srt.two_batch_overlap import (
compute_split_seq_index, compute_split_seq_index,
...@@ -25,6 +25,7 @@ class TestTwoBatchOverlap(unittest.TestCase): ...@@ -25,6 +25,7 @@ class TestTwoBatchOverlap(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
with envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False):
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
cls.base_url, cls.base_url,
...@@ -43,7 +44,6 @@ class TestTwoBatchOverlap(unittest.TestCase): ...@@ -43,7 +44,6 @@ class TestTwoBatchOverlap(unittest.TestCase):
"--disable-cuda-graph", # DeepEP normal does not support CUDA Graph "--disable-cuda-graph", # DeepEP normal does not support CUDA Graph
"--enable-two-batch-overlap", "--enable-two-batch-overlap",
], ],
env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ},
) )
@classmethod @classmethod
...@@ -126,6 +126,7 @@ class TestQwen3TwoBatchOverlap(TestTwoBatchOverlap): ...@@ -126,6 +126,7 @@ class TestQwen3TwoBatchOverlap(TestTwoBatchOverlap):
cls.model = DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST cls.model = DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-1234" cls.api_key = "sk-1234"
with envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False):
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
cls.base_url, cls.base_url,
...@@ -144,7 +145,6 @@ class TestQwen3TwoBatchOverlap(TestTwoBatchOverlap): ...@@ -144,7 +145,6 @@ class TestQwen3TwoBatchOverlap(TestTwoBatchOverlap):
"--disable-cuda-graph", # DeepEP normal does not support CUDA Graph "--disable-cuda-graph", # DeepEP normal does not support CUDA Graph
"--enable-two-batch-overlap", "--enable-two-batch-overlap",
], ],
env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ},
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment