Unverified Commit 1dda8c5e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Return more infos for computing average acceptance length (#3152)

parent 7e097613
......@@ -57,6 +57,7 @@ from sglang.srt.utils import (
assert_pkg_version,
configure_logger,
kill_process_tree,
launch_dummy_health_check_server,
maybe_set_triton_cache_manager,
prepare_model_and_tokenizer,
set_prometheus_multiproc_dir,
......@@ -400,14 +401,16 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
# When using `Engine` as a Python API, we don't want to block here.
return
return None, None
launch_dummy_health_check_server(server_args.host, server_args.port)
for proc in scheduler_procs:
proc.join()
logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
)
return
return None, None
# Launch detokenizer process
detoken_proc = mp.Process(
......
......@@ -22,6 +22,8 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size
)
......@@ -35,7 +37,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
],
tp_rank,
torch.distributed.get_backend(tp_group.device_group),
False,
SYNC_TOKEN_IDS_ACROSS_TP,
False,
False,
False,
......
......@@ -201,6 +201,7 @@ class DetokenizerManager:
prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens,
spec_verify_ct=recv_obj.spec_verify_ct,
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
......
......@@ -354,10 +354,13 @@ class BatchTokenIDOut:
skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool]
no_stop_trim: List[bool]
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
spec_verify_ct: List[int]
# Logprobs
input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int]
......@@ -382,6 +385,7 @@ class BatchStrOut:
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
spec_verify_ct: List[int]
# Logprobs
input_token_logprobs_val: List[float]
......
......@@ -252,7 +252,6 @@ class Req:
# Sampling info
self.sampling_params = sampling_params
self.lora_path = lora_path
self.custom_logit_processor = custom_logit_processor
# Memory pool info
......@@ -300,7 +299,7 @@ class Req:
self.logprob_start_len = 0
self.top_logprobs_num = top_logprobs_num
# Logprobs (return value)
# Logprobs (return values)
self.input_token_logprobs_val: Optional[List[float]] = None
self.input_token_logprobs_idx: Optional[List[int]] = None
self.input_top_logprobs_val: Optional[List[float]] = None
......@@ -329,10 +328,15 @@ class Req:
# Constrained decoding
self.grammar: Optional[BaseGrammarObject] = None
# The number of cached tokens, that were already cached in the KV cache
# The number of cached tokens that were already cached in the KV cache
self.cached_tokens = 0
self.already_computed = 0
# The number of verification forward passes in the speculative decoding.
# This is used to compute the average acceptance length per request.
self.spec_verify_ct = 0
self.lora_path = lora_path
def extend_image_inputs(self, image_inputs):
if self.image_inputs is None:
self.image_inputs = image_inputs
......
......@@ -281,6 +281,7 @@ class Scheduler:
# Print debug info
logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, "
f"context_len={self.model_config.context_len}"
......@@ -408,6 +409,11 @@ class Scheduler:
},
)
# The largest prefill length of a single request
self._largest_prefill_len: int = 0
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
......@@ -1371,6 +1377,7 @@ class Scheduler:
prompt_tokens = []
completion_tokens = []
cached_tokens = []
spec_verify_ct = []
if return_logprob:
input_token_logprobs_val = []
......@@ -1424,6 +1431,9 @@ class Scheduler:
completion_tokens.append(len(req.output_ids))
cached_tokens.append(req.cached_tokens)
if not self.spec_algorithm.is_none():
spec_verify_ct.append(req.spec_verify_ct)
if return_logprob:
input_token_logprobs_val.append(req.input_token_logprobs_val)
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
......@@ -1451,6 +1461,7 @@ class Scheduler:
prompt_tokens,
completion_tokens,
cached_tokens,
spec_verify_ct,
input_token_logprobs_val,
input_token_logprobs_idx,
output_token_logprobs_val,
......
......@@ -785,6 +785,9 @@ class TokenizerManager:
i,
)
if self.server_args.speculative_algorithm:
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
if not isinstance(recv_obj, BatchEmbeddingOut):
meta_info.update(
{
......@@ -809,6 +812,7 @@ class TokenizerManager:
"embedding": recv_obj.embeddings[i],
"meta_info": meta_info,
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reasons[i] is not None
state.event.set()
......
......@@ -38,7 +38,7 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values():
if isinstance(sub, CustomOp):
if reverse:
......@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
else:
# NOTE: Temporarily workaround MoE
if "FusedMoE" in sub.__class__.__name__:
if batch_size == 1:
if num_tokens == 1:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
sub._forward_method = fused_moe_forward_native
......@@ -55,14 +55,14 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
sub._forward_method = sub.forward_native
setattr(sub, "is_torch_compile", True)
if isinstance(sub, torch.nn.Module):
_to_torch(sub, reverse, batch_size)
_to_torch(sub, reverse, num_tokens)
@contextmanager
def patch_model(
model: torch.nn.Module,
enable_compile: bool,
batch_size: int,
num_tokens: int,
tp_group: GroupCoordinator,
):
"""Patch the model to make it compatible with with torch.compile"""
......@@ -70,7 +70,7 @@ def patch_model(
try:
if enable_compile:
_to_torch(model, reverse=False, batch_size=batch_size)
_to_torch(model, reverse=False, num_tokens=num_tokens)
backup_ca_comm = tp_group.ca_comm
# Use custom-allreduce here.
# We found the custom allreduce is much faster than the built-in allreduce in torch,
......@@ -85,7 +85,7 @@ def patch_model(
yield model.forward
finally:
if enable_compile:
_to_torch(model, reverse=True, batch_size=batch_size)
_to_torch(model, reverse=True, num_tokens=num_tokens)
tp_group.ca_comm = backup_ca_comm
......@@ -283,8 +283,8 @@ class CudaGraphRunner:
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
bs,
self.model_runner.tp_group,
num_tokens=bs * self.num_tokens_per_bs,
tp_group=self.model_runner.tp_group,
) as forward:
(
graph,
......
......@@ -603,6 +603,7 @@ class EagleVerifyInput(SpecInfo):
if not req.finished():
new_accept_index.extend(new_accept_index_)
unfinished_index.append(i)
req.spec_verify_ct += 1
accept_length = (accept_index != -1).sum(dim=1) - 1
accept_index = accept_index[accept_index != -1]
......
......@@ -14,6 +14,7 @@
"""Common utilities."""
import base64
import ctypes
import dataclasses
import io
import ipaddress
......@@ -29,6 +30,7 @@ import shutil
import signal
import socket
import subprocess
import sys
import tempfile
import time
import warnings
......@@ -59,7 +61,6 @@ from triton.runtime.cache import (
default_dump_dir,
default_override_dir,
)
from uvicorn.config import LOGGING_CONFIG
logger = logging.getLogger(__name__)
......@@ -1366,7 +1367,33 @@ def nullable_str(val: str):
return val
def pyspy_dump_schedulers():
"""py-spy dump on all scheduler in a local node."""
try:
pid = psutil.Process().pid
# Command to run py-spy with the PID
cmd = f"py-spy dump --pid {pid}"
result = subprocess.run(
cmd, shell=True, capture_output=True, text=True, check=True
)
logger.info(f"Profile for PID {pid}:\n{result.stdout}")
except subprocess.CalledProcessError as e:
logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}")
def kill_itself_when_parent_died():
if sys.platform == "linux":
# sigkill this process when parent worker manager dies
PR_SET_PDEATHSIG = 1
libc = ctypes.CDLL("libc.so.6")
libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL)
else:
logger.warninig("kill_itself_when_parent_died is only supported in linux.")
def set_uvicorn_logging_configs():
from uvicorn.config import LOGGING_CONFIG
LOGGING_CONFIG["formatters"]["default"][
"fmt"
] = "[%(asctime)s] %(levelprefix)s %(message)s"
......@@ -1449,3 +1476,28 @@ def rank0_print(msg: str):
if get_tensor_model_parallel_rank() == 0:
print(msg, flush=True)
def launch_dummy_health_check_server(host, port):
import uvicorn
from fastapi import FastAPI, Response
app = FastAPI()
@app.get("/health")
async def health():
"""Check the health of the http server."""
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate():
"""Check the health of the http server."""
return Response(status_code=200)
uvicorn.run(
app,
host=host,
port=port,
timeout_keep_alive=5,
loop="uvloop",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment