Unverified Commit cd7e1bd5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Sync code and test CI; rename some env vars (#11686)

parent 729b7edf
...@@ -82,7 +82,7 @@ SGLang supports various environment variables that can be used to configure its ...@@ -82,7 +82,7 @@ SGLang supports various environment variables that can be used to configure its
| Environment Variable | Description | Default Value | | Environment Variable | Description | Default Value |
| --- | --- | --- | | --- | --- | --- |
| `SGLANG_IS_IN_CI` | Indicates if running in CI environment | `false` | | `SGLANG_IS_IN_CI` | Indicates if running in CI environment | `false` |
| `SGLANG_AMD_CI` | Indicates running in AMD CI environment | `0` | | `SGLANG_IS_IN_CI_AMD` | Indicates running in AMD CI environment | `0` |
| `SGLANG_TEST_RETRACT` | Enable retract decode testing | `false` | | `SGLANG_TEST_RETRACT` | Enable retract decode testing | `false` |
| `SGLANG_RECORD_STEP_TIME` | Record step time for profiling | `false` | | `SGLANG_RECORD_STEP_TIME` | Record step time for profiling | `false` |
| `SGLANG_TEST_REQUEST_TIME_STATS` | Test request time statistics | `false` | | `SGLANG_TEST_REQUEST_TIME_STATS` | Test request time statistics | `false` |
......
...@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ...@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
is_weak_contiguous, is_weak_contiguous,
) )
from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip, log_info_on_rank0
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -301,11 +301,11 @@ class CustomAllreduce: ...@@ -301,11 +301,11 @@ class CustomAllreduce:
if _is_hip: if _is_hip:
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset)) log_info_on_rank0(logger, f"Registering {len(offset)} cuda graph addresses")
ops.register_graph_buffers(self._ptr, handles, offsets) ops.register_graph_buffers(self._ptr, handles, offsets)
else: else:
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset)) log_info_on_rank0(logger, f"Registering {len(offset)} cuda graph addresses")
# We cannot directly use `dist.all_gather_object` here # We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode. # because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details. # see https://github.com/pytorch/pytorch/issues/126032 for details.
......
...@@ -113,7 +113,7 @@ class Envs: ...@@ -113,7 +113,7 @@ class Envs:
# Test & Debug # Test & Debug
SGLANG_IS_IN_CI = EnvBool(False) SGLANG_IS_IN_CI = EnvBool(False)
SGLANG_AMD_CI = EnvBool(False) SGLANG_IS_IN_CI_AMD = EnvBool(False)
SGLANG_TEST_RETRACT = EnvBool(False) SGLANG_TEST_RETRACT = EnvBool(False)
SGLANG_SET_CPU_AFFINITY = EnvBool(False) SGLANG_SET_CPU_AFFINITY = EnvBool(False)
SGLANG_PROFILE_WITH_STACK = EnvBool(True) SGLANG_PROFILE_WITH_STACK = EnvBool(True)
...@@ -197,12 +197,12 @@ class Envs: ...@@ -197,12 +197,12 @@ class Envs:
# sgl-kernel # sgl-kernel
SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK = EnvBool(False) SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK = EnvBool(False)
# vLLM dependencies # vLLM dependencies (TODO: they have been deprecated, we can remove them safely)
USE_VLLM_CUSTOM_ALLREDUCE = EnvBool(False) USE_VLLM_CUSTOM_ALLREDUCE = EnvBool(False)
USE_VLLM_CUTLASS_W8A8_FP8_KERNEL = EnvBool(False) USE_VLLM_CUTLASS_W8A8_FP8_KERNEL = EnvBool(False)
USE_TRITON_W8A8_FP8_KERNEL = EnvBool(False) USE_TRITON_W8A8_FP8_KERNEL = EnvBool(False)
RETURN_ORIGINAL_LOGPROB = EnvBool(False) SGLANG_RETURN_ORIGINAL_LOGPROB = EnvBool(False)
SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN = EnvBool(False) SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN = EnvBool(False)
SGLANG_MOE_PADDING = EnvBool(False) SGLANG_MOE_PADDING = EnvBool(False)
SGLANG_CUTLASS_MOE = EnvBool(False) SGLANG_CUTLASS_MOE = EnvBool(False)
......
...@@ -65,7 +65,7 @@ class LogitsProcessorOutput: ...@@ -65,7 +65,7 @@ class LogitsProcessorOutput:
hidden_states: Optional[torch.Tensor] = None hidden_states: Optional[torch.Tensor] = None
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
# he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature. # he log probs of output tokens, if SGLANG_RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
next_token_logprobs: Optional[torch.Tensor] = None next_token_logprobs: Optional[torch.Tensor] = None
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k] # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
next_token_top_logprobs_val: Optional[List] = None next_token_top_logprobs_val: Optional[List] = None
......
...@@ -13,6 +13,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -13,6 +13,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size, get_attention_dp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.utils import log_info_on_rank0
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -159,8 +160,9 @@ def get_moe_a2a_backend() -> MoeA2ABackend: ...@@ -159,8 +160,9 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
def get_moe_runner_backend() -> MoeRunnerBackend: def get_moe_runner_backend() -> MoeRunnerBackend:
global MOE_RUNNER_BACKEND global MOE_RUNNER_BACKEND
if MOE_RUNNER_BACKEND is None: if MOE_RUNNER_BACKEND is None:
logger.warning( log_info_on_rank0(
"MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected" logger,
"MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected",
) )
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
return MOE_RUNNER_BACKEND return MOE_RUNNER_BACKEND
......
...@@ -27,7 +27,7 @@ if is_cuda(): ...@@ -27,7 +27,7 @@ if is_cuda():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP") SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB") SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
class Sampler(nn.Module): class Sampler(nn.Module):
...@@ -99,7 +99,7 @@ class Sampler(nn.Module): ...@@ -99,7 +99,7 @@ class Sampler(nn.Module):
) )
# If requested, cache probabilities from original logits before temperature scaling. # If requested, cache probabilities from original logits before temperature scaling.
if return_logprob and RETURN_ORIGINAL_LOGPROB: if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB:
probs_without_temp_scaling = torch.softmax(logits, dim=-1) probs_without_temp_scaling = torch.softmax(logits, dim=-1)
# Post process logits # Post process logits
...@@ -149,7 +149,7 @@ class Sampler(nn.Module): ...@@ -149,7 +149,7 @@ class Sampler(nn.Module):
if return_logprob: if return_logprob:
# clamp to avoid -inf # clamp to avoid -inf
if RETURN_ORIGINAL_LOGPROB: if SGLANG_RETURN_ORIGINAL_LOGPROB:
logprobs = torch.log(probs_without_temp_scaling).clamp( logprobs = torch.log(probs_without_temp_scaling).clamp(
min=torch.finfo(probs_without_temp_scaling.dtype).min min=torch.finfo(probs_without_temp_scaling.dtype).min
) )
......
...@@ -286,8 +286,6 @@ class ModelRunner: ...@@ -286,8 +286,6 @@ class ModelRunner:
self.forward_pass_id = 0 self.forward_pass_id = 0
# Apply the rank zero filter to logger # Apply the rank zero filter to logger
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
logger.addFilter(RankZeroFilter(tp_rank == 0))
if server_args.show_time_cost: if server_args.show_time_cost:
enable_show_time_cost() enable_show_time_cost()
...@@ -577,8 +575,9 @@ class ModelRunner: ...@@ -577,8 +575,9 @@ class ModelRunner:
server_args.attention_backend = "ascend" server_args.attention_backend = "ascend"
else: else:
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
logger.info( log_info_on_rank0(
f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default." logger,
f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default.",
) )
elif self.use_mla_backend: elif self.use_mla_backend:
if server_args.device != "cpu": if server_args.device != "cpu":
......
...@@ -38,7 +38,7 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank ...@@ -38,7 +38,7 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.dp_attention import get_attention_tp_rank from sglang.srt.layers.dp_attention import get_attention_tp_rank
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
from sglang.srt.utils import find_local_repo_dir, print_warning_once from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once
from sglang.utils import is_in_ci from sglang.utils import is_in_ci
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -429,7 +429,7 @@ def download_weights_from_hf( ...@@ -429,7 +429,7 @@ def download_weights_from_hf(
allow_patterns = [pattern] allow_patterns = [pattern]
break break
logger.info("Using model weights format %s", allow_patterns) log_info_on_rank0(logger, f"Using model weights format {allow_patterns}")
# Use file lock to prevent multiple processes from # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time. # downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir): with get_lock(model_name_or_path, cache_dir):
......
...@@ -2484,13 +2484,7 @@ class ServerArgs: ...@@ -2484,13 +2484,7 @@ class ServerArgs:
default=ServerArgs.mamba_full_memory_ratio, default=ServerArgs.mamba_full_memory_ratio,
help="The ratio of mamba state memory to full kv cache memory.", help="The ratio of mamba state memory to full kv cache memory.",
) )
# Args for multi-item-scoring
parser.add_argument(
"--multi-item-scoring-delimiter",
type=int,
default=ServerArgs.multi_item_scoring_delimiter,
help="Delimiter token ID for multi-item scoring. Used to combine Query and Items into a single sequence: Query<delimiter>Item1<delimiter>Item2<delimiter>... This enables efficient batch processing of multiple items against a single query.",
)
# Hierarchical cache # Hierarchical cache
parser.add_argument( parser.add_argument(
"--enable-hierarchical-cache", "--enable-hierarchical-cache",
...@@ -2636,6 +2630,14 @@ class ServerArgs: ...@@ -2636,6 +2630,14 @@ class ServerArgs:
help="Mode of offloading.", help="Mode of offloading.",
) )
# Args for multi-item-scoring
parser.add_argument(
"--multi-item-scoring-delimiter",
type=int,
default=ServerArgs.multi_item_scoring_delimiter,
help="Delimiter token ID for multi-item scoring. Used to combine Query and Items into a single sequence: Query<delimiter>Item1<delimiter>Item2<delimiter>... This enables efficient batch processing of multiple items against a single query.",
)
# Optimization/debug options # Optimization/debug options
parser.add_argument( parser.add_argument(
"--disable-radix-cache", "--disable-radix-cache",
......
...@@ -64,7 +64,7 @@ if is_cuda(): ...@@ -64,7 +64,7 @@ if is_cuda():
from sgl_kernel import segment_packbits from sgl_kernel import segment_packbits
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB") SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
@contextmanager @contextmanager
...@@ -741,7 +741,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -741,7 +741,7 @@ class EAGLEWorker(TpModelWorker):
# acceptance indices are the indices in a "flattened" batch. # acceptance indices are the indices in a "flattened" batch.
# dividing it to num_draft_tokens will yield the actual batch index. # dividing it to num_draft_tokens will yield the actual batch index.
temperatures = temperatures[accepted_indices // num_draft_tokens] temperatures = temperatures[accepted_indices // num_draft_tokens]
if RETURN_ORIGINAL_LOGPROB: if SGLANG_RETURN_ORIGINAL_LOGPROB:
logprobs = torch.nn.functional.log_softmax( logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1 logits_output.next_token_logits, dim=-1
) )
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import IntEnum, auto from enum import IntEnum, auto
from functools import lru_cache
from typing import List, Tuple from typing import List, Tuple
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
...@@ -27,6 +28,7 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -27,6 +28,7 @@ class SpeculativeAlgorithm(IntEnum):
def is_ngram(self): def is_ngram(self):
return self == SpeculativeAlgorithm.NGRAM return self == SpeculativeAlgorithm.NGRAM
@lru_cache(maxsize=None)
@staticmethod @staticmethod
def from_string(name: str): def from_string(name: str):
name_map = { name_map = {
......
...@@ -15,7 +15,7 @@ if is_cuda(): ...@@ -15,7 +15,7 @@ if is_cuda():
from sgl_kernel import segment_packbits from sgl_kernel import segment_packbits
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB") SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
@contextmanager @contextmanager
......
...@@ -3,6 +3,8 @@ Run one test prompt. ...@@ -3,6 +3,8 @@ Run one test prompt.
Usage: Usage:
python3 -m sglang.test.send_one python3 -m sglang.test.send_one
python3 -m sglang.test.send_one --profile --profile-steps 5
python3 -m sglang.test.send_one --profile --profile-by-stage
""" """
import argparse import argparse
...@@ -11,6 +13,8 @@ import json ...@@ -11,6 +13,8 @@ import json
import requests import requests
from sglang.profiler import run_profile
@dataclasses.dataclass @dataclasses.dataclass
class BenchArgs: class BenchArgs:
...@@ -29,6 +33,9 @@ class BenchArgs: ...@@ -29,6 +33,9 @@ class BenchArgs:
image: bool = False image: bool = False
many_images: bool = False many_images: bool = False
stream: bool = False stream: bool = False
profile: bool = False
profile_steps: int = 3
profile_by_stage: bool = False
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -51,6 +58,11 @@ class BenchArgs: ...@@ -51,6 +58,11 @@ class BenchArgs:
parser.add_argument("--image", action="store_true") parser.add_argument("--image", action="store_true")
parser.add_argument("--many-images", action="store_true") parser.add_argument("--many-images", action="store_true")
parser.add_argument("--stream", action="store_true") parser.add_argument("--stream", action="store_true")
parser.add_argument("--profile", action="store_true")
parser.add_argument(
"--profile-steps", type=int, default=BenchArgs.profile_steps
)
parser.add_argument("--profile-by-stage", action="store_true")
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
...@@ -59,6 +71,8 @@ class BenchArgs: ...@@ -59,6 +71,8 @@ class BenchArgs:
def send_one_prompt(args): def send_one_prompt(args):
base_url = f"http://{args.host}:{args.port}"
if args.image: if args.image:
args.prompt = ( args.prompt = (
"Human: Describe this image in a very short sentence.\n\nAssistant:" "Human: Describe this image in a very short sentence.\n\nAssistant:"
...@@ -108,8 +122,20 @@ def send_one_prompt(args): ...@@ -108,8 +122,20 @@ def send_one_prompt(args):
"stream": args.stream, "stream": args.stream,
} }
# Run profiler if requested
if args.profile:
print(f"Running profiler with {args.profile_steps} steps...")
run_profile(
base_url,
args.profile_steps,
["CPU", "GPU"],
None,
None,
args.profile_by_stage,
)
response = requests.post( response = requests.post(
f"http://{args.host}:{args.port}/generate", f"{base_url}/generate",
json=json_data, json=json_data,
stream=args.stream, stream=args.stream,
) )
......
...@@ -126,7 +126,7 @@ def is_in_ci(): ...@@ -126,7 +126,7 @@ def is_in_ci():
def is_in_amd_ci(): def is_in_amd_ci():
"""Return whether it is in an AMD CI runner.""" """Return whether it is in an AMD CI runner."""
return get_bool_env_var("SGLANG_AMD_CI") return get_bool_env_var("SGLANG_IS_IN_CI_AMD")
def _use_cached_default_models(model_repo: str): def _use_cached_default_models(model_repo: str):
......
...@@ -15,7 +15,7 @@ fi ...@@ -15,7 +15,7 @@ fi
WORKDIR="/sglang-checkout/test/srt" WORKDIR="/sglang-checkout/test/srt"
declare -A ENV_MAP=( declare -A ENV_MAP=(
[SGLANG_AMD_CI]=1 [SGLANG_IS_IN_CI_AMD]=1
[SGLANG_IS_IN_CI]=1 [SGLANG_IS_IN_CI]=1
[SGLANG_USE_AITER]=1 [SGLANG_USE_AITER]=1
) )
......
...@@ -13,6 +13,7 @@ echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-}" ...@@ -13,6 +13,7 @@ echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-}"
# Clear torch compilation cache # Clear torch compilation cache
python3 -c 'import os, shutil, tempfile, getpass; cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") or os.path.join(tempfile.gettempdir(), "torchinductor_" + getpass.getuser()); shutil.rmtree(cache_dir, ignore_errors=True)' python3 -c 'import os, shutil, tempfile, getpass; cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") or os.path.join(tempfile.gettempdir(), "torchinductor_" + getpass.getuser()); shutil.rmtree(cache_dir, ignore_errors=True)'
rm -rf /root/.cache/flashinfer
# Install apt packages # Install apt packages
apt install -y git libnuma-dev apt install -y git libnuma-dev
......
...@@ -125,8 +125,8 @@ class TestOriginalLogprob(unittest.TestCase): ...@@ -125,8 +125,8 @@ class TestOriginalLogprob(unittest.TestCase):
vocab_size = self.tokenizer.vocab_size vocab_size = self.tokenizer.vocab_size
for env_val in ["True", "False"]: for env_val in ["True", "False"]:
with self.subTest(return_original_logprob=env_val): with self.subTest(SGLANG_RETURN_ORIGINAL_LOGPROB=env_val):
os.environ["RETURN_ORIGINAL_LOGPROB"] = env_val os.environ["SGLANG_RETURN_ORIGINAL_LOGPROB"] = env_val
# ----- SGLang side ----- # ----- SGLang side -----
sgl_engine = sgl.Engine( sgl_engine = sgl.Engine(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment