Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1dda8c5e
Unverified
Commit
1dda8c5e
authored
Jan 26, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 26, 2025
Browse files
Return more infos for computing average acceptance length (#3152)
parent
7e097613
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
97 additions
and
15 deletions
+97
-15
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+5
-2
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+3
-1
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+1
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+4
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+11
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+4
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+8
-8
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+1
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+53
-1
No files found.
python/sglang/srt/entrypoints/engine.py
View file @
1dda8c5e
...
...
@@ -57,6 +57,7 @@ from sglang.srt.utils import (
assert_pkg_version
,
configure_logger
,
kill_process_tree
,
launch_dummy_health_check_server
,
maybe_set_triton_cache_manager
,
prepare_model_and_tokenizer
,
set_prometheus_multiproc_dir
,
...
...
@@ -400,14 +401,16 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
if
os
.
getenv
(
"SGLANG_BLOCK_NONZERO_RANK_CHILDREN"
)
==
"0"
:
# When using `Engine` as a Python API, we don't want to block here.
return
return
None
,
None
launch_dummy_health_check_server
(
server_args
.
host
,
server_args
.
port
)
for
proc
in
scheduler_procs
:
proc
.
join
()
logger
.
error
(
f
"Scheduler or DataParallelController
{
proc
.
pid
}
terminated with
{
proc
.
exitcode
}
"
)
return
return
None
,
None
# Launch detokenizer process
detoken_proc
=
mp
.
Process
(
...
...
python/sglang/srt/layers/dp_attention.py
View file @
1dda8c5e
...
...
@@ -22,6 +22,8 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
def
initialize_dp_attention
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
):
global
_ATTN_TP_GROUP
,
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_DP_RANK
,
_DP_SIZE
from
sglang.srt.layers.sampler
import
SYNC_TOKEN_IDS_ACROSS_TP
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_DP_RANK
=
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
)
...
...
@@ -35,7 +37,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
],
tp_rank
,
torch
.
distributed
.
get_backend
(
tp_group
.
device_group
),
False
,
SYNC_TOKEN_IDS_ACROSS_TP
,
False
,
False
,
False
,
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
1dda8c5e
...
...
@@ -201,6 +201,7 @@ class DetokenizerManager:
prompt_tokens
=
recv_obj
.
prompt_tokens
,
completion_tokens
=
recv_obj
.
completion_tokens
,
cached_tokens
=
recv_obj
.
cached_tokens
,
spec_verify_ct
=
recv_obj
.
spec_verify_ct
,
input_token_logprobs_val
=
recv_obj
.
input_token_logprobs_val
,
input_token_logprobs_idx
=
recv_obj
.
input_token_logprobs_idx
,
output_token_logprobs_val
=
recv_obj
.
output_token_logprobs_val
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
1dda8c5e
...
...
@@ -354,10 +354,13 @@ class BatchTokenIDOut:
skip_special_tokens
:
List
[
bool
]
spaces_between_special_tokens
:
List
[
bool
]
no_stop_trim
:
List
[
bool
]
# Token counts
prompt_tokens
:
List
[
int
]
completion_tokens
:
List
[
int
]
cached_tokens
:
List
[
int
]
spec_verify_ct
:
List
[
int
]
# Logprobs
input_token_logprobs_val
:
List
[
float
]
input_token_logprobs_idx
:
List
[
int
]
...
...
@@ -382,6 +385,7 @@ class BatchStrOut:
prompt_tokens
:
List
[
int
]
completion_tokens
:
List
[
int
]
cached_tokens
:
List
[
int
]
spec_verify_ct
:
List
[
int
]
# Logprobs
input_token_logprobs_val
:
List
[
float
]
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
1dda8c5e
...
...
@@ -252,7 +252,6 @@ class Req:
# Sampling info
self
.
sampling_params
=
sampling_params
self
.
lora_path
=
lora_path
self
.
custom_logit_processor
=
custom_logit_processor
# Memory pool info
...
...
@@ -300,7 +299,7 @@ class Req:
self
.
logprob_start_len
=
0
self
.
top_logprobs_num
=
top_logprobs_num
# Logprobs (return value)
# Logprobs (return value
s
)
self
.
input_token_logprobs_val
:
Optional
[
List
[
float
]]
=
None
self
.
input_token_logprobs_idx
:
Optional
[
List
[
int
]]
=
None
self
.
input_top_logprobs_val
:
Optional
[
List
[
float
]]
=
None
...
...
@@ -329,10 +328,15 @@ class Req:
# Constrained decoding
self
.
grammar
:
Optional
[
BaseGrammarObject
]
=
None
# The number of cached tokens
,
that were already cached in the KV cache
# The number of cached tokens that were already cached in the KV cache
self
.
cached_tokens
=
0
self
.
already_computed
=
0
# The number of verification forward passes in the speculative decoding.
# This is used to compute the average acceptance length per request.
self
.
spec_verify_ct
=
0
self
.
lora_path
=
lora_path
def
extend_image_inputs
(
self
,
image_inputs
):
if
self
.
image_inputs
is
None
:
self
.
image_inputs
=
image_inputs
...
...
python/sglang/srt/managers/scheduler.py
View file @
1dda8c5e
...
...
@@ -281,6 +281,7 @@ class Scheduler:
# Print debug info
logger
.
info
(
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"chunked_prefill_size=
{
server_args
.
chunked_prefill_size
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
"
...
...
@@ -408,6 +409,11 @@ class Scheduler:
},
)
# The largest prefill length of a single request
self
.
_largest_prefill_len
:
int
=
0
# The largest context length (prefill + generation) of a single request
self
.
_largest_prefill_decode_len
:
int
=
0
# Init request dispatcher
self
.
_request_dispatcher
=
TypeBasedDispatcher
(
[
...
...
@@ -1371,6 +1377,7 @@ class Scheduler:
prompt_tokens
=
[]
completion_tokens
=
[]
cached_tokens
=
[]
spec_verify_ct
=
[]
if
return_logprob
:
input_token_logprobs_val
=
[]
...
...
@@ -1424,6 +1431,9 @@ class Scheduler:
completion_tokens
.
append
(
len
(
req
.
output_ids
))
cached_tokens
.
append
(
req
.
cached_tokens
)
if
not
self
.
spec_algorithm
.
is_none
():
spec_verify_ct
.
append
(
req
.
spec_verify_ct
)
if
return_logprob
:
input_token_logprobs_val
.
append
(
req
.
input_token_logprobs_val
)
input_token_logprobs_idx
.
append
(
req
.
input_token_logprobs_idx
)
...
...
@@ -1451,6 +1461,7 @@ class Scheduler:
prompt_tokens
,
completion_tokens
,
cached_tokens
,
spec_verify_ct
,
input_token_logprobs_val
,
input_token_logprobs_idx
,
output_token_logprobs_val
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
1dda8c5e
...
...
@@ -785,6 +785,9 @@ class TokenizerManager:
i
,
)
if
self
.
server_args
.
speculative_algorithm
:
meta_info
[
"spec_verify_ct"
]
=
recv_obj
.
spec_verify_ct
[
i
]
if
not
isinstance
(
recv_obj
,
BatchEmbeddingOut
):
meta_info
.
update
(
{
...
...
@@ -809,6 +812,7 @@ class TokenizerManager:
"embedding"
:
recv_obj
.
embeddings
[
i
],
"meta_info"
:
meta_info
,
}
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished_reasons
[
i
]
is
not
None
state
.
event
.
set
()
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
1dda8c5e
...
...
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
batch_size
:
int
):
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
num_tokens
:
int
):
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
if
reverse
:
...
...
@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
else
:
# NOTE: Temporarily workaround MoE
if
"FusedMoE"
in
sub
.
__class__
.
__name__
:
if
batch_size
==
1
:
if
num_tokens
==
1
:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
sub
.
_forward_method
=
fused_moe_forward_native
...
...
@@ -55,14 +55,14 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
sub
.
_forward_method
=
sub
.
forward_native
setattr
(
sub
,
"is_torch_compile"
,
True
)
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
_to_torch
(
sub
,
reverse
,
batch_size
)
_to_torch
(
sub
,
reverse
,
num_tokens
)
@
contextmanager
def
patch_model
(
model
:
torch
.
nn
.
Module
,
enable_compile
:
bool
,
batch_size
:
int
,
num_tokens
:
int
,
tp_group
:
GroupCoordinator
,
):
"""Patch the model to make it compatible with with torch.compile"""
...
...
@@ -70,7 +70,7 @@ def patch_model(
try
:
if
enable_compile
:
_to_torch
(
model
,
reverse
=
False
,
batch_size
=
batch_size
)
_to_torch
(
model
,
reverse
=
False
,
num_tokens
=
num_tokens
)
backup_ca_comm
=
tp_group
.
ca_comm
# Use custom-allreduce here.
# We found the custom allreduce is much faster than the built-in allreduce in torch,
...
...
@@ -85,7 +85,7 @@ def patch_model(
yield
model
.
forward
finally
:
if
enable_compile
:
_to_torch
(
model
,
reverse
=
True
,
batch_size
=
batch_size
)
_to_torch
(
model
,
reverse
=
True
,
num_tokens
=
num_tokens
)
tp_group
.
ca_comm
=
backup_ca_comm
...
...
@@ -283,8 +283,8 @@ class CudaGraphRunner:
with
patch_model
(
self
.
model_runner
.
model
,
bs
in
self
.
compile_bs
,
bs
,
self
.
model_runner
.
tp_group
,
num_tokens
=
bs
*
self
.
num_tokens_per_
bs
,
tp_group
=
self
.
model_runner
.
tp_group
,
)
as
forward
:
(
graph
,
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
1dda8c5e
...
...
@@ -603,6 +603,7 @@ class EagleVerifyInput(SpecInfo):
if
not
req
.
finished
():
new_accept_index
.
extend
(
new_accept_index_
)
unfinished_index
.
append
(
i
)
req
.
spec_verify_ct
+=
1
accept_length
=
(
accept_index
!=
-
1
).
sum
(
dim
=
1
)
-
1
accept_index
=
accept_index
[
accept_index
!=
-
1
]
...
...
python/sglang/srt/utils.py
View file @
1dda8c5e
...
...
@@ -14,6 +14,7 @@
"""Common utilities."""
import
base64
import
ctypes
import
dataclasses
import
io
import
ipaddress
...
...
@@ -29,6 +30,7 @@ import shutil
import
signal
import
socket
import
subprocess
import
sys
import
tempfile
import
time
import
warnings
...
...
@@ -59,7 +61,6 @@ from triton.runtime.cache import (
default_dump_dir
,
default_override_dir
,
)
from
uvicorn.config
import
LOGGING_CONFIG
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -1366,7 +1367,33 @@ def nullable_str(val: str):
return
val
def
pyspy_dump_schedulers
():
"""py-spy dump on all scheduler in a local node."""
try
:
pid
=
psutil
.
Process
().
pid
# Command to run py-spy with the PID
cmd
=
f
"py-spy dump --pid
{
pid
}
"
result
=
subprocess
.
run
(
cmd
,
shell
=
True
,
capture_output
=
True
,
text
=
True
,
check
=
True
)
logger
.
info
(
f
"Profile for PID
{
pid
}
:
\n
{
result
.
stdout
}
"
)
except
subprocess
.
CalledProcessError
as
e
:
logger
.
info
(
f
"Failed to profile PID
{
pid
}
. Error:
{
e
.
stderr
}
"
)
def
kill_itself_when_parent_died
():
if
sys
.
platform
==
"linux"
:
# sigkill this process when parent worker manager dies
PR_SET_PDEATHSIG
=
1
libc
=
ctypes
.
CDLL
(
"libc.so.6"
)
libc
.
prctl
(
PR_SET_PDEATHSIG
,
signal
.
SIGKILL
)
else
:
logger
.
warninig
(
"kill_itself_when_parent_died is only supported in linux."
)
def
set_uvicorn_logging_configs
():
from
uvicorn.config
import
LOGGING_CONFIG
LOGGING_CONFIG
[
"formatters"
][
"default"
][
"fmt"
]
=
"[%(asctime)s] %(levelprefix)s %(message)s"
...
...
@@ -1449,3 +1476,28 @@ def rank0_print(msg: str):
if
get_tensor_model_parallel_rank
()
==
0
:
print
(
msg
,
flush
=
True
)
def
launch_dummy_health_check_server
(
host
,
port
):
import
uvicorn
from
fastapi
import
FastAPI
,
Response
app
=
FastAPI
()
@
app
.
get
(
"/health"
)
async
def
health
():
"""Check the health of the http server."""
return
Response
(
status_code
=
200
)
@
app
.
get
(
"/health_generate"
)
async
def
health_generate
():
"""Check the health of the http server."""
return
Response
(
status_code
=
200
)
uvicorn
.
run
(
app
,
host
=
host
,
port
=
port
,
timeout_keep_alive
=
5
,
loop
=
"uvloop"
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment