Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f86c1e61
Unverified
Commit
f86c1e61
authored
Sep 29, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 29, 2024
Browse files
Move scheduler code from tp_worker.py to scheduler.py (#1538)
parent
acaffd23
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
933 additions
and
870 deletions
+933
-870
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+12
-4
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+3
-8
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+879
-8
python/sglang/srt/managers/scheduler_policy.py
python/sglang/srt/managers/scheduler_policy.py
+2
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+21
-832
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+2
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+9
-10
No files found.
python/sglang/bench_latency.py
View file @
f86c1e61
...
@@ -167,9 +167,13 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
...
@@ -167,9 +167,13 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
assert
len
(
input_ids
[
i
])
>
bench_args
.
cut_len
assert
len
(
input_ids
[
i
])
>
bench_args
.
cut_len
tmp_input_ids
=
input_ids
[
i
][:
bench_args
.
cut_len
]
tmp_input_ids
=
input_ids
[
i
][:
bench_args
.
cut_len
]
req
=
Req
(
rid
=
i
,
origin_input_text
=
prompts
[
i
],
origin_input_ids
=
tmp_input_ids
)
req
=
Req
(
rid
=
i
,
origin_input_text
=
prompts
[
i
],
origin_input_ids
=
tmp_input_ids
,
sampling_params
=
sampling_params
,
)
req
.
prefix_indices
=
[]
req
.
prefix_indices
=
[]
req
.
sampling_params
=
sampling_params
req
.
fill_ids
=
req
.
origin_input_ids
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
reqs
.
append
(
req
)
reqs
.
append
(
req
)
...
@@ -199,9 +203,13 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
...
@@ -199,9 +203,13 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
reqs
=
[]
reqs
=
[]
for
i
in
range
(
len
(
input_ids
)):
for
i
in
range
(
len
(
input_ids
)):
req
=
Req
(
rid
=
i
,
origin_input_text
=
""
,
origin_input_ids
=
list
(
input_ids
[
i
]))
req
=
Req
(
rid
=
i
,
origin_input_text
=
""
,
origin_input_ids
=
list
(
input_ids
[
i
]),
sampling_params
=
sampling_params
,
)
req
.
prefix_indices
=
[]
req
.
prefix_indices
=
[]
req
.
sampling_params
=
sampling_params
req
.
fill_ids
=
req
.
origin_input_ids
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
reqs
.
append
(
req
)
reqs
.
append
(
req
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
f86c1e61
...
@@ -18,7 +18,6 @@ The definition of objects transfered between different
...
@@ -18,7 +18,6 @@ The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
processes (TokenizerManager, DetokenizerManager, Controller).
"""
"""
import
copy
import
uuid
import
uuid
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
...
@@ -53,12 +52,12 @@ class GenerateReqInput:
...
@@ -53,12 +52,12 @@ class GenerateReqInput:
stream
:
bool
=
False
stream
:
bool
=
False
# The modalities of the image data [image, multi-images, video]
# The modalities of the image data [image, multi-images, video]
modalities
:
Optional
[
List
[
str
]]
=
None
modalities
:
Optional
[
List
[
str
]]
=
None
is_single
:
bool
=
True
# LoRA related
# LoRA related
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
def
post_init
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
...
@@ -307,10 +306,6 @@ class BatchTokenIDOut:
...
@@ -307,10 +306,6 @@ class BatchTokenIDOut:
meta_info
:
List
[
Dict
]
meta_info
:
List
[
Dict
]
finished_reason
:
List
[
BaseFinishReason
]
finished_reason
:
List
[
BaseFinishReason
]
def
__post_init__
(
self
):
# deepcopy meta_info to avoid modification in place
self
.
meta_info
=
copy
.
deepcopy
(
self
.
meta_info
)
@
dataclass
@
dataclass
class
BatchStrOut
:
class
BatchStrOut
:
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
f86c1e61
...
@@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
...
@@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
@@ -143,6 +144,7 @@ class Req:
...
@@ -143,6 +144,7 @@ class Req:
rid
:
str
,
rid
:
str
,
origin_input_text
:
str
,
origin_input_text
:
str
,
origin_input_ids
:
Tuple
[
int
],
origin_input_ids
:
Tuple
[
int
],
sampling_params
:
SamplingParams
,
lora_path
:
Optional
[
str
]
=
None
,
lora_path
:
Optional
[
str
]
=
None
,
):
):
# Input and output info
# Input and output info
...
@@ -152,6 +154,8 @@ class Req:
...
@@ -152,6 +154,8 @@ class Req:
self
.
origin_input_ids
=
origin_input_ids
self
.
origin_input_ids
=
origin_input_ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
self
.
sampling_params
=
sampling_params
self
.
lora_path
=
lora_path
self
.
lora_path
=
lora_path
# Memory info
# Memory info
...
@@ -160,6 +164,7 @@ class Req:
...
@@ -160,6 +164,7 @@ class Req:
# Check finish
# Check finish
self
.
tokenizer
=
None
self
.
tokenizer
=
None
self
.
finished_reason
=
None
self
.
finished_reason
=
None
self
.
stream
=
False
# For incremental decoding
# For incremental decoding
# ----- | --------- read_ids -------|
# ----- | --------- read_ids -------|
...
@@ -187,10 +192,6 @@ class Req:
...
@@ -187,10 +192,6 @@ class Req:
self
.
extend_input_len
=
0
self
.
extend_input_len
=
0
self
.
last_node
=
None
self
.
last_node
=
None
# Sampling parameters
self
.
sampling_params
=
None
self
.
stream
=
False
# Logprobs (arguments)
# Logprobs (arguments)
self
.
return_logprob
=
False
self
.
return_logprob
=
False
self
.
logprob_start_len
=
0
self
.
logprob_start_len
=
0
...
...
python/sglang/srt/managers/scheduler.py
View file @
f86c1e61
...
@@ -15,18 +15,62 @@ limitations under the License.
...
@@ -15,18 +15,62 @@ limitations under the License.
"""A scheduler that manages a tensor parallel GPU worker."""
"""A scheduler that manages a tensor parallel GPU worker."""
import
json
import
logging
import
logging
import
multiprocessing
import
multiprocessing
import
os
import
time
import
warnings
from
typing
import
List
,
Optional
,
Union
import
torch
import
zmq
import
zmq
from
sglang.srt.managers.tp_worker
import
ModelTpServer
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
FlushCacheReq
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedRewardReqInput
,
UpdateWeightReqInput
,
UpdateWeightReqOutput
,
)
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
BaseFinishReason
,
ImageInputs
,
Req
,
ScheduleBatch
,
)
from
sglang.srt.managers.scheduler_policy
import
PrefillAdder
,
SchedulerPolicy
from
sglang.srt.managers.tp_worker
import
ModelTpWorker
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
broadcast_pyobj
,
configure_logger
,
kill_parent_process
from
sglang.srt.utils
import
(
broadcast_pyobj
,
configure_logger
,
is_generation_model
,
is_multimodal_model
,
kill_parent_process
,
set_random_seed
,
suppress_other_loggers
,
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Crash on warning if we are running CI tests
crash_on_warning
=
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
class
Scheduler
:
class
Scheduler
:
"""A scheduler that manages a tensor parallel GPU worker."""
"""A scheduler that manages a tensor parallel GPU worker."""
...
@@ -39,8 +83,13 @@ class Scheduler:
...
@@ -39,8 +83,13 @@ class Scheduler:
tp_rank
:
int
,
tp_rank
:
int
,
):
):
# Parse args
# Parse args
self
.
server_args
=
server_args
self
.
tp_rank
=
tp_rank
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
server_args
.
tp_size
self
.
tp_size
=
server_args
.
tp_size
self
.
schedule_policy
=
server_args
.
schedule_policy
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
self
.
lora_paths
=
server_args
.
lora_paths
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
# Init inter-process communication
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
context
=
zmq
.
Context
(
2
)
...
@@ -54,30 +103,146 @@ class Scheduler:
...
@@ -54,30 +103,146 @@ class Scheduler:
f
"tcp://127.0.0.1:
{
port_args
.
detokenizer_port
}
"
f
"tcp://127.0.0.1:
{
port_args
.
detokenizer_port
}
"
)
)
else
:
else
:
self
.
send_to_detokenizer
=
None
self
.
recv_from_tokenizer
=
self
.
send_to_detokenizer
=
None
# Init tokenizer
self
.
model_config
=
ModelConfig
(
server_args
.
model_path
,
server_args
.
trust_remote_code
,
context_length
=
server_args
.
context_length
,
model_override_args
=
json
.
loads
(
server_args
.
json_model_override_args
),
)
if
server_args
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
processor
=
None
else
:
if
is_multimodal_model
(
self
.
model_config
.
hf_config
.
architectures
):
self
.
processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
self
.
tokenizer
=
self
.
processor
.
tokenizer
else
:
self
.
tokenizer
=
get_tokenizer
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
self
.
is_generation
=
is_generation_model
(
self
.
model_config
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
)
# Launch a t
p serv
er
# Launch a t
ensor parallel work
er
self
.
tp_
serv
er
=
ModelTp
Serv
er
(
self
.
tp_
work
er
=
ModelTp
Work
er
(
gpu_id
=
gpu_id
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
server_args
=
server_args
,
server_args
=
server_args
,
nccl_port
=
port_args
.
nccl_ports
[
0
],
nccl_port
=
port_args
.
nccl_ports
[
0
],
)
)
self
.
tp_cpu_group
=
self
.
tp_server
.
model_runner
.
tp_group
.
cpu_group
self
.
tp_cpu_group
=
self
.
tp_worker
.
model_runner
.
tp_group
.
cpu_group
# Get token and memory info from the tp worker
(
self
.
max_total_num_tokens
,
self
.
max_prefill_tokens
,
self
.
max_running_requests
,
self
.
max_req_input_len
,
self
.
random_seed
,
)
=
self
.
tp_worker
.
get_token_and_memory_info
()
set_random_seed
(
self
.
random_seed
)
# Print debug info
logger
.
info
(
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
"
)
# Init cache
self
.
req_to_token_pool
=
self
.
tp_worker
.
model_runner
.
req_to_token_pool
self
.
token_to_kv_pool
=
self
.
tp_worker
.
model_runner
.
token_to_kv_pool
if
(
server_args
.
chunked_prefill_size
is
not
None
and
server_args
.
disable_radix_cache
):
self
.
tree_cache
=
ChunkCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
)
else
:
self
.
tree_cache
=
RadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
disable
=
server_args
.
disable_radix_cache
,
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
policy
=
SchedulerPolicy
(
self
.
schedule_policy
,
self
.
tree_cache
)
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
running_batch
:
ScheduleBatch
=
None
self
.
out_pyobjs
=
[]
self
.
decode_forward_ct
=
0
self
.
stream_interval
=
server_args
.
stream_interval
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
self
.
is_mixed_chunk
=
(
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
# Init the FSM cache for constrained generation
if
not
server_args
.
skip_tokenizer_init
:
self
.
regex_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
constrained_json_whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
)
self
.
jump_forward_cache
=
JumpForwardCache
()
# Init new token estimation
assert
(
server_args
.
schedule_conservativeness
>=
0
),
"Invalid schedule_conservativeness"
self
.
min_new_token_ratio
=
min
(
global_config
.
base_min_new_token_ratio
*
server_args
.
schedule_conservativeness
,
1.0
,
)
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
do_not_get_new_batch
=
False
def
event_loop
(
self
):
def
event_loop
(
self
):
while
True
:
while
True
:
# Receive requests
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
recv_reqs
=
self
.
recv_requests_from_zmq
()
recv_reqs
=
self
.
recv_requests_from_zmq
()
else
:
else
:
recv_reqs
=
None
recv_reqs
=
None
# Process requests
recv_reqs
=
broadcast_pyobj
(
recv_reqs
,
self
.
tp_rank
,
self
.
tp_cpu_group
)
recv_reqs
=
broadcast_pyobj
(
recv_reqs
,
self
.
tp_rank
,
self
.
tp_cpu_group
)
out_pyobjs
=
self
.
tp_server
.
exposed_step
(
recv_reqs
)
self
.
process_requests
(
recv_reqs
)
# Forward
self
.
forward_step
()
# Send results
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
for
obj
in
out_pyobjs
:
for
obj
in
self
.
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
out_pyobjs
=
[]
def
recv_requests_from_zmq
(
self
):
def
recv_requests_from_zmq
(
self
):
recv_reqs
=
[]
recv_reqs
=
[]
...
@@ -91,6 +256,711 @@ class Scheduler:
...
@@ -91,6 +256,711 @@ class Scheduler:
return
recv_reqs
return
recv_reqs
def
process_requests
(
self
,
recv_reqs
:
List
):
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
handle_generate_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
(
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
)
):
self
.
handle_embedding_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
self
.
abort_request
(
recv_req
)
elif
isinstance
(
recv_req
,
UpdateWeightReqInput
):
success
,
message
=
self
.
update_weights
(
recv_req
)
self
.
out_pyobjs
.
append
(
UpdateWeightReqOutput
(
success
,
message
))
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
@
torch
.
inference_mode
()
def
forward_step
(
self
):
if
self
.
do_not_get_new_batch
and
self
.
current_inflight_req
is
None
:
new_batch
=
None
else
:
new_batch
=
self
.
get_new_prefill_batch
()
self
.
do_not_get_new_batch
=
False
if
new_batch
is
not
None
:
# Run a new prefill batch
self
.
forward_prefill_batch
(
new_batch
)
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
new_batch
else
:
self
.
running_batch
.
merge
(
new_batch
)
else
:
# Run a decode batch
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
self
.
num_generated_tokens
+=
len
(
self
.
running_batch
.
reqs
)
self
.
forward_decode_batch
(
self
.
running_batch
)
# Print stats
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_decode_stats
()
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
has_stream
:
break
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
print_decode_stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
throughput
=
self
.
num_generated_tokens
/
(
time
.
time
()
-
self
.
last_stats_tic
)
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
logger
.
info
(
f
"Decode batch. "
f
"#running-req:
{
len
(
self
.
running_batch
.
reqs
)
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
)
def
check_memory
(
self
):
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
available_size
!=
self
.
max_total_num_tokens
:
warnings
.
warn
(
"Warning: "
f
"available_size=
{
available_size
}
, max_total_num_tokens=
{
self
.
max_total_num_tokens
}
\n
"
"KV cache pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
warnings
.
warn
(
"Warning: "
f
"available req slots=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total slots=
{
self
.
req_to_token_pool
.
size
}
\n
"
"Memory pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
def
handle_generate_request
(
self
,
recv_req
:
TokenizedGenerateReqInput
,
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
,
recv_req
.
sampling_params
,
lora_path
=
recv_req
.
lora_path
,
)
req
.
tokenizer
=
self
.
tokenizer
# Image inputs
if
recv_req
.
image_inputs
is
not
None
:
req
.
image_inputs
=
ImageInputs
.
from_dict
(
recv_req
.
image_inputs
,
self
.
model_config
.
vocab_size
)
req
.
origin_input_ids
=
self
.
tp_worker
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
image_inputs
)
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
if
req
.
logprob_start_len
==
-
1
:
# By default, only return the logprobs for output tokens
req
.
logprob_start_len
=
len
(
recv_req
.
input_ids
)
-
1
# Init regex FSM
if
(
req
.
sampling_params
.
json_schema
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
):
if
req
.
sampling_params
.
json_schema
is
not
None
:
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex_fsm_cache
.
query
(
(
"json"
,
req
.
sampling_params
.
json_schema
)
)
elif
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex_fsm_cache
.
query
(
(
"regex"
,
req
.
sampling_params
.
regex
)
)
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
computed_regex_string
)
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
logger
.
warning
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
req
.
sampling_params
.
max_new_tokens
=
min
(
(
req
.
sampling_params
.
max_new_tokens
if
req
.
sampling_params
.
max_new_tokens
is
not
None
else
1
<<
30
),
self
.
max_req_input_len
-
1
-
len
(
req
.
origin_input_ids
),
)
self
.
waiting_queue
.
append
(
req
)
def
handle_embedding_request
(
self
,
recv_req
:
Union
[
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
],
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
,
recv_req
.
sampling_params
,
)
req
.
tokenizer
=
self
.
tokenizer
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
logger
.
warning
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
self
.
waiting_queue
.
append
(
req
)
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
running_bs
=
(
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
)
if
running_bs
>=
self
.
max_running_requests
:
return
None
# Get priority queue
prefix_computed
=
self
.
policy
.
calc_priority
(
self
.
waiting_queue
)
num_mixed_running
=
running_bs
if
self
.
is_mixed_chunk
else
0
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
running_batch
,
self
.
new_token_ratio
,
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
(),
self
.
max_prefill_tokens
,
self
.
chunked_prefill_size
,
num_mixed_running
,
)
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
self
.
current_inflight_req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
)
if
self
.
lora_paths
is
not
None
:
lora_set
=
(
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
if
self
.
running_batch
is
not
None
else
set
([])
)
for
req
in
self
.
waiting_queue
:
if
(
self
.
lora_paths
is
not
None
and
len
(
lora_set
|
set
([
req
.
lora_path
for
req
in
adder
.
can_run_list
])
|
set
([
req
.
lora_path
])
)
>
self
.
max_loras_per_batch
):
break
if
adder
.
no_remaining_tokens
():
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
)
if
(
not
res
or
running_bs
+
len
(
adder
.
can_run_list
)
>=
self
.
max_running_requests
):
break
can_run_list
=
adder
.
can_run_list
if
adder
.
new_inflight_req
is
not
None
:
assert
self
.
current_inflight_req
is
None
self
.
current_inflight_req
=
adder
.
new_inflight_req
if
len
(
can_run_list
)
==
0
:
return
None
# Print stats
if
self
.
tp_rank
==
0
:
if
isinstance
(
self
.
tree_cache
,
RadixCache
):
self
.
tree_cache_metrics
[
"total"
]
+=
(
adder
.
log_input_tokens
+
adder
.
log_hit_tokens
)
/
10
**
9
self
.
tree_cache_metrics
[
"hit"
]
+=
(
adder
.
log_hit_tokens
)
/
10
**
9
tree_cache_hit_rate
=
(
self
.
tree_cache_metrics
[
"hit"
]
/
self
.
tree_cache_metrics
[
"total"
]
)
else
:
tree_cache_hit_rate
=
0.0
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
num_mixed_running
>
0
:
logger
.
info
(
f
"Prefill batch"
f
"(mixed #running-req:
{
num_mixed_running
}
). "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
else
:
logger
.
info
(
f
"Prefill batch. "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
# Return the new batch
new_batch
=
ScheduleBatch
.
init_new
(
can_run_list
,
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
)
self
.
waiting_queue
=
[
x
for
x
in
self
.
waiting_queue
if
x
not
in
can_run_list
]
return
new_batch
def
forward_prefill_batch
(
self
,
batch
:
ScheduleBatch
):
# Build batch tensors
batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
decoding_reqs
=
[]
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
self
.
running_batch
.
prepare_for_decode
()
batch
.
mix_with_running
(
self
.
running_batch
)
decoding_reqs
=
self
.
running_batch
.
reqs
self
.
running_batch
=
None
if
self
.
is_generation
:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
)
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
tolist
()
)
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
else
:
if
self
.
tokenizer
is
None
:
next_token_ids
=
[]
for
req
in
batch
.
reqs
:
next_token_ids
.
append
(
next
(
iter
(
req
.
sampling_params
.
stop_token_ids
))
)
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
# Check finish conditions
logprob_pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_ids
[
i
]
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
elif
req
not
in
decoding_reqs
:
# To reduce overhead, only cache prefill reqs
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
if
req
.
return_logprob
:
logprob_pt
+=
self
.
add_logprob_return_values
(
i
,
req
,
logprob_pt
,
next_token_ids
,
logits_output
)
else
:
assert
batch
.
extend_num_tokens
!=
0
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
batch
)
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
embedding
=
embeddings
[
i
]
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
handle_finished_requests
(
batch
)
def
add_logprob_return_values
(
self
,
i
:
int
,
req
:
Req
,
pt
:
int
,
next_token_ids
:
List
[
int
],
output
:
LogitsProcessorOutput
,
):
"""Attach logprobs to the return values."""
req
.
output_token_logprobs
.
append
(
(
output
.
next_token_logprobs
[
i
],
next_token_ids
[
i
])
)
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
num_input_logprobs
=
req
.
extend_input_len
-
req
.
extend_logprob_start_len
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
if
req
.
input_token_logprobs
is
None
:
input_token_logprobs
=
output
.
input_token_logprobs
[
pt
:
pt
+
num_input_logprobs
-
1
-
req
.
last_update_decode_tokens
]
input_token_ids
=
req
.
fill_ids
[
len
(
req
.
fill_ids
)
-
num_input_logprobs
+
1
:
len
(
req
.
fill_ids
)
-
req
.
last_update_decode_tokens
]
req
.
input_token_logprobs
=
list
(
zip
(
input_token_logprobs
,
input_token_ids
))
if
(
req
.
logprob_start_len
==
0
):
# The first token does not have logprob, pad it.
req
.
input_token_logprobs
=
[
(
None
,
req
.
fill_ids
[
0
])
]
+
req
.
input_token_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
# Some decode tokens are re-computed in an extend batch
req
.
output_token_logprobs
.
extend
(
list
(
zip
(
output
.
input_token_logprobs
[
pt
+
num_input_logprobs
-
1
-
req
.
last_update_decode_tokens
:
pt
+
num_input_logprobs
-
1
],
req
.
fill_ids
[
len
(
req
.
fill_ids
)
-
req
.
last_update_decode_tokens
:
len
(
req
.
fill_ids
)
],
)
)
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
input_top_logprobs
is
None
:
req
.
input_top_logprobs
=
output
.
input_top_logprobs
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
input_top_logprobs
=
[
None
]
+
req
.
input_top_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
output_top_logprobs
.
extend
(
output
.
input_top_logprobs
[
i
][
-
req
.
last_update_decode_tokens
:]
)
req
.
output_top_logprobs
.
append
(
output
.
output_top_logprobs
[
i
])
return
num_input_logprobs
def
forward_decode_batch
(
self
,
batch
:
ScheduleBatch
):
# Check if decode out of memory
if
not
batch
.
check_decode_mem
():
old_ratio
=
self
.
new_token_ratio
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
()
self
.
new_token_ratio
=
new_token_ratio
logger
.
info
(
"Decode out of memory happened. "
f
"#retracted_reqs:
{
len
(
retracted_reqs
)
}
, "
f
"#new_token_ratio:
{
old_ratio
:.
4
f
}
->
{
self
.
new_token_ratio
:.
4
f
}
"
)
self
.
waiting_queue
.
extend
(
retracted_reqs
)
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
self
.
min_new_token_ratio
,
)
# Check for jump-forward
if
not
self
.
disable_regex_jump_forward
:
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
tp_worker
.
model_runner
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
return
# Update batch tensors
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish condition
has_finished
=
False
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_id
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
has_finished
=
True
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
if
not
has_finished
:
self
.
do_not_get_new_batch
=
True
self
.
handle_finished_requests
(
batch
)
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
output_rids
=
[]
output_meta_info
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
if
self
.
is_generation
:
output_vids
=
[]
decoded_texts
=
[]
output_read_ids
=
[]
output_read_offsets
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
else
:
# for embedding model
output_embeddings
=
[]
unfinished_indices
=
[]
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
not
req
.
finished
()
and
req
is
not
self
.
current_inflight_req
:
unfinished_indices
.
append
(
i
)
if
req
.
finished
()
or
(
req
.
stream
and
(
self
.
decode_forward_ct
%
self
.
stream_interval
==
0
or
len
(
req
.
output_ids
)
==
1
)
):
output_rids
.
append
(
req
.
rid
)
output_finished_reason
.
append
(
req
.
finished_reason
)
if
self
.
is_generation
:
output_vids
.
append
(
req
.
vid
)
decoded_texts
.
append
(
req
.
decoded_text
)
read_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
output_read_ids
.
append
(
read_ids
)
output_read_offsets
.
append
(
read_offset
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
output_spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
is
not
None
else
None
),
}
if
req
.
return_logprob
:
(
meta_info
[
"input_token_logprobs"
],
meta_info
[
"output_token_logprobs"
],
meta_info
[
"input_top_logprobs"
],
meta_info
[
"output_top_logprobs"
],
meta_info
[
"normalized_prompt_logprob"
],
)
=
(
req
.
input_token_logprobs
,
req
.
output_token_logprobs
,
req
.
input_top_logprobs
,
req
.
output_top_logprobs
,
req
.
normalized_prompt_logprob
,
)
output_meta_info
.
append
(
meta_info
)
else
:
# for embedding model
output_embeddings
.
append
(
req
.
embedding
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
}
output_meta_info
.
append
(
meta_info
)
# Send to detokenizer
if
output_rids
:
if
self
.
is_generation
:
self
.
out_pyobjs
.
append
(
BatchTokenIDOut
(
output_rids
,
output_vids
,
decoded_texts
,
output_read_ids
,
output_read_offsets
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
output_meta_info
,
output_finished_reason
,
)
)
else
:
# for embedding model
self
.
out_pyobjs
.
append
(
BatchEmbeddingOut
(
output_rids
,
output_embeddings
,
output_meta_info
,
output_finished_reason
,
)
)
# Remove finished reqs: update batch tensors
batch
.
filter_batch
(
unfinished_indices
)
def
flush_cache
(
self
):
if
len
(
self
.
waiting_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
regex_fsm_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
if_success
=
True
else
:
logging
.
warning
(
f
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
if_success
=
False
return
if_success
def
abort_request
(
self
,
recv_req
:
AbortReq
):
# Delete requests in the waiting queue
to_del
=
None
for
i
,
req
in
enumerate
(
self
.
waiting_queue
):
if
req
.
rid
==
recv_req
.
rid
:
to_del
=
i
break
if
to_del
is
not
None
:
del
self
.
waiting_queue
[
to_del
]
# Delete requests in the running batch
if
self
.
running_batch
:
for
req
in
self
.
running_batch
.
reqs
:
if
req
.
rid
==
recv_req
.
rid
:
req
.
finished_reason
=
FINISH_ABORT
()
break
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
success
,
message
=
self
.
tp_worker
.
update_weights
(
recv_req
)
if
success
:
flash_cache_success
=
self
.
flush_cache
()
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
success
,
message
def
run_scheduler_process
(
def
run_scheduler_process
(
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
...
@@ -100,6 +970,7 @@ def run_scheduler_process(
...
@@ -100,6 +970,7 @@ def run_scheduler_process(
pipe_writer
:
multiprocessing
.
connection
.
Connection
,
pipe_writer
:
multiprocessing
.
connection
.
Connection
,
):
):
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
suppress_other_loggers
()
try
:
try
:
scheduler
=
Scheduler
(
server_args
,
port_args
,
gpu_id
,
tp_rank
)
scheduler
=
Scheduler
(
server_args
,
port_args
,
gpu_id
,
tp_rank
)
...
...
python/sglang/srt/managers/
policy_
scheduler.py
→
python/sglang/srt/managers/scheduler
_policy
.py
View file @
f86c1e61
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""Request
policy
scheduler"""
"""Request scheduler
policy
"""
import
os
import
os
import
random
import
random
...
@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
...
@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
CLIP_MAX_NEW_TOKENS
=
int
(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS"
,
"4096"
))
CLIP_MAX_NEW_TOKENS
=
int
(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS"
,
"4096"
))
class
Policy
Scheduler
:
class
Scheduler
Policy
:
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
# LPM and DFS-weight is meaningless when the tree cache is disabled.
...
...
python/sglang/srt/managers/tp_worker.py
View file @
f86c1e61
...
@@ -17,58 +17,18 @@ limitations under the License.
...
@@ -17,58 +17,18 @@ limitations under the License.
import
json
import
json
import
logging
import
logging
import
os
import
time
import
warnings
from
typing
import
List
,
Optional
,
Union
import
torch
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
UpdateWeightReqInput
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
FlushCacheReq
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedRewardReqInput
,
UpdateWeightReqInput
,
UpdateWeightReqOutput
,
)
from
sglang.srt.managers.policy_scheduler
import
PolicyScheduler
,
PrefillAdder
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
BaseFinishReason
,
ImageInputs
,
Req
,
ScheduleBatch
,
)
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
broadcast_pyobj
,
is_multimodal_model
,
set_random_seed
broadcast_pyobj
,
is_multimodal_model
,
set_random_seed
,
suppress_other_loggers
,
)
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Crash on warning if we are running CI tests
class
ModelTpWorker
:
crash_on_warning
=
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
class
ModelTpServer
:
def
__init__
(
def
__init__
(
self
,
self
,
gpu_id
:
int
,
gpu_id
:
int
,
...
@@ -76,17 +36,8 @@ class ModelTpServer:
...
@@ -76,17 +36,8 @@ class ModelTpServer:
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
nccl_port
:
int
,
nccl_port
:
int
,
):
):
suppress_other_loggers
()
# Parse args
# Parse arguments
self
.
gpu_id
=
gpu_id
self
.
tp_rank
=
tp_rank
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
server_args
.
tp_size
self
.
dp_size
=
server_args
.
dp_size
self
.
schedule_policy
=
server_args
.
schedule_policy
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
self
.
lora_paths
=
server_args
.
lora_paths
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
# Init model and tokenizer
# Init model and tokenizer
self
.
model_config
=
ModelConfig
(
self
.
model_config
=
ModelConfig
(
...
@@ -120,6 +71,8 @@ class ModelTpServer:
...
@@ -120,6 +71,8 @@ class ModelTpServer:
tokenizer_mode
=
server_args
.
tokenizer_mode
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
)
# Profile number of tokens
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_prefill_tokens
=
server_args
.
max_prefill_tokens
self
.
max_prefill_tokens
=
server_args
.
max_prefill_tokens
self
.
max_running_requests
=
min
(
self
.
max_running_requests
=
min
(
...
@@ -136,798 +89,34 @@ class ModelTpServer:
...
@@ -136,798 +89,34 @@ class ModelTpServer:
)
)
# Sync random seed across TP workers
# Sync random seed across TP workers
se
rver_args
.
random_seed
=
broadcast_pyobj
(
se
lf
.
random_seed
=
broadcast_pyobj
(
[
server_args
.
random_seed
],
[
server_args
.
random_seed
],
self
.
tp_rank
,
self
.
tp_rank
,
self
.
model_runner
.
tp_group
.
cpu_group
,
self
.
model_runner
.
tp_group
.
cpu_group
,
)[
0
]
)[
0
]
set_random_seed
(
server_args
.
random_seed
)
set_random_seed
(
self
.
random_seed
)
# Print debug info
logger
.
info
(
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
"
)
# Init cache
if
(
server_args
.
chunked_prefill_size
is
not
None
and
server_args
.
disable_radix_cache
):
self
.
tree_cache
=
ChunkCache
(
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
)
else
:
self
.
tree_cache
=
RadixCache
(
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
disable
=
server_args
.
disable_radix_cache
,
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
scheduler
=
PolicyScheduler
(
self
.
schedule_policy
,
self
.
tree_cache
)
self
.
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
self
.
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
running_batch
:
ScheduleBatch
=
None
self
.
out_pyobjs
=
[]
self
.
decode_forward_ct
=
0
self
.
stream_interval
=
server_args
.
stream_interval
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
self
.
is_mixed_chunk
=
(
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
# Init the FSM cache for constrained generation
if
not
server_args
.
skip_tokenizer_init
:
self
.
regex_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
constrained_json_whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
)
self
.
jump_forward_cache
=
JumpForwardCache
()
# Init new token estimation
assert
(
server_args
.
schedule_conservativeness
>=
0
),
"Invalid schedule_conservativeness"
self
.
min_new_token_ratio
=
min
(
global_config
.
base_min_new_token_ratio
*
server_args
.
schedule_conservativeness
,
1.0
,
)
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
do_not_get_new_batch
=
False
@
torch
.
inference_mode
()
def
exposed_step
(
self
,
recv_reqs
:
List
):
try
:
# Recv requests
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
handle_generate_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
(
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
)
):
self
.
handle_embedding_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
self
.
abort_request
(
recv_req
)
elif
isinstance
(
recv_req
,
UpdateWeightReqInput
):
success
,
message
=
self
.
update_weights
(
recv_req
)
self
.
out_pyobjs
.
append
(
UpdateWeightReqOutput
(
success
,
message
))
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
# Forward
self
.
forward_step
()
except
Exception
:
logger
.
error
(
"Exception in ModelTpServer:
\n
"
+
get_exception_traceback
())
raise
# Return results
ret
=
self
.
out_pyobjs
self
.
out_pyobjs
=
[]
return
ret
def
forward_step
(
self
):
if
self
.
do_not_get_new_batch
and
self
.
current_inflight_req
is
None
:
new_batch
=
None
else
:
new_batch
=
self
.
get_new_prefill_batch
()
self
.
do_not_get_new_batch
=
False
if
new_batch
is
not
None
:
# Run a new prefill batch
self
.
forward_prefill_batch
(
new_batch
)
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
new_batch
else
:
self
.
running_batch
.
merge
(
new_batch
)
else
:
# Run a decode batch
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
self
.
num_generated_tokens
+=
len
(
self
.
running_batch
.
reqs
)
self
.
forward_decode_batch
(
self
.
running_batch
)
# Print stats
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_decode_stats
()
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
has_stream
:
break
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
print_decode_stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
throughput
=
self
.
num_generated_tokens
/
(
time
.
time
()
-
self
.
last_stats_tic
)
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
logger
.
info
(
f
"Decode batch. "
f
"#running-req:
{
len
(
self
.
running_batch
.
reqs
)
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
)
def
check_memory
(
self
):
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
available_size
!=
self
.
max_total_num_tokens
:
warnings
.
warn
(
"Warning: "
f
"available_size=
{
available_size
}
, max_total_num_tokens=
{
self
.
max_total_num_tokens
}
\n
"
"KV cache pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
warnings
.
warn
(
"Warning: "
f
"available req slots=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total slots=
{
self
.
req_to_token_pool
.
size
}
\n
"
"Memory pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
def
handle_generate_request
(
self
,
recv_req
:
TokenizedGenerateReqInput
,
):
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
,
lora_path
=
recv_req
.
lora_path
,
)
else
:
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
.
tokenizer
=
self
.
tokenizer
req
.
sampling_params
=
recv_req
.
sampling_params
# Image inputs
if
recv_req
.
image_inputs
is
not
None
:
req
.
image_inputs
=
ImageInputs
.
from_dict
(
recv_req
.
image_inputs
,
self
.
model_config
.
vocab_size
)
req
.
origin_input_ids
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
image_inputs
)
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
if
req
.
logprob_start_len
==
-
1
:
# By default, only return the logprobs for output tokens
req
.
logprob_start_len
=
len
(
recv_req
.
input_ids
)
-
1
# Init regex FSM
if
(
req
.
sampling_params
.
json_schema
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
):
if
req
.
sampling_params
.
json_schema
is
not
None
:
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex_fsm_cache
.
query
(
(
"json"
,
req
.
sampling_params
.
json_schema
)
)
elif
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex_fsm_cache
.
query
(
(
"regex"
,
req
.
sampling_params
.
regex
)
)
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
computed_regex_string
)
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
logger
.
warning
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
req
.
sampling_params
.
max_new_tokens
=
min
(
(
req
.
sampling_params
.
max_new_tokens
if
req
.
sampling_params
.
max_new_tokens
is
not
None
else
1
<<
30
),
self
.
max_req_input_len
-
1
-
len
(
req
.
origin_input_ids
),
)
self
.
waiting_queue
.
append
(
req
)
def
handle_embedding_request
(
self
,
recv_req
:
Union
[
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
],
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
.
tokenizer
=
self
.
tokenizer
req
.
sampling_params
=
recv_req
.
sampling_params
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
logger
.
warning
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
self
.
waiting_queue
.
append
(
req
)
def
get_token_and_memory_info
(
self
):
return
(
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
self
.
max_total_num_tokens
,
running_bs
=
(
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
)
if
running_bs
>=
self
.
max_running_requests
:
return
None
# Get priority queue
prefix_computed
=
self
.
scheduler
.
calc_priority
(
self
.
waiting_queue
)
num_mixed_running
=
running_bs
if
self
.
is_mixed_chunk
else
0
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
running_batch
,
self
.
new_token_ratio
,
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
(),
self
.
max_prefill_tokens
,
self
.
max_prefill_tokens
,
self
.
chunked_prefill_size
,
self
.
max_running_requests
,
num_mixed_running
,
self
.
max_req_input_len
,
)
self
.
random_seed
,
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
self
.
current_inflight_req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
)
if
self
.
lora_paths
is
not
None
:
lora_set
=
(
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
if
self
.
running_batch
is
not
None
else
set
([])
)
for
req
in
self
.
waiting_queue
:
if
(
self
.
lora_paths
is
not
None
and
len
(
lora_set
|
set
([
req
.
lora_path
for
req
in
adder
.
can_run_list
])
|
set
([
req
.
lora_path
])
)
>
self
.
max_loras_per_batch
):
break
if
adder
.
no_remaining_tokens
():
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
)
if
(
not
res
or
running_bs
+
len
(
adder
.
can_run_list
)
>=
self
.
max_running_requests
):
break
can_run_list
=
adder
.
can_run_list
if
adder
.
new_inflight_req
is
not
None
:
assert
self
.
current_inflight_req
is
None
self
.
current_inflight_req
=
adder
.
new_inflight_req
if
len
(
can_run_list
)
==
0
:
return
None
# Print stats
if
self
.
tp_rank
==
0
:
if
isinstance
(
self
.
tree_cache
,
RadixCache
):
self
.
tree_cache_metrics
[
"total"
]
+=
(
adder
.
log_input_tokens
+
adder
.
log_hit_tokens
)
/
10
**
9
self
.
tree_cache_metrics
[
"hit"
]
+=
(
adder
.
log_hit_tokens
)
/
10
**
9
tree_cache_hit_rate
=
(
self
.
tree_cache_metrics
[
"hit"
]
/
self
.
tree_cache_metrics
[
"total"
]
)
else
:
tree_cache_hit_rate
=
0.0
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
num_mixed_running
>
0
:
logger
.
info
(
f
"Prefill batch"
f
"(mixed #running-req:
{
num_mixed_running
}
). "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
else
:
logger
.
info
(
f
"Prefill batch. "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
# Return the new batch
new_batch
=
ScheduleBatch
.
init_new
(
can_run_list
,
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
)
)
self
.
waiting_queue
=
[
x
for
x
in
self
.
waiting_queue
if
x
not
in
can_run_list
]
return
new_batch
def
forward_prefill_batch
(
self
,
batch
:
ScheduleBatch
):
# Build batch tensors
batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
decoding_reqs
=
[]
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
self
.
running_batch
.
prepare_for_decode
()
batch
.
mix_with_running
(
self
.
running_batch
)
decoding_reqs
=
self
.
running_batch
.
reqs
self
.
running_batch
=
None
if
self
.
model_runner
.
is_generation
:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
logits_output
=
self
.
model_runner
.
forward
(
batch
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
)
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
tolist
()
)
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
else
:
if
self
.
tokenizer
is
None
:
next_token_ids
=
[]
for
req
in
batch
.
reqs
:
next_token_ids
.
append
(
next
(
iter
(
req
.
sampling_params
.
stop_token_ids
))
)
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
# Check finish conditions
logprob_pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_ids
[
i
]
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
elif
req
not
in
decoding_reqs
:
# To reduce overhead, only cache prefill reqs
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
if
req
.
return_logprob
:
logprob_pt
+=
self
.
add_logprob_return_values
(
i
,
req
,
logprob_pt
,
next_token_ids
,
logits_output
)
else
:
assert
batch
.
extend_num_tokens
!=
0
logits_output
=
self
.
model_runner
.
forward
(
batch
)
embeddings
=
logits_output
.
embeddings
.
tolist
()
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
embedding
=
embeddings
[
i
]
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
handle_finished_requests
(
batch
)
def
add_logprob_return_values
(
self
,
i
:
int
,
req
:
Req
,
pt
:
int
,
next_token_ids
:
List
[
int
],
output
:
LogitsProcessorOutput
,
):
"""Attach logprobs to the return values."""
req
.
output_token_logprobs
.
append
(
(
output
.
next_token_logprobs
[
i
],
next_token_ids
[
i
])
)
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
num_input_logprobs
=
req
.
extend_input_len
-
req
.
extend_logprob_start_len
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
if
req
.
input_token_logprobs
is
None
:
input_token_logprobs
=
output
.
input_token_logprobs
[
pt
:
pt
+
num_input_logprobs
-
1
-
req
.
last_update_decode_tokens
]
input_token_ids
=
req
.
fill_ids
[
len
(
req
.
fill_ids
)
-
num_input_logprobs
+
1
:
len
(
req
.
fill_ids
)
-
req
.
last_update_decode_tokens
]
req
.
input_token_logprobs
=
list
(
zip
(
input_token_logprobs
,
input_token_ids
))
if
(
req
.
logprob_start_len
==
0
):
# The first token does not have logprob, pad it.
req
.
input_token_logprobs
=
[
(
None
,
req
.
fill_ids
[
0
])
]
+
req
.
input_token_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
# Some decode tokens are re-computed in an extend batch
req
.
output_token_logprobs
.
extend
(
list
(
zip
(
output
.
input_token_logprobs
[
pt
+
num_input_logprobs
-
1
-
req
.
last_update_decode_tokens
:
pt
+
num_input_logprobs
-
1
],
req
.
fill_ids
[
len
(
req
.
fill_ids
)
-
req
.
last_update_decode_tokens
:
len
(
req
.
fill_ids
)
],
)
)
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
input_top_logprobs
is
None
:
req
.
input_top_logprobs
=
output
.
input_top_logprobs
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
input_top_logprobs
=
[
None
]
+
req
.
input_top_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
def
forward_batch_generation
(
self
,
batch
):
req
.
output_top_logprobs
.
extend
(
output
.
input_top_logprobs
[
i
][
-
req
.
last_update_decode_tokens
:]
)
req
.
output_top_logprobs
.
append
(
output
.
output_top_logprobs
[
i
])
return
num_input_logprobs
def
forward_decode_batch
(
self
,
batch
:
ScheduleBatch
):
# Check if decode out of memory
if
not
batch
.
check_decode_mem
():
old_ratio
=
self
.
new_token_ratio
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
()
self
.
new_token_ratio
=
new_token_ratio
logger
.
info
(
"Decode out of memory happened. "
f
"#retracted_reqs:
{
len
(
retracted_reqs
)
}
, "
f
"#new_token_ratio:
{
old_ratio
:.
4
f
}
->
{
self
.
new_token_ratio
:.
4
f
}
"
)
self
.
waiting_queue
.
extend
(
retracted_reqs
)
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
self
.
min_new_token_ratio
,
)
if
not
self
.
disable_regex_jump_forward
:
# Check for jump-forward
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
model_runner
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
return
# Update batch tensors
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
logits_output
=
self
.
model_runner
.
forward
(
batch
)
logits_output
=
self
.
model_runner
.
forward
(
batch
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
batch
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
return
logits_output
,
next_token_ids
next_token_ids
)
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish condition
has_finished
=
False
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_id
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
has_finished
=
True
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
if
not
has_finished
:
def
forward_batch_embedding
(
self
,
batch
):
self
.
do_not_get_new_batch
=
True
logits_output
=
self
.
model_runner
.
forward
(
batch
)
embeddings
=
logits_output
.
embeddings
.
tolist
()
self
.
handle_finished_requests
(
batch
)
return
embeddings
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
output_rids
=
[]
output_meta_info
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
if
self
.
model_runner
.
is_generation
:
output_vids
=
[]
decoded_texts
=
[]
output_read_ids
=
[]
output_read_offsets
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
else
:
# for embedding model
output_embeddings
=
[]
unfinished_indices
=
[]
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
not
req
.
finished
()
and
req
is
not
self
.
current_inflight_req
:
unfinished_indices
.
append
(
i
)
if
req
.
finished
()
or
(
req
.
stream
and
(
self
.
decode_forward_ct
%
self
.
stream_interval
==
0
or
len
(
req
.
output_ids
)
==
1
)
):
output_rids
.
append
(
req
.
rid
)
output_finished_reason
.
append
(
req
.
finished_reason
)
if
self
.
model_runner
.
is_generation
:
output_vids
.
append
(
req
.
vid
)
decoded_texts
.
append
(
req
.
decoded_text
)
read_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
output_read_ids
.
append
(
read_ids
)
output_read_offsets
.
append
(
read_offset
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
output_spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
is
not
None
else
None
),
}
if
req
.
return_logprob
:
(
meta_info
[
"input_token_logprobs"
],
meta_info
[
"output_token_logprobs"
],
meta_info
[
"input_top_logprobs"
],
meta_info
[
"output_top_logprobs"
],
meta_info
[
"normalized_prompt_logprob"
],
)
=
(
req
.
input_token_logprobs
,
req
.
output_token_logprobs
,
req
.
input_top_logprobs
,
req
.
output_top_logprobs
,
req
.
normalized_prompt_logprob
,
)
output_meta_info
.
append
(
meta_info
)
else
:
# for embedding model
output_embeddings
.
append
(
req
.
embedding
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
}
output_meta_info
.
append
(
meta_info
)
# Send to detokenizer
if
output_rids
:
if
self
.
model_runner
.
is_generation
:
self
.
out_pyobjs
.
append
(
BatchTokenIDOut
(
output_rids
,
output_vids
,
decoded_texts
,
output_read_ids
,
output_read_offsets
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
output_meta_info
,
output_finished_reason
,
)
)
else
:
# for embedding model
self
.
out_pyobjs
.
append
(
BatchEmbeddingOut
(
output_rids
,
output_embeddings
,
output_meta_info
,
output_finished_reason
,
)
)
# Remove finished reqs: update batch tensors
batch
.
filter_batch
(
unfinished_indices
)
def
flush_cache
(
self
):
if
len
(
self
.
waiting_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
regex_fsm_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
if_success
=
True
else
:
logging
.
warning
(
f
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
if_success
=
False
return
if_success
def
abort_request
(
self
,
recv_req
):
# Delete requests in the waiting queue
to_del
=
None
for
i
,
req
in
enumerate
(
self
.
waiting_queue
):
if
req
.
rid
==
recv_req
.
rid
:
to_del
=
i
break
if
to_del
is
not
None
:
del
self
.
waiting_queue
[
to_del
]
# Delete requests in the running batch
if
self
.
running_batch
:
for
req
in
self
.
running_batch
.
reqs
:
if
req
.
rid
==
recv_req
.
rid
:
req
.
finished_reason
=
FINISH_ABORT
()
break
def
update_weights
(
self
,
recv_req
):
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights
(
success
,
message
=
self
.
model_runner
.
update_weights
(
recv_req
.
model_path
,
recv_req
.
load_format
recv_req
.
model_path
,
recv_req
.
load_format
)
)
if
success
:
flash_cache_success
=
self
.
flush_cache
()
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
success
,
message
return
success
,
message
python/sglang/srt/mem_cache/memory_pool.py
View file @
f86c1e61
...
@@ -27,11 +27,11 @@ logger = logging.getLogger(__name__)
...
@@ -27,11 +27,11 @@ logger = logging.getLogger(__name__)
class
ReqToTokenPool
:
class
ReqToTokenPool
:
"""A memory pool that maps a request to its token locations."""
"""A memory pool that maps a request to its token locations."""
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
):
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
):
self
.
size
=
size
self
.
size
=
size
self
.
free_slots
=
list
(
range
(
size
))
self
.
free_slots
=
list
(
range
(
size
))
self
.
req_to_token
=
torch
.
empty
(
self
.
req_to_token
=
torch
.
empty
(
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
)
)
def
alloc
(
self
,
need_size
:
int
)
->
List
[
int
]:
def
alloc
(
self
,
need_size
:
int
)
->
List
[
int
]:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
f86c1e61
...
@@ -87,6 +87,7 @@ class ModelRunner:
...
@@ -87,6 +87,7 @@ class ModelRunner:
self
.
model_config
.
hf_config
.
architectures
self
.
model_config
.
hf_config
.
architectures
)
)
# Model-specific adjustment
if
(
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
and
not
self
.
server_args
.
disable_mla
and
not
self
.
server_args
.
disable_mla
...
@@ -94,6 +95,13 @@ class ModelRunner:
...
@@ -94,6 +95,13 @@ class ModelRunner:
logger
.
info
(
"MLA optimization is tunred on. Use triton backend."
)
logger
.
info
(
"MLA optimization is tunred on. Use triton backend."
)
self
.
server_args
.
attention_backend
=
"triton"
self
.
server_args
.
attention_backend
=
"triton"
if
self
.
is_multimodal_model
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args
.
chunked_prefill_size
=
None
server_args
.
mem_fraction_static
*=
0.95
global_server_args_dict
.
update
(
global_server_args_dict
.
update
(
{
{
"attention_backend"
:
server_args
.
attention_backend
,
"attention_backend"
:
server_args
.
attention_backend
,
...
@@ -104,14 +112,6 @@ class ModelRunner:
...
@@ -104,14 +112,6 @@ class ModelRunner:
}
}
)
)
# Model-specific adjustment
if
self
.
is_multimodal_model
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args
.
chunked_prefill_size
=
None
server_args
.
mem_fraction_static
*=
0.95
# Init componnets
# Init componnets
min_per_gpu_memory
=
self
.
init_torch_distributed
()
min_per_gpu_memory
=
self
.
init_torch_distributed
()
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
@@ -400,8 +400,7 @@ class ModelRunner:
...
@@ -400,8 +400,7 @@ class ModelRunner:
)
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
self
.
req_to_token_pool
=
ReqToTokenPool
(
max_num_reqs
+
1
,
max_num_reqs
+
1
,
self
.
model_config
.
context_len
+
4
,
device
=
"cuda"
self
.
model_config
.
context_len
+
4
,
)
)
if
(
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment