Unverified Commit acc2327b authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Move deep gemm related arguments to `sglang.srt.environ` (#11547)

parent bfadb5ea
......@@ -144,7 +144,7 @@ With data parallelism attention enabled, we have achieved up to **1.9x** decodin
- **DeepGEMM**: The [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM) kernel library optimized for FP8 matrix multiplications.
**Usage**: The activation and weight optimization above are turned on by default for DeepSeek V3 models. DeepGEMM is enabled by default on NVIDIA Hopper GPUs and disabled by default on other devices. DeepGEMM can also be manually turned off by setting the environment variable `SGL_ENABLE_JIT_DEEPGEMM=0`.
**Usage**: The activation and weight optimization above are turned on by default for DeepSeek V3 models. DeepGEMM is enabled by default on NVIDIA Hopper GPUs and disabled by default on other devices. DeepGEMM can also be manually turned off by setting the environment variable `SGLANG_ENABLE_JIT_DEEPGEMM=0`.
Before serving the DeepSeek model, precompile the DeepGEMM kernels using:
```bash
......
......@@ -32,9 +32,9 @@ SGLang supports various environment variables that can be used to configure its
| Environment Variable | Description | Default Value |
| --- | --- | --- |
| `SGL_ENABLE_JIT_DEEPGEMM` | Enable Just-In-Time compilation of DeepGEMM kernels | `"true"` |
| `SGL_JIT_DEEPGEMM_PRECOMPILE` | Enable precompilation of DeepGEMM kernels | `"true"` |
| `SGL_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` |
| `SGLANG_ENABLE_JIT_DEEPGEMM` | Enable Just-In-Time compilation of DeepGEMM kernels | `"true"` |
| `SGLANG_JIT_DEEPGEMM_PRECOMPILE` | Enable precompilation of DeepGEMM kernels | `"true"` |
| `SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` |
| `SGL_IN_DEEPGEMM_PRECOMPILE_STAGE` | Indicator flag used during the DeepGEMM precompile script | `"false"` |
| `SGL_DG_CACHE_DIR` | Directory for caching compiled DeepGEMM kernels | `~/.cache/deep_gemm` |
| `SGL_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` |
......
......@@ -80,7 +80,7 @@ spec:
value: "true"
- name: SGLANG_MOONCAKE_TRANS_THREAD
value: "16"
- name: SGL_ENABLE_JIT_DEEPGEMM
- name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1"
- name: NCCL_IB_HCA
value: ^=mlx5_0,mlx5_5,mlx5_6
......@@ -217,7 +217,7 @@ spec:
value: "5"
- name: SGLANG_MOONCAKE_TRANS_THREAD
value: "16"
- name: SGL_ENABLE_JIT_DEEPGEMM
- name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1"
- name: NCCL_IB_HCA
value: ^=mlx5_0,mlx5_5,mlx5_6
......
......@@ -71,7 +71,7 @@ spec:
value: "1"
- name: SGLANG_SET_CPU_AFFINITY
value: "true"
- name: SGL_ENABLE_JIT_DEEPGEMM
- name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1"
- name: NCCL_IB_QPS_PER_CONNECTION
value: "8"
......@@ -224,7 +224,7 @@ spec:
value: "0"
- name: SGLANG_MOONCAKE_TRANS_THREAD
value: "8"
- name: SGL_ENABLE_JIT_DEEPGEMM
- name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1"
- name: SGL_CHUNKED_PREFIX_CACHE_THRESHOLD
value: "0"
......
......@@ -98,7 +98,7 @@ spec:
value: "1"
- name: SGLANG_SET_CPU_AFFINITY
value: "true"
- name: SGL_ENABLE_JIT_DEEPGEMM
- name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1"
- name: NCCL_IB_QPS_PER_CONNECTION
value: "8"
......@@ -257,7 +257,7 @@ spec:
value: "0"
- name: SGLANG_MOONCAKE_TRANS_THREAD
value: "8"
- name: SGL_ENABLE_JIT_DEEPGEMM
- name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1"
- name: SGL_CHUNKED_PREFIX_CACHE_THRESHOLD
value: "0"
......@@ -421,7 +421,7 @@ spec:
value: "true"
- name: SGLANG_MOONCAKE_TRANS_THREAD
value: "16"
- name: SGL_ENABLE_JIT_DEEPGEMM
- name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1"
- name: NCCL_IB_HCA
value: ^=mlx5_0,mlx5_5,mlx5_6
......@@ -560,7 +560,7 @@ spec:
value: "5"
- name: SGLANG_MOONCAKE_TRANS_THREAD
value: "16"
- name: SGL_ENABLE_JIT_DEEPGEMM
- name: SGLANG_ENABLE_JIT_DEEPGEMM
value: "1"
- name: NCCL_IB_HCA
value: ^=mlx5_0,mlx5_5,mlx5_6
......
......@@ -19,6 +19,7 @@ import requests
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.environ import envs
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import ServerArgs
......@@ -28,9 +29,9 @@ from sglang.srt.warmup import warmup
multiprocessing.set_start_method("spawn", force=True)
# Reduce warning
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.set(True)
# Force enable deep gemm
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(True)
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
......
......@@ -180,6 +180,7 @@ class Envs:
SGLANG_EXPERT_LOCATION_UPDATER_CANARY = EnvBool(False)
SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS = EnvBool(False)
SGLANG_LOG_EXPERT_LOCATION_METADATA = EnvBool(False)
SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR = EnvStr("/tmp")
# TBO
SGLANG_TBO_DEBUG = EnvBool(False)
......
......@@ -16,21 +16,20 @@ from __future__ import annotations
import logging
import math
import os
import time
from abc import ABC
from collections import deque
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
import einops
import torch
import torch.distributed
from sglang.srt.environ import envs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var, is_npu
from sglang.srt.utils import Withable, is_npu
_is_npu = is_npu()
......@@ -839,7 +838,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
def _dump_to_file(name, data):
save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))
save_dir = envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.get()
path_output = save_dir / name
logger.info(f"Write expert distribution to {path_output}")
if not save_dir.exists():
......
......@@ -7,11 +7,12 @@ from typing import Dict, List, Tuple
import torch
from tqdm import tqdm
from sglang.srt.environ import envs
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
ENABLE_JIT_DEEPGEMM,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
from sglang.srt.utils import ceil_div, get_bool_env_var
logger = logging.getLogger(__name__)
......@@ -20,12 +21,9 @@ if ENABLE_JIT_DEEPGEMM:
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
)
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get()
_DO_COMPILE_ALL = True
_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
# Force redirect deep_gemm cache_dir
......
import logging
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
from sglang.srt.environ import envs
from sglang.srt.utils import get_device_sm, is_blackwell
logger = logging.getLogger(__name__)
......@@ -15,7 +16,7 @@ def _compute_enable_deep_gemm():
except ImportError:
return False
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get()
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
......
......@@ -5,6 +5,7 @@ from pathlib import Path
from types import SimpleNamespace
import sglang as sgl
from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
......@@ -23,44 +24,43 @@ class _BaseTestDynamicEPLB(CustomTestCase):
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"--enable-eplb",
"--ep-num-redundant-experts",
"4",
"--eplb-rebalance-num-iterations",
"50",
"--expert-distribution-recorder-buffer-size",
"50",
# TODO pr-chain: enable later
# "--enable-expert-distribution-metrics",
# TODO auto determine these flags
"--expert-distribution-recorder-mode",
"stat",
"--ep-dispatch-algorithm",
"static",
*cls.extra_args,
],
env={
"SGL_ENABLE_JIT_DEEPGEMM": "0",
"SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1",
**os.environ,
},
)
with (
envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False),
envs.SGLANG_EXPERT_LOCATION_UPDATER_CANARY.override(True),
):
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"--enable-eplb",
"--ep-num-redundant-experts",
"4",
"--eplb-rebalance-num-iterations",
"50",
"--expert-distribution-recorder-buffer-size",
"50",
# TODO pr-chain: enable later
# "--enable-expert-distribution-metrics",
# TODO auto determine these flags
"--expert-distribution-recorder-mode",
"stat",
"--ep-dispatch-algorithm",
"static",
*cls.extra_args,
],
)
@classmethod
def tearDownClass(cls):
......@@ -89,7 +89,7 @@ class TestDynamicEPLBMultiChunk(_BaseTestDynamicEPLB):
class TestStaticEPLB(CustomTestCase):
def test_save_expert_distribution_and_init_expert_location(self):
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0"
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
with tempfile.TemporaryDirectory() as tmp_dir:
engine_kwargs = dict(
......@@ -108,7 +108,7 @@ class TestStaticEPLB(CustomTestCase):
)
print(f"Action: start engine")
os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir
envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.set(tmp_dir)
engine = sgl.Engine(
**engine_kwargs,
disable_overlap_schedule=True,
......
......@@ -3,6 +3,7 @@ import os
import unittest
from types import SimpleNamespace
from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
......@@ -55,48 +56,45 @@ class TestDPAttn(unittest.TestCase):
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
# Test custom config
"--deepep-config",
json.dumps(
{
"normal_dispatch": {
"num_sms": 20,
"num_max_nvl_chunked_send_tokens": 16,
"num_max_nvl_chunked_recv_tokens": 256,
"num_max_rdma_chunked_send_tokens": 6,
"num_max_rdma_chunked_recv_tokens": 128,
},
"normal_combine": {
"num_sms": 20,
"num_max_nvl_chunked_send_tokens": 6,
"num_max_nvl_chunked_recv_tokens": 256,
"num_max_rdma_chunked_send_tokens": 6,
"num_max_rdma_chunked_recv_tokens": 128,
},
}
),
],
env={
"SGL_ENABLE_JIT_DEEPGEMM": "0",
**os.environ,
},
)
with envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False):
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
# Test custom config
"--deepep-config",
json.dumps(
{
"normal_dispatch": {
"num_sms": 20,
"num_max_nvl_chunked_send_tokens": 16,
"num_max_nvl_chunked_recv_tokens": 256,
"num_max_rdma_chunked_send_tokens": 6,
"num_max_rdma_chunked_recv_tokens": 128,
},
"normal_combine": {
"num_sms": 20,
"num_max_nvl_chunked_send_tokens": 6,
"num_max_nvl_chunked_recv_tokens": 256,
"num_max_rdma_chunked_send_tokens": 6,
"num_max_rdma_chunked_recv_tokens": 128,
},
}
),
],
)
@classmethod
def tearDownClass(cls):
......
import os
import time
import unittest
from types import SimpleNamespace
from sglang.srt.environ import envs
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
......@@ -18,8 +17,7 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
......@@ -90,8 +88,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
......@@ -162,8 +159,7 @@ class TestDisaggregationMooncakeMHAPrefillLargerTP(TestDisaggregationBase):
def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
......@@ -234,8 +230,7 @@ class TestDisaggregationMooncakeMHADecodeLargerTP(TestDisaggregationBase):
def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
......
......@@ -2,6 +2,7 @@ import os
import unittest
from types import SimpleNamespace
from sglang.srt.environ import envs
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
......@@ -16,8 +17,7 @@ class TestDisaggregationDPAttention(TestDisaggregationBase):
def setUpClass(cls):
super().setUpClass()
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
......
......@@ -6,9 +6,9 @@ from pathlib import Path
import requests
import torch
from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_BASE,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
......@@ -32,7 +32,7 @@ class TestExpertDistribution(CustomTestCase):
def _execute_core(self, model_path: str, mode: str = "stat", tp_size: int = 1):
"""Test expert distribution record endpoints"""
with tempfile.TemporaryDirectory() as tmp_dir:
os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir
envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.set(tmp_dir)
process = popen_launch_server(
model_path,
......
import os
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.environ import envs
from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
......@@ -77,14 +77,16 @@ class BaseFlashAttentionTest(CustomTestCase):
def setUpClass(cls):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=cls.get_server_args(),
)
with (
envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.override(False),
envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False),
):
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=cls.get_server_args(),
)
@classmethod
def tearDownClass(cls):
......
......@@ -4,6 +4,7 @@ from types import SimpleNamespace
import requests
from sglang.srt.environ import envs
from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
......@@ -49,18 +50,20 @@ class TestHybridAttnBackendBase(CustomTestCase):
def setUpClass(cls):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
if cls.speculative_decode:
model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
else:
model = cls.model
cls.process = popen_launch_server(
model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=cls.get_server_args(),
)
with (
envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.override(False),
envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False),
):
if cls.speculative_decode:
model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
else:
model = cls.model
cls.process = popen_launch_server(
model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=cls.get_server_args(),
)
@classmethod
def tearDownClass(cls):
......
import os
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
......@@ -47,8 +47,8 @@ class TestNgramSpeculativeDecodingBase(CustomTestCase):
def setUpClass(cls):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.set(False)
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
model = cls.model
cls.process = popen_launch_server(
model,
......
import os
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.environ import envs
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
......@@ -55,8 +55,8 @@ class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
def setUpClass(cls):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.set(False)
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
model = cls.model
cls.process = popen_launch_server(
model,
......
import os
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.environ import envs
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.two_batch_overlap import (
compute_split_seq_index,
......@@ -25,26 +25,26 @@ class TestTwoBatchOverlap(unittest.TestCase):
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"normal",
"--disable-cuda-graph", # DeepEP normal does not support CUDA Graph
"--enable-two-batch-overlap",
],
env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ},
)
with envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False):
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"normal",
"--disable-cuda-graph", # DeepEP normal does not support CUDA Graph
"--enable-two-batch-overlap",
],
)
@classmethod
def tearDownClass(cls):
......@@ -126,26 +126,26 @@ class TestQwen3TwoBatchOverlap(TestTwoBatchOverlap):
cls.model = DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-1234"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"normal",
"--disable-cuda-graph", # DeepEP normal does not support CUDA Graph
"--enable-two-batch-overlap",
],
env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ},
)
with envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False):
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"normal",
"--disable-cuda-graph", # DeepEP normal does not support CUDA Graph
"--enable-two-batch-overlap",
],
)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment