Unverified Commit 1dda8c5e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Return more infos for computing average acceptance length (#3152)

parent 7e097613
...@@ -57,6 +57,7 @@ from sglang.srt.utils import ( ...@@ -57,6 +57,7 @@ from sglang.srt.utils import (
assert_pkg_version, assert_pkg_version,
configure_logger, configure_logger,
kill_process_tree, kill_process_tree,
launch_dummy_health_check_server,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
prepare_model_and_tokenizer, prepare_model_and_tokenizer,
set_prometheus_multiproc_dir, set_prometheus_multiproc_dir,
...@@ -400,14 +401,16 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic ...@@ -400,14 +401,16 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
# When using `Engine` as a Python API, we don't want to block here. # When using `Engine` as a Python API, we don't want to block here.
return return None, None
launch_dummy_health_check_server(server_args.host, server_args.port)
for proc in scheduler_procs: for proc in scheduler_procs:
proc.join() proc.join()
logger.error( logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
) )
return return None, None
# Launch detokenizer process # Launch detokenizer process
detoken_proc = mp.Process( detoken_proc = mp.Process(
......
...@@ -22,6 +22,8 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si ...@@ -22,6 +22,8 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size): def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info( _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size enable_dp_attention, tp_rank, tp_size, dp_size
) )
...@@ -35,7 +37,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size): ...@@ -35,7 +37,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
], ],
tp_rank, tp_rank,
torch.distributed.get_backend(tp_group.device_group), torch.distributed.get_backend(tp_group.device_group),
False, SYNC_TOKEN_IDS_ACROSS_TP,
False, False,
False, False,
False, False,
......
...@@ -201,6 +201,7 @@ class DetokenizerManager: ...@@ -201,6 +201,7 @@ class DetokenizerManager:
prompt_tokens=recv_obj.prompt_tokens, prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens, completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens, cached_tokens=recv_obj.cached_tokens,
spec_verify_ct=recv_obj.spec_verify_ct,
input_token_logprobs_val=recv_obj.input_token_logprobs_val, input_token_logprobs_val=recv_obj.input_token_logprobs_val,
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
output_token_logprobs_val=recv_obj.output_token_logprobs_val, output_token_logprobs_val=recv_obj.output_token_logprobs_val,
......
...@@ -354,10 +354,13 @@ class BatchTokenIDOut: ...@@ -354,10 +354,13 @@ class BatchTokenIDOut:
skip_special_tokens: List[bool] skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool] spaces_between_special_tokens: List[bool]
no_stop_trim: List[bool] no_stop_trim: List[bool]
# Token counts # Token counts
prompt_tokens: List[int] prompt_tokens: List[int]
completion_tokens: List[int] completion_tokens: List[int]
cached_tokens: List[int] cached_tokens: List[int]
spec_verify_ct: List[int]
# Logprobs # Logprobs
input_token_logprobs_val: List[float] input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int] input_token_logprobs_idx: List[int]
...@@ -382,6 +385,7 @@ class BatchStrOut: ...@@ -382,6 +385,7 @@ class BatchStrOut:
prompt_tokens: List[int] prompt_tokens: List[int]
completion_tokens: List[int] completion_tokens: List[int]
cached_tokens: List[int] cached_tokens: List[int]
spec_verify_ct: List[int]
# Logprobs # Logprobs
input_token_logprobs_val: List[float] input_token_logprobs_val: List[float]
......
...@@ -252,7 +252,6 @@ class Req: ...@@ -252,7 +252,6 @@ class Req:
# Sampling info # Sampling info
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.lora_path = lora_path
self.custom_logit_processor = custom_logit_processor self.custom_logit_processor = custom_logit_processor
# Memory pool info # Memory pool info
...@@ -300,7 +299,7 @@ class Req: ...@@ -300,7 +299,7 @@ class Req:
self.logprob_start_len = 0 self.logprob_start_len = 0
self.top_logprobs_num = top_logprobs_num self.top_logprobs_num = top_logprobs_num
# Logprobs (return value) # Logprobs (return values)
self.input_token_logprobs_val: Optional[List[float]] = None self.input_token_logprobs_val: Optional[List[float]] = None
self.input_token_logprobs_idx: Optional[List[int]] = None self.input_token_logprobs_idx: Optional[List[int]] = None
self.input_top_logprobs_val: Optional[List[float]] = None self.input_top_logprobs_val: Optional[List[float]] = None
...@@ -329,10 +328,15 @@ class Req: ...@@ -329,10 +328,15 @@ class Req:
# Constrained decoding # Constrained decoding
self.grammar: Optional[BaseGrammarObject] = None self.grammar: Optional[BaseGrammarObject] = None
# The number of cached tokens, that were already cached in the KV cache # The number of cached tokens that were already cached in the KV cache
self.cached_tokens = 0 self.cached_tokens = 0
self.already_computed = 0 self.already_computed = 0
# The number of verification forward passes in the speculative decoding.
# This is used to compute the average acceptance length per request.
self.spec_verify_ct = 0
self.lora_path = lora_path
def extend_image_inputs(self, image_inputs): def extend_image_inputs(self, image_inputs):
if self.image_inputs is None: if self.image_inputs is None:
self.image_inputs = image_inputs self.image_inputs = image_inputs
......
...@@ -281,6 +281,7 @@ class Scheduler: ...@@ -281,6 +281,7 @@ class Scheduler:
# Print debug info # Print debug info
logger.info( logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_total_num_tokens={self.max_total_num_tokens}, "
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, " f"max_running_requests={self.max_running_requests}, "
f"context_len={self.model_config.context_len}" f"context_len={self.model_config.context_len}"
...@@ -408,6 +409,11 @@ class Scheduler: ...@@ -408,6 +409,11 @@ class Scheduler:
}, },
) )
# The largest prefill length of a single request
self._largest_prefill_len: int = 0
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
# Init request dispatcher # Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher( self._request_dispatcher = TypeBasedDispatcher(
[ [
...@@ -1371,6 +1377,7 @@ class Scheduler: ...@@ -1371,6 +1377,7 @@ class Scheduler:
prompt_tokens = [] prompt_tokens = []
completion_tokens = [] completion_tokens = []
cached_tokens = [] cached_tokens = []
spec_verify_ct = []
if return_logprob: if return_logprob:
input_token_logprobs_val = [] input_token_logprobs_val = []
...@@ -1424,6 +1431,9 @@ class Scheduler: ...@@ -1424,6 +1431,9 @@ class Scheduler:
completion_tokens.append(len(req.output_ids)) completion_tokens.append(len(req.output_ids))
cached_tokens.append(req.cached_tokens) cached_tokens.append(req.cached_tokens)
if not self.spec_algorithm.is_none():
spec_verify_ct.append(req.spec_verify_ct)
if return_logprob: if return_logprob:
input_token_logprobs_val.append(req.input_token_logprobs_val) input_token_logprobs_val.append(req.input_token_logprobs_val)
input_token_logprobs_idx.append(req.input_token_logprobs_idx) input_token_logprobs_idx.append(req.input_token_logprobs_idx)
...@@ -1451,6 +1461,7 @@ class Scheduler: ...@@ -1451,6 +1461,7 @@ class Scheduler:
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
cached_tokens, cached_tokens,
spec_verify_ct,
input_token_logprobs_val, input_token_logprobs_val,
input_token_logprobs_idx, input_token_logprobs_idx,
output_token_logprobs_val, output_token_logprobs_val,
......
...@@ -785,6 +785,9 @@ class TokenizerManager: ...@@ -785,6 +785,9 @@ class TokenizerManager:
i, i,
) )
if self.server_args.speculative_algorithm:
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
if not isinstance(recv_obj, BatchEmbeddingOut): if not isinstance(recv_obj, BatchEmbeddingOut):
meta_info.update( meta_info.update(
{ {
...@@ -809,6 +812,7 @@ class TokenizerManager: ...@@ -809,6 +812,7 @@ class TokenizerManager:
"embedding": recv_obj.embeddings[i], "embedding": recv_obj.embeddings[i],
"meta_info": meta_info, "meta_info": meta_info,
} }
state.out_list.append(out_dict) state.out_list.append(out_dict)
state.finished = recv_obj.finished_reasons[i] is not None state.finished = recv_obj.finished_reasons[i] is not None
state.event.set() state.event.set()
......
...@@ -38,7 +38,7 @@ if TYPE_CHECKING: ...@@ -38,7 +38,7 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values(): for sub in model._modules.values():
if isinstance(sub, CustomOp): if isinstance(sub, CustomOp):
if reverse: if reverse:
...@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): ...@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
else: else:
# NOTE: Temporarily workaround MoE # NOTE: Temporarily workaround MoE
if "FusedMoE" in sub.__class__.__name__: if "FusedMoE" in sub.__class__.__name__:
if batch_size == 1: if num_tokens == 1:
# The performance of torch.compile on this layer is not always good when bs > 1, # The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1 # so we decide to only use torch.compile when bs =1
sub._forward_method = fused_moe_forward_native sub._forward_method = fused_moe_forward_native
...@@ -55,14 +55,14 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): ...@@ -55,14 +55,14 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
sub._forward_method = sub.forward_native sub._forward_method = sub.forward_native
setattr(sub, "is_torch_compile", True) setattr(sub, "is_torch_compile", True)
if isinstance(sub, torch.nn.Module): if isinstance(sub, torch.nn.Module):
_to_torch(sub, reverse, batch_size) _to_torch(sub, reverse, num_tokens)
@contextmanager @contextmanager
def patch_model( def patch_model(
model: torch.nn.Module, model: torch.nn.Module,
enable_compile: bool, enable_compile: bool,
batch_size: int, num_tokens: int,
tp_group: GroupCoordinator, tp_group: GroupCoordinator,
): ):
"""Patch the model to make it compatible with with torch.compile""" """Patch the model to make it compatible with with torch.compile"""
...@@ -70,7 +70,7 @@ def patch_model( ...@@ -70,7 +70,7 @@ def patch_model(
try: try:
if enable_compile: if enable_compile:
_to_torch(model, reverse=False, batch_size=batch_size) _to_torch(model, reverse=False, num_tokens=num_tokens)
backup_ca_comm = tp_group.ca_comm backup_ca_comm = tp_group.ca_comm
# Use custom-allreduce here. # Use custom-allreduce here.
# We found the custom allreduce is much faster than the built-in allreduce in torch, # We found the custom allreduce is much faster than the built-in allreduce in torch,
...@@ -85,7 +85,7 @@ def patch_model( ...@@ -85,7 +85,7 @@ def patch_model(
yield model.forward yield model.forward
finally: finally:
if enable_compile: if enable_compile:
_to_torch(model, reverse=True, batch_size=batch_size) _to_torch(model, reverse=True, num_tokens=num_tokens)
tp_group.ca_comm = backup_ca_comm tp_group.ca_comm = backup_ca_comm
...@@ -283,8 +283,8 @@ class CudaGraphRunner: ...@@ -283,8 +283,8 @@ class CudaGraphRunner:
with patch_model( with patch_model(
self.model_runner.model, self.model_runner.model,
bs in self.compile_bs, bs in self.compile_bs,
bs, num_tokens=bs * self.num_tokens_per_bs,
self.model_runner.tp_group, tp_group=self.model_runner.tp_group,
) as forward: ) as forward:
( (
graph, graph,
......
...@@ -603,6 +603,7 @@ class EagleVerifyInput(SpecInfo): ...@@ -603,6 +603,7 @@ class EagleVerifyInput(SpecInfo):
if not req.finished(): if not req.finished():
new_accept_index.extend(new_accept_index_) new_accept_index.extend(new_accept_index_)
unfinished_index.append(i) unfinished_index.append(i)
req.spec_verify_ct += 1
accept_length = (accept_index != -1).sum(dim=1) - 1 accept_length = (accept_index != -1).sum(dim=1) - 1
accept_index = accept_index[accept_index != -1] accept_index = accept_index[accept_index != -1]
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Common utilities.""" """Common utilities."""
import base64 import base64
import ctypes
import dataclasses import dataclasses
import io import io
import ipaddress import ipaddress
...@@ -29,6 +30,7 @@ import shutil ...@@ -29,6 +30,7 @@ import shutil
import signal import signal
import socket import socket
import subprocess import subprocess
import sys
import tempfile import tempfile
import time import time
import warnings import warnings
...@@ -59,7 +61,6 @@ from triton.runtime.cache import ( ...@@ -59,7 +61,6 @@ from triton.runtime.cache import (
default_dump_dir, default_dump_dir,
default_override_dir, default_override_dir,
) )
from uvicorn.config import LOGGING_CONFIG
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1366,7 +1367,33 @@ def nullable_str(val: str): ...@@ -1366,7 +1367,33 @@ def nullable_str(val: str):
return val return val
def pyspy_dump_schedulers():
"""py-spy dump on all scheduler in a local node."""
try:
pid = psutil.Process().pid
# Command to run py-spy with the PID
cmd = f"py-spy dump --pid {pid}"
result = subprocess.run(
cmd, shell=True, capture_output=True, text=True, check=True
)
logger.info(f"Profile for PID {pid}:\n{result.stdout}")
except subprocess.CalledProcessError as e:
logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}")
def kill_itself_when_parent_died():
if sys.platform == "linux":
# sigkill this process when parent worker manager dies
PR_SET_PDEATHSIG = 1
libc = ctypes.CDLL("libc.so.6")
libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL)
else:
logger.warninig("kill_itself_when_parent_died is only supported in linux.")
def set_uvicorn_logging_configs(): def set_uvicorn_logging_configs():
from uvicorn.config import LOGGING_CONFIG
LOGGING_CONFIG["formatters"]["default"][ LOGGING_CONFIG["formatters"]["default"][
"fmt" "fmt"
] = "[%(asctime)s] %(levelprefix)s %(message)s" ] = "[%(asctime)s] %(levelprefix)s %(message)s"
...@@ -1449,3 +1476,28 @@ def rank0_print(msg: str): ...@@ -1449,3 +1476,28 @@ def rank0_print(msg: str):
if get_tensor_model_parallel_rank() == 0: if get_tensor_model_parallel_rank() == 0:
print(msg, flush=True) print(msg, flush=True)
def launch_dummy_health_check_server(host, port):
import uvicorn
from fastapi import FastAPI, Response
app = FastAPI()
@app.get("/health")
async def health():
"""Check the health of the http server."""
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate():
"""Check the health of the http server."""
return Response(status_code=200)
uvicorn.run(
app,
host=host,
port=port,
timeout_keep_alive=5,
loop="uvloop",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment