Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bc6915e3
Unverified
Commit
bc6915e3
authored
Jan 16, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 16, 2025
Browse files
Improve type annotation and styles (#2926)
parent
a883f079
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
78 additions
and
26 deletions
+78
-26
benchmark/tree_of_thought_deep/bench_sglang.py
benchmark/tree_of_thought_deep/bench_sglang.py
+1
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+8
-6
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+62
-15
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+4
-3
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+2
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+0
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
No files found.
benchmark/tree_of_thought_deep/bench_sglang.py
View file @
bc6915e3
...
@@ -103,6 +103,7 @@ def tree_search(s, question, num_branches):
...
@@ -103,6 +103,7 @@ def tree_search(s, question, num_branches):
def
main
(
args
):
def
main
(
args
):
lines
=
read_jsonl
(
args
.
data_path
)
lines
=
read_jsonl
(
args
.
data_path
)
lines
=
list
(
lines
)
# Construct prompts
# Construct prompts
num_branches
=
2
num_branches
=
2
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
bc6915e3
...
@@ -226,8 +226,9 @@ class Req:
...
@@ -226,8 +226,9 @@ class Req:
else
origin_input_ids
# Before image padding
else
origin_input_ids
# Before image padding
)
)
self
.
origin_input_ids
=
origin_input_ids
self
.
origin_input_ids
=
origin_input_ids
self
.
output_ids
=
[]
# Each decode stage's output ids
# Each decode stage's output ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
self
.
output_ids
=
[]
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
self
.
session_id
=
session_id
self
.
session_id
=
session_id
self
.
input_embeds
=
input_embeds
self
.
input_embeds
=
input_embeds
...
@@ -265,6 +266,7 @@ class Req:
...
@@ -265,6 +266,7 @@ class Req:
# Prefix info
# Prefix info
self
.
prefix_indices
=
[]
self
.
prefix_indices
=
[]
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
# Updated if chunked.
self
.
extend_input_len
=
0
self
.
extend_input_len
=
0
self
.
last_node
=
None
self
.
last_node
=
None
...
@@ -280,10 +282,10 @@ class Req:
...
@@ -280,10 +282,10 @@ class Req:
self
.
top_logprobs_num
=
top_logprobs_num
self
.
top_logprobs_num
=
top_logprobs_num
# Logprobs (return value)
# Logprobs (return value)
self
.
input_token_logprobs_val
=
None
self
.
input_token_logprobs_val
:
Optional
[
List
[
float
]]
=
None
self
.
input_token_logprobs_idx
=
None
self
.
input_token_logprobs_idx
:
Optional
[
List
[
int
]]
=
None
self
.
input_top_logprobs_val
=
None
self
.
input_top_logprobs_val
:
Optional
[
List
[
float
]]
=
None
self
.
input_top_logprobs_idx
=
None
self
.
input_top_logprobs_idx
:
Optional
[
List
[
int
]]
=
None
if
return_logprob
:
if
return_logprob
:
self
.
output_token_logprobs_val
=
[]
self
.
output_token_logprobs_val
=
[]
...
...
python/sglang/srt/managers/scheduler.py
View file @
bc6915e3
...
@@ -22,8 +22,9 @@ import time
...
@@ -22,8 +22,9 @@ import time
import
warnings
import
warnings
from
collections
import
deque
from
collections
import
deque
from
concurrent
import
futures
from
concurrent
import
futures
from
dataclasses
import
dataclass
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
psutil
import
psutil
import
setproctitle
import
setproctitle
...
@@ -102,6 +103,19 @@ logger = logging.getLogger(__name__)
...
@@ -102,6 +103,19 @@ logger = logging.getLogger(__name__)
test_retract
=
get_bool_env_var
(
"SGLANG_TEST_RETRACT"
)
test_retract
=
get_bool_env_var
(
"SGLANG_TEST_RETRACT"
)
@
dataclass
class
GenerationBatchResult
:
logits_output
:
LogitsProcessorOutput
next_token_ids
:
List
[
int
]
bid
:
int
@
dataclass
class
EmbeddingBatchResult
:
embeddings
:
torch
.
Tensor
bid
:
int
class
Scheduler
:
class
Scheduler
:
"""A scheduler that manages a tensor parallel GPU worker."""
"""A scheduler that manages a tensor parallel GPU worker."""
...
@@ -411,16 +425,16 @@ class Scheduler:
...
@@ -411,16 +425,16 @@ class Scheduler:
self
.
watchdog_last_time
=
time
.
time
()
self
.
watchdog_last_time
=
time
.
time
()
while
True
:
while
True
:
current
=
time
.
time
()
if
self
.
cur_batch
is
not
None
:
if
self
.
cur_batch
is
not
None
:
if
self
.
watchdog_last_forward_ct
==
self
.
forward_ct
:
if
self
.
watchdog_last_forward_ct
==
self
.
forward_ct
:
if
time
.
time
()
>
self
.
watchdog_last_time
+
self
.
watchdog_timeout
:
if
current
>
self
.
watchdog_last_time
+
self
.
watchdog_timeout
:
logger
.
error
(
f
"Watchdog timeout (
{
self
.
watchdog_timeout
=
}
)"
)
logger
.
error
(
f
"Watchdog timeout (
{
self
.
watchdog_timeout
=
}
)"
)
break
break
else
:
else
:
self
.
watchdog_last_forward_ct
=
self
.
forward_ct
self
.
watchdog_last_forward_ct
=
self
.
forward_ct
self
.
watchdog_last_time
=
time
.
time
()
self
.
watchdog_last_time
=
current
time
.
sleep
(
self
.
watchdog_timeout
/
2
)
time
.
sleep
(
self
.
watchdog_timeout
//
2
)
# Wait sometimes so that the parent process can print the error.
# Wait sometimes so that the parent process can print the error.
time
.
sleep
(
5
)
time
.
sleep
(
5
)
self
.
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
self
.
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
...
@@ -1018,7 +1032,9 @@ class Scheduler:
...
@@ -1018,7 +1032,9 @@ class Scheduler:
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
return
batch
return
batch
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
def
run_batch
(
self
,
batch
:
ScheduleBatch
)
->
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
]:
"""Run a batch."""
"""Run a batch."""
self
.
forward_ct
+=
1
self
.
forward_ct
+=
1
...
@@ -1040,15 +1056,26 @@ class Scheduler:
...
@@ -1040,15 +1056,26 @@ class Scheduler:
else
:
else
:
assert
False
,
"batch.extend_num_tokens == 0, this is unexpected!"
assert
False
,
"batch.extend_num_tokens == 0, this is unexpected!"
batch
.
output_ids
=
next_token_ids
batch
.
output_ids
=
next_token_ids
ret
=
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
ret
=
GenerationBatchResult
(
logits_output
=
logits_output
,
next_token_ids
=
next_token_ids
,
bid
=
model_worker_batch
.
bid
,
)
else
:
# embedding or reward model
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
assert
batch
.
extend_num_tokens
!=
0
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
ret
=
embeddings
,
model_worker_batch
.
bid
ret
=
EmbeddingBatchResult
(
embeddings
=
embeddings
,
bid
=
model_worker_batch
.
bid
)
return
ret
return
ret
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
):
if
batch
.
forward_mode
.
is_decode
():
if
batch
.
forward_mode
.
is_decode
():
self
.
process_batch_result_decode
(
batch
,
result
)
self
.
process_batch_result_decode
(
batch
,
result
)
if
batch
.
is_empty
():
if
batch
.
is_empty
():
...
@@ -1057,17 +1084,29 @@ class Scheduler:
...
@@ -1057,17 +1084,29 @@ class Scheduler:
self
.
process_batch_result_prefill
(
batch
,
result
)
self
.
process_batch_result_prefill
(
batch
,
result
)
elif
batch
.
forward_mode
.
is_idle
():
elif
batch
.
forward_mode
.
is_idle
():
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
self
.
tp_worker
.
resolve_batch_result
(
result
[
-
1
]
)
self
.
tp_worker
.
resolve_batch_result
(
result
.
bid
)
elif
batch
.
forward_mode
.
is_dummy_first
():
elif
batch
.
forward_mode
.
is_dummy_first
():
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
self
.
current_stream
.
synchronize
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
):
skip_stream_req
=
None
skip_stream_req
=
None
if
self
.
is_generation
:
if
self
.
is_generation
:
logits_output
,
next_token_ids
,
bid
=
result
(
logits_output
,
next_token_ids
,
bid
,
)
=
(
result
.
logits_output
,
result
.
next_token_ids
,
result
.
bid
,
)
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
...
@@ -1125,7 +1164,7 @@ class Scheduler:
...
@@ -1125,7 +1164,7 @@ class Scheduler:
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
else
:
# embedding or reward model
embeddings
,
bid
=
result
embeddings
,
bid
=
result
.
embeddings
,
result
.
bid
embeddings
=
embeddings
.
tolist
()
embeddings
=
embeddings
.
tolist
()
# Check finish conditions
# Check finish conditions
...
@@ -1149,8 +1188,16 @@ class Scheduler:
...
@@ -1149,8 +1188,16 @@ class Scheduler:
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_decode
(
logits_output
,
next_token_ids
,
bid
=
result
self
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
,
):
logits_output
,
next_token_ids
,
bid
=
(
result
.
logits_output
,
result
.
next_token_ids
,
result
.
bid
,
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
bc6915e3
...
@@ -37,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
...
@@ -37,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_group
,
get_attention_tp_group
,
get_attention_tp_size
,
initialize_dp_attention
,
initialize_dp_attention
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
...
@@ -532,7 +533,7 @@ class ModelRunner:
...
@@ -532,7 +533,7 @@ class ModelRunner:
)
)
else
:
else
:
cell_size
=
(
cell_size
=
(
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
self
.
model_config
.
get_num_kv_heads
(
get_attention_
tp_size
()
)
*
self
.
model_config
.
head_dim
*
self
.
model_config
.
head_dim
*
self
.
model_config
.
num_hidden_layers
*
self
.
model_config
.
num_hidden_layers
*
2
*
2
...
@@ -626,7 +627,7 @@ class ModelRunner:
...
@@ -626,7 +627,7 @@ class ModelRunner:
self
.
token_to_kv_pool
=
DoubleSparseTokenToKVPool
(
self
.
token_to_kv_pool
=
DoubleSparseTokenToKVPool
(
self
.
max_total_num_tokens
,
self
.
max_total_num_tokens
,
dtype
=
self
.
kv_cache_dtype
,
dtype
=
self
.
kv_cache_dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_num
=
self
.
model_config
.
get_num_kv_heads
(
get_attention_
tp_size
()
),
head_dim
=
self
.
model_config
.
head_dim
,
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
device
=
self
.
device
,
...
@@ -637,7 +638,7 @@ class ModelRunner:
...
@@ -637,7 +638,7 @@ class ModelRunner:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
max_total_num_tokens
,
self
.
max_total_num_tokens
,
dtype
=
self
.
kv_cache_dtype
,
dtype
=
self
.
kv_cache_dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_num
=
self
.
model_config
.
get_num_kv_heads
(
get_attention_
tp_size
()
),
head_dim
=
self
.
model_config
.
head_dim
,
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
device
=
self
.
device
,
...
...
python/sglang/srt/openai_api/protocol.py
View file @
bc6915e3
...
@@ -180,6 +180,7 @@ class CompletionRequest(BaseModel):
...
@@ -180,6 +180,7 @@ class CompletionRequest(BaseModel):
ignore_eos
:
bool
=
False
ignore_eos
:
bool
=
False
skip_special_tokens
:
bool
=
True
skip_special_tokens
:
bool
=
True
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
session_params
:
Optional
[
Dict
]
=
None
class
CompletionResponseChoice
(
BaseModel
):
class
CompletionResponseChoice
(
BaseModel
):
...
@@ -322,6 +323,7 @@ class ChatCompletionRequest(BaseModel):
...
@@ -322,6 +323,7 @@ class ChatCompletionRequest(BaseModel):
ignore_eos
:
bool
=
False
ignore_eos
:
bool
=
False
skip_special_tokens
:
bool
=
True
skip_special_tokens
:
bool
=
True
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
session_params
:
Optional
[
Dict
]
=
None
class
FunctionResponse
(
BaseModel
):
class
FunctionResponse
(
BaseModel
):
...
...
python/sglang/srt/server.py
View file @
bc6915e3
...
@@ -842,7 +842,6 @@ class Engine:
...
@@ -842,7 +842,6 @@ class Engine:
generator
=
ret
.
body_iterator
generator
=
ret
.
body_iterator
async
def
generator_wrapper
():
async
def
generator_wrapper
():
offset
=
0
offset
=
0
while
True
:
while
True
:
...
...
python/sglang/srt/server_args.py
View file @
bc6915e3
...
@@ -239,8 +239,8 @@ class ServerArgs:
...
@@ -239,8 +239,8 @@ class ServerArgs:
# Others
# Others
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
assert
self
.
tp_size
%
self
.
dp_size
==
0
self
.
dp_size
=
self
.
tp_size
self
.
dp_size
=
self
.
tp_size
assert
self
.
tp_size
%
self
.
dp_size
==
0
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
2
self
.
schedule_conservativeness
=
self
.
schedule_conservativeness
*
0.3
self
.
schedule_conservativeness
=
self
.
schedule_conservativeness
*
0.3
logger
.
warning
(
logger
.
warning
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment