Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bc6915e3
Unverified
Commit
bc6915e3
authored
Jan 16, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 16, 2025
Browse files
Improve type annotation and styles (#2926)
parent
a883f079
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
78 additions
and
26 deletions
+78
-26
benchmark/tree_of_thought_deep/bench_sglang.py
benchmark/tree_of_thought_deep/bench_sglang.py
+1
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+8
-6
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+62
-15
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+4
-3
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+2
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+0
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
No files found.
benchmark/tree_of_thought_deep/bench_sglang.py
View file @
bc6915e3
...
...
@@ -103,6 +103,7 @@ def tree_search(s, question, num_branches):
def
main
(
args
):
lines
=
read_jsonl
(
args
.
data_path
)
lines
=
list
(
lines
)
# Construct prompts
num_branches
=
2
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
bc6915e3
...
...
@@ -226,8 +226,9 @@ class Req:
else
origin_input_ids
# Before image padding
)
self
.
origin_input_ids
=
origin_input_ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
# Each decode stage's output ids
self
.
output_ids
=
[]
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
self
.
session_id
=
session_id
self
.
input_embeds
=
input_embeds
...
...
@@ -265,6 +266,7 @@ class Req:
# Prefix info
self
.
prefix_indices
=
[]
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
# Updated if chunked.
self
.
extend_input_len
=
0
self
.
last_node
=
None
...
...
@@ -280,10 +282,10 @@ class Req:
self
.
top_logprobs_num
=
top_logprobs_num
# Logprobs (return value)
self
.
input_token_logprobs_val
=
None
self
.
input_token_logprobs_idx
=
None
self
.
input_top_logprobs_val
=
None
self
.
input_top_logprobs_idx
=
None
self
.
input_token_logprobs_val
:
Optional
[
List
[
float
]]
=
None
self
.
input_token_logprobs_idx
:
Optional
[
List
[
int
]]
=
None
self
.
input_top_logprobs_val
:
Optional
[
List
[
float
]]
=
None
self
.
input_top_logprobs_idx
:
Optional
[
List
[
int
]]
=
None
if
return_logprob
:
self
.
output_token_logprobs_val
=
[]
...
...
python/sglang/srt/managers/scheduler.py
View file @
bc6915e3
...
...
@@ -22,8 +22,9 @@ import time
import
warnings
from
collections
import
deque
from
concurrent
import
futures
from
dataclasses
import
dataclass
from
types
import
SimpleNamespace
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
psutil
import
setproctitle
...
...
@@ -102,6 +103,19 @@ logger = logging.getLogger(__name__)
test_retract
=
get_bool_env_var
(
"SGLANG_TEST_RETRACT"
)
@
dataclass
class
GenerationBatchResult
:
logits_output
:
LogitsProcessorOutput
next_token_ids
:
List
[
int
]
bid
:
int
@
dataclass
class
EmbeddingBatchResult
:
embeddings
:
torch
.
Tensor
bid
:
int
class
Scheduler
:
"""A scheduler that manages a tensor parallel GPU worker."""
...
...
@@ -411,16 +425,16 @@ class Scheduler:
self
.
watchdog_last_time
=
time
.
time
()
while
True
:
current
=
time
.
time
()
if
self
.
cur_batch
is
not
None
:
if
self
.
watchdog_last_forward_ct
==
self
.
forward_ct
:
if
time
.
time
()
>
self
.
watchdog_last_time
+
self
.
watchdog_timeout
:
if
current
>
self
.
watchdog_last_time
+
self
.
watchdog_timeout
:
logger
.
error
(
f
"Watchdog timeout (
{
self
.
watchdog_timeout
=
}
)"
)
break
else
:
self
.
watchdog_last_forward_ct
=
self
.
forward_ct
self
.
watchdog_last_time
=
time
.
time
()
time
.
sleep
(
self
.
watchdog_timeout
/
2
)
self
.
watchdog_last_time
=
current
time
.
sleep
(
self
.
watchdog_timeout
//
2
)
# Wait sometimes so that the parent process can print the error.
time
.
sleep
(
5
)
self
.
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
...
...
@@ -1018,7 +1032,9 @@ class Scheduler:
batch
.
prepare_for_decode
()
return
batch
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
def
run_batch
(
self
,
batch
:
ScheduleBatch
)
->
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
]:
"""Run a batch."""
self
.
forward_ct
+=
1
...
...
@@ -1040,15 +1056,26 @@ class Scheduler:
else
:
assert
False
,
"batch.extend_num_tokens == 0, this is unexpected!"
batch
.
output_ids
=
next_token_ids
ret
=
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
ret
=
GenerationBatchResult
(
logits_output
=
logits_output
,
next_token_ids
=
next_token_ids
,
bid
=
model_worker_batch
.
bid
,
)
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
model_worker_batch
=
batch
.
get_model_worker_batch
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
ret
=
embeddings
,
model_worker_batch
.
bid
ret
=
EmbeddingBatchResult
(
embeddings
=
embeddings
,
bid
=
model_worker_batch
.
bid
)
return
ret
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
):
if
batch
.
forward_mode
.
is_decode
():
self
.
process_batch_result_decode
(
batch
,
result
)
if
batch
.
is_empty
():
...
...
@@ -1057,17 +1084,29 @@ class Scheduler:
self
.
process_batch_result_prefill
(
batch
,
result
)
elif
batch
.
forward_mode
.
is_idle
():
if
self
.
enable_overlap
:
self
.
tp_worker
.
resolve_batch_result
(
result
[
-
1
]
)
self
.
tp_worker
.
resolve_batch_result
(
result
.
bid
)
elif
batch
.
forward_mode
.
is_dummy_first
():
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
):
skip_stream_req
=
None
if
self
.
is_generation
:
logits_output
,
next_token_ids
,
bid
=
result
(
logits_output
,
next_token_ids
,
bid
,
)
=
(
result
.
logits_output
,
result
.
next_token_ids
,
result
.
bid
,
)
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
...
...
@@ -1125,7 +1164,7 @@ class Scheduler:
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
embeddings
,
bid
=
result
embeddings
,
bid
=
result
.
embeddings
,
result
.
bid
embeddings
=
embeddings
.
tolist
()
# Check finish conditions
...
...
@@ -1149,8 +1188,16 @@ class Scheduler:
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
,
bid
=
result
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
,
):
logits_output
,
next_token_ids
,
bid
=
(
result
.
logits_output
,
result
.
next_token_ids
,
result
.
bid
,
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
self
.
enable_overlap
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
bc6915e3
...
...
@@ -37,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_group
,
get_attention_tp_size
,
initialize_dp_attention
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
...
...
@@ -532,7 +533,7 @@ class ModelRunner:
)
else
:
cell_size
=
(
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
self
.
model_config
.
get_num_kv_heads
(
get_attention_
tp_size
()
)
*
self
.
model_config
.
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
...
...
@@ -626,7 +627,7 @@ class ModelRunner:
self
.
token_to_kv_pool
=
DoubleSparseTokenToKVPool
(
self
.
max_total_num_tokens
,
dtype
=
self
.
kv_cache_dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_num
=
self
.
model_config
.
get_num_kv_heads
(
get_attention_
tp_size
()
),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
...
...
@@ -637,7 +638,7 @@ class ModelRunner:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
max_total_num_tokens
,
dtype
=
self
.
kv_cache_dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_num
=
self
.
model_config
.
get_num_kv_heads
(
get_attention_
tp_size
()
),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
...
...
python/sglang/srt/openai_api/protocol.py
View file @
bc6915e3
...
...
@@ -180,6 +180,7 @@ class CompletionRequest(BaseModel):
ignore_eos
:
bool
=
False
skip_special_tokens
:
bool
=
True
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
session_params
:
Optional
[
Dict
]
=
None
class
CompletionResponseChoice
(
BaseModel
):
...
...
@@ -322,6 +323,7 @@ class ChatCompletionRequest(BaseModel):
ignore_eos
:
bool
=
False
skip_special_tokens
:
bool
=
True
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
session_params
:
Optional
[
Dict
]
=
None
class
FunctionResponse
(
BaseModel
):
...
...
python/sglang/srt/server.py
View file @
bc6915e3
...
...
@@ -842,7 +842,6 @@ class Engine:
generator
=
ret
.
body_iterator
async
def
generator_wrapper
():
offset
=
0
while
True
:
...
...
python/sglang/srt/server_args.py
View file @
bc6915e3
...
...
@@ -239,8 +239,8 @@ class ServerArgs:
# Others
if
self
.
enable_dp_attention
:
assert
self
.
tp_size
%
self
.
dp_size
==
0
self
.
dp_size
=
self
.
tp_size
assert
self
.
tp_size
%
self
.
dp_size
==
0
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
self
.
schedule_conservativeness
=
self
.
schedule_conservativeness
*
0.3
logger
.
warning
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment