Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f86c1e61
"docs/vscode:/vscode.git/clone" did not exist on "1df9e20179948d712ddd86e26ce935c9cb64b86d"
Unverified
Commit
f86c1e61
authored
Sep 29, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 29, 2024
Browse files
Move scheduler code from tp_worker.py to scheduler.py (#1538)
parent
acaffd23
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
933 additions
and
870 deletions
+933
-870
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+12
-4
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+3
-8
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+879
-8
python/sglang/srt/managers/scheduler_policy.py
python/sglang/srt/managers/scheduler_policy.py
+2
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+21
-832
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+2
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+9
-10
No files found.
python/sglang/bench_latency.py
View file @
f86c1e61
...
...
@@ -167,9 +167,13 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
assert
len
(
input_ids
[
i
])
>
bench_args
.
cut_len
tmp_input_ids
=
input_ids
[
i
][:
bench_args
.
cut_len
]
req
=
Req
(
rid
=
i
,
origin_input_text
=
prompts
[
i
],
origin_input_ids
=
tmp_input_ids
)
req
=
Req
(
rid
=
i
,
origin_input_text
=
prompts
[
i
],
origin_input_ids
=
tmp_input_ids
,
sampling_params
=
sampling_params
,
)
req
.
prefix_indices
=
[]
req
.
sampling_params
=
sampling_params
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
reqs
.
append
(
req
)
...
...
@@ -199,9 +203,13 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
reqs
=
[]
for
i
in
range
(
len
(
input_ids
)):
req
=
Req
(
rid
=
i
,
origin_input_text
=
""
,
origin_input_ids
=
list
(
input_ids
[
i
]))
req
=
Req
(
rid
=
i
,
origin_input_text
=
""
,
origin_input_ids
=
list
(
input_ids
[
i
]),
sampling_params
=
sampling_params
,
)
req
.
prefix_indices
=
[]
req
.
sampling_params
=
sampling_params
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
reqs
.
append
(
req
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
f86c1e61
...
...
@@ -18,7 +18,6 @@ The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
"""
import
copy
import
uuid
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
...
...
@@ -53,12 +52,12 @@ class GenerateReqInput:
stream
:
bool
=
False
# The modalities of the image data [image, multi-images, video]
modalities
:
Optional
[
List
[
str
]]
=
None
is_single
:
bool
=
True
# LoRA related
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
...
...
@@ -307,10 +306,6 @@ class BatchTokenIDOut:
meta_info
:
List
[
Dict
]
finished_reason
:
List
[
BaseFinishReason
]
def
__post_init__
(
self
):
# deepcopy meta_info to avoid modification in place
self
.
meta_info
=
copy
.
deepcopy
(
self
.
meta_info
)
@
dataclass
class
BatchStrOut
:
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
f86c1e61
...
...
@@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
...
@@ -143,6 +144,7 @@ class Req:
rid
:
str
,
origin_input_text
:
str
,
origin_input_ids
:
Tuple
[
int
],
sampling_params
:
SamplingParams
,
lora_path
:
Optional
[
str
]
=
None
,
):
# Input and output info
...
...
@@ -152,6 +154,8 @@ class Req:
self
.
origin_input_ids
=
origin_input_ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
self
.
sampling_params
=
sampling_params
self
.
lora_path
=
lora_path
# Memory info
...
...
@@ -160,6 +164,7 @@ class Req:
# Check finish
self
.
tokenizer
=
None
self
.
finished_reason
=
None
self
.
stream
=
False
# For incremental decoding
# ----- | --------- read_ids -------|
...
...
@@ -187,10 +192,6 @@ class Req:
self
.
extend_input_len
=
0
self
.
last_node
=
None
# Sampling parameters
self
.
sampling_params
=
None
self
.
stream
=
False
# Logprobs (arguments)
self
.
return_logprob
=
False
self
.
logprob_start_len
=
0
...
...
python/sglang/srt/managers/scheduler.py
View file @
f86c1e61
...
...
@@ -15,18 +15,62 @@ limitations under the License.
"""A scheduler that manages a tensor parallel GPU worker."""
import
json
import
logging
import
multiprocessing
import
os
import
time
import
warnings
from
typing
import
List
,
Optional
,
Union
import
torch
import
zmq
from
sglang.srt.managers.tp_worker
import
ModelTpServer
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
FlushCacheReq
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedRewardReqInput
,
UpdateWeightReqInput
,
UpdateWeightReqOutput
,
)
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
BaseFinishReason
,
ImageInputs
,
Req
,
ScheduleBatch
,
)
from
sglang.srt.managers.scheduler_policy
import
PrefillAdder
,
SchedulerPolicy
from
sglang.srt.managers.tp_worker
import
ModelTpWorker
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
broadcast_pyobj
,
configure_logger
,
kill_parent_process
from
sglang.srt.utils
import
(
broadcast_pyobj
,
configure_logger
,
is_generation_model
,
is_multimodal_model
,
kill_parent_process
,
set_random_seed
,
suppress_other_loggers
,
)
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
# Crash on warning if we are running CI tests
crash_on_warning
=
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
class
Scheduler
:
"""A scheduler that manages a tensor parallel GPU worker."""
...
...
@@ -39,8 +83,13 @@ class Scheduler:
tp_rank
:
int
,
):
# Parse args
self
.
server_args
=
server_args
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
server_args
.
tp_size
self
.
schedule_policy
=
server_args
.
schedule_policy
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
self
.
lora_paths
=
server_args
.
lora_paths
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
...
...
@@ -54,30 +103,146 @@ class Scheduler:
f
"tcp://127.0.0.1:
{
port_args
.
detokenizer_port
}
"
)
else
:
self
.
send_to_detokenizer
=
None
self
.
recv_from_tokenizer
=
self
.
send_to_detokenizer
=
None
# Init tokenizer
self
.
model_config
=
ModelConfig
(
server_args
.
model_path
,
server_args
.
trust_remote_code
,
context_length
=
server_args
.
context_length
,
model_override_args
=
json
.
loads
(
server_args
.
json_model_override_args
),
)
if
server_args
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
processor
=
None
else
:
if
is_multimodal_model
(
self
.
model_config
.
hf_config
.
architectures
):
self
.
processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
self
.
tokenizer
=
self
.
processor
.
tokenizer
else
:
self
.
tokenizer
=
get_tokenizer
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
self
.
is_generation
=
is_generation_model
(
self
.
model_config
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
)
# Launch a t
p serv
er
self
.
tp_
serv
er
=
ModelTp
Serv
er
(
# Launch a t
ensor parallel work
er
self
.
tp_
work
er
=
ModelTp
Work
er
(
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
server_args
=
server_args
,
nccl_port
=
port_args
.
nccl_ports
[
0
],
)
self
.
tp_cpu_group
=
self
.
tp_server
.
model_runner
.
tp_group
.
cpu_group
self
.
tp_cpu_group
=
self
.
tp_worker
.
model_runner
.
tp_group
.
cpu_group
# Get token and memory info from the tp worker
(
self
.
max_total_num_tokens
,
self
.
max_prefill_tokens
,
self
.
max_running_requests
,
self
.
max_req_input_len
,
self
.
random_seed
,
)
=
self
.
tp_worker
.
get_token_and_memory_info
()
set_random_seed
(
self
.
random_seed
)
# Print debug info
logger
.
info
(
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
"
)
# Init cache
self
.
req_to_token_pool
=
self
.
tp_worker
.
model_runner
.
req_to_token_pool
self
.
token_to_kv_pool
=
self
.
tp_worker
.
model_runner
.
token_to_kv_pool
if
(
server_args
.
chunked_prefill_size
is
not
None
and
server_args
.
disable_radix_cache
):
self
.
tree_cache
=
ChunkCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
)
else
:
self
.
tree_cache
=
RadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
disable
=
server_args
.
disable_radix_cache
,
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
policy
=
SchedulerPolicy
(
self
.
schedule_policy
,
self
.
tree_cache
)
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
running_batch
:
ScheduleBatch
=
None
self
.
out_pyobjs
=
[]
self
.
decode_forward_ct
=
0
self
.
stream_interval
=
server_args
.
stream_interval
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
self
.
is_mixed_chunk
=
(
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
# Init the FSM cache for constrained generation
if
not
server_args
.
skip_tokenizer_init
:
self
.
regex_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
constrained_json_whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
)
self
.
jump_forward_cache
=
JumpForwardCache
()
# Init new token estimation
assert
(
server_args
.
schedule_conservativeness
>=
0
),
"Invalid schedule_conservativeness"
self
.
min_new_token_ratio
=
min
(
global_config
.
base_min_new_token_ratio
*
server_args
.
schedule_conservativeness
,
1.0
,
)
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
do_not_get_new_batch
=
False
def
event_loop
(
self
):
while
True
:
# Receive requests
if
self
.
tp_rank
==
0
:
recv_reqs
=
self
.
recv_requests_from_zmq
()
else
:
recv_reqs
=
None
# Process requests
recv_reqs
=
broadcast_pyobj
(
recv_reqs
,
self
.
tp_rank
,
self
.
tp_cpu_group
)
out_pyobjs
=
self
.
tp_server
.
exposed_step
(
recv_reqs
)
self
.
process_requests
(
recv_reqs
)
# Forward
self
.
forward_step
()
# Send results
if
self
.
tp_rank
==
0
:
for
obj
in
out_pyobjs
:
for
obj
in
self
.
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
out_pyobjs
=
[]
def
recv_requests_from_zmq
(
self
):
recv_reqs
=
[]
...
...
@@ -91,6 +256,711 @@ class Scheduler:
return
recv_reqs
def
process_requests
(
self
,
recv_reqs
:
List
):
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
handle_generate_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
(
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
)
):
self
.
handle_embedding_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
self
.
abort_request
(
recv_req
)
elif
isinstance
(
recv_req
,
UpdateWeightReqInput
):
success
,
message
=
self
.
update_weights
(
recv_req
)
self
.
out_pyobjs
.
append
(
UpdateWeightReqOutput
(
success
,
message
))
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
@
torch
.
inference_mode
()
def
forward_step
(
self
):
if
self
.
do_not_get_new_batch
and
self
.
current_inflight_req
is
None
:
new_batch
=
None
else
:
new_batch
=
self
.
get_new_prefill_batch
()
self
.
do_not_get_new_batch
=
False
if
new_batch
is
not
None
:
# Run a new prefill batch
self
.
forward_prefill_batch
(
new_batch
)
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
new_batch
else
:
self
.
running_batch
.
merge
(
new_batch
)
else
:
# Run a decode batch
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
self
.
num_generated_tokens
+=
len
(
self
.
running_batch
.
reqs
)
self
.
forward_decode_batch
(
self
.
running_batch
)
# Print stats
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_decode_stats
()
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
has_stream
:
break
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
print_decode_stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
throughput
=
self
.
num_generated_tokens
/
(
time
.
time
()
-
self
.
last_stats_tic
)
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
logger
.
info
(
f
"Decode batch. "
f
"#running-req:
{
len
(
self
.
running_batch
.
reqs
)
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
)
def
check_memory
(
self
):
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
available_size
!=
self
.
max_total_num_tokens
:
warnings
.
warn
(
"Warning: "
f
"available_size=
{
available_size
}
, max_total_num_tokens=
{
self
.
max_total_num_tokens
}
\n
"
"KV cache pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
warnings
.
warn
(
"Warning: "
f
"available req slots=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total slots=
{
self
.
req_to_token_pool
.
size
}
\n
"
"Memory pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
def
handle_generate_request
(
self
,
recv_req
:
TokenizedGenerateReqInput
,
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
,
recv_req
.
sampling_params
,
lora_path
=
recv_req
.
lora_path
,
)
req
.
tokenizer
=
self
.
tokenizer
# Image inputs
if
recv_req
.
image_inputs
is
not
None
:
req
.
image_inputs
=
ImageInputs
.
from_dict
(
recv_req
.
image_inputs
,
self
.
model_config
.
vocab_size
)
req
.
origin_input_ids
=
self
.
tp_worker
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
image_inputs
)
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
if
req
.
logprob_start_len
==
-
1
:
# By default, only return the logprobs for output tokens
req
.
logprob_start_len
=
len
(
recv_req
.
input_ids
)
-
1
# Init regex FSM
if
(
req
.
sampling_params
.
json_schema
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
):
if
req
.
sampling_params
.
json_schema
is
not
None
:
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex_fsm_cache
.
query
(
(
"json"
,
req
.
sampling_params
.
json_schema
)
)
elif
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex_fsm_cache
.
query
(
(
"regex"
,
req
.
sampling_params
.
regex
)
)
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
computed_regex_string
)
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
logger
.
warning
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
req
.
sampling_params
.
max_new_tokens
=
min
(
(
req
.
sampling_params
.
max_new_tokens
if
req
.
sampling_params
.
max_new_tokens
is
not
None
else
1
<<
30
),
self
.
max_req_input_len
-
1
-
len
(
req
.
origin_input_ids
),
)
self
.
waiting_queue
.
append
(
req
)
def
handle_embedding_request
(
self
,
recv_req
:
Union
[
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
],
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
,
recv_req
.
sampling_params
,
)
req
.
tokenizer
=
self
.
tokenizer
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
logger
.
warning
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
self
.
waiting_queue
.
append
(
req
)
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
running_bs
=
(
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
)
if
running_bs
>=
self
.
max_running_requests
:
return
None
# Get priority queue
prefix_computed
=
self
.
policy
.
calc_priority
(
self
.
waiting_queue
)
num_mixed_running
=
running_bs
if
self
.
is_mixed_chunk
else
0
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
running_batch
,
self
.
new_token_ratio
,
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
(),
self
.
max_prefill_tokens
,
self
.
chunked_prefill_size
,
num_mixed_running
,
)
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
self
.
current_inflight_req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
)
if
self
.
lora_paths
is
not
None
:
lora_set
=
(
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
if
self
.
running_batch
is
not
None
else
set
([])
)
for
req
in
self
.
waiting_queue
:
if
(
self
.
lora_paths
is
not
None
and
len
(
lora_set
|
set
([
req
.
lora_path
for
req
in
adder
.
can_run_list
])
|
set
([
req
.
lora_path
])
)
>
self
.
max_loras_per_batch
):
break
if
adder
.
no_remaining_tokens
():
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
)
if
(
not
res
or
running_bs
+
len
(
adder
.
can_run_list
)
>=
self
.
max_running_requests
):
break
can_run_list
=
adder
.
can_run_list
if
adder
.
new_inflight_req
is
not
None
:
assert
self
.
current_inflight_req
is
None
self
.
current_inflight_req
=
adder
.
new_inflight_req
if
len
(
can_run_list
)
==
0
:
return
None
# Print stats
if
self
.
tp_rank
==
0
:
if
isinstance
(
self
.
tree_cache
,
RadixCache
):
self
.
tree_cache_metrics
[
"total"
]
+=
(
adder
.
log_input_tokens
+
adder
.
log_hit_tokens
)
/
10
**
9
self
.
tree_cache_metrics
[
"hit"
]
+=
(
adder
.
log_hit_tokens
)
/
10
**
9
tree_cache_hit_rate
=
(
self
.
tree_cache_metrics
[
"hit"
]
/
self
.
tree_cache_metrics
[
"total"
]
)
else
:
tree_cache_hit_rate
=
0.0
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
num_mixed_running
>
0
:
logger
.
info
(
f
"Prefill batch"
f
"(mixed #running-req:
{
num_mixed_running
}
). "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
else
:
logger
.
info
(
f
"Prefill batch. "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
# Return the new batch
new_batch
=
ScheduleBatch
.
init_new
(
can_run_list
,
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
)
self
.
waiting_queue
=
[
x
for
x
in
self
.
waiting_queue
if
x
not
in
can_run_list
]
return
new_batch
def
forward_prefill_batch
(
self
,
batch
:
ScheduleBatch
):
# Build batch tensors
batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
decoding_reqs
=
[]
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
self
.
running_batch
.
prepare_for_decode
()
batch
.
mix_with_running
(
self
.
running_batch
)
decoding_reqs
=
self
.
running_batch
.
reqs
self
.
running_batch
=
None
if
self
.
is_generation
:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
)
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
tolist
()
)
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
else
:
if
self
.
tokenizer
is
None
:
next_token_ids
=
[]
for
req
in
batch
.
reqs
:
next_token_ids
.
append
(
next
(
iter
(
req
.
sampling_params
.
stop_token_ids
))
)
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
# Check finish conditions
logprob_pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_ids
[
i
]
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
elif
req
not
in
decoding_reqs
:
# To reduce overhead, only cache prefill reqs
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
if
req
.
return_logprob
:
logprob_pt
+=
self
.
add_logprob_return_values
(
i
,
req
,
logprob_pt
,
next_token_ids
,
logits_output
)
else
:
assert
batch
.
extend_num_tokens
!=
0
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
batch
)
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
embedding
=
embeddings
[
i
]
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
handle_finished_requests
(
batch
)
def
add_logprob_return_values
(
self
,
i
:
int
,
req
:
Req
,
pt
:
int
,
next_token_ids
:
List
[
int
],
output
:
LogitsProcessorOutput
,
):
"""Attach logprobs to the return values."""
req
.
output_token_logprobs
.
append
(
(
output
.
next_token_logprobs
[
i
],
next_token_ids
[
i
])
)
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
num_input_logprobs
=
req
.
extend_input_len
-
req
.
extend_logprob_start_len
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
if
req
.
input_token_logprobs
is
None
:
input_token_logprobs
=
output
.
input_token_logprobs
[
pt
:
pt
+
num_input_logprobs
-
1
-
req
.
last_update_decode_tokens
]
input_token_ids
=
req
.
fill_ids
[
len
(
req
.
fill_ids
)
-
num_input_logprobs
+
1
:
len
(
req
.
fill_ids
)
-
req
.
last_update_decode_tokens
]
req
.
input_token_logprobs
=
list
(
zip
(
input_token_logprobs
,
input_token_ids
))
if
(
req
.
logprob_start_len
==
0
):
# The first token does not have logprob, pad it.
req
.
input_token_logprobs
=
[
(
None
,
req
.
fill_ids
[
0
])
]
+
req
.
input_token_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
# Some decode tokens are re-computed in an extend batch
req
.
output_token_logprobs
.
extend
(
list
(
zip
(
output
.
input_token_logprobs
[
pt
+
num_input_logprobs
-
1
-
req
.
last_update_decode_tokens
:
pt
+
num_input_logprobs
-
1
],
req
.
fill_ids
[
len
(
req
.
fill_ids
)
-
req
.
last_update_decode_tokens
:
len
(
req
.
fill_ids
)
],
)
)
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
input_top_logprobs
is
None
:
req
.
input_top_logprobs
=
output
.
input_top_logprobs
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
input_top_logprobs
=
[
None
]
+
req
.
input_top_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
output_top_logprobs
.
extend
(
output
.
input_top_logprobs
[
i
][
-
req
.
last_update_decode_tokens
:]
)
req
.
output_top_logprobs
.
append
(
output
.
output_top_logprobs
[
i
])
return
num_input_logprobs
def
forward_decode_batch
(
self
,
batch
:
ScheduleBatch
):
# Check if decode out of memory
if
not
batch
.
check_decode_mem
():
old_ratio
=
self
.
new_token_ratio
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
()
self
.
new_token_ratio
=
new_token_ratio
logger
.
info
(
"Decode out of memory happened. "
f
"#retracted_reqs:
{
len
(
retracted_reqs
)
}
, "
f
"#new_token_ratio:
{
old_ratio
:.
4
f
}
->
{
self
.
new_token_ratio
:.
4
f
}
"
)
self
.
waiting_queue
.
extend
(
retracted_reqs
)
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
self
.
min_new_token_ratio
,
)
# Check for jump-forward
if
not
self
.
disable_regex_jump_forward
:
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
tp_worker
.
model_runner
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
return
# Update batch tensors
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish condition
has_finished
=
False
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_id
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
has_finished
=
True
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
if
not
has_finished
:
self
.
do_not_get_new_batch
=
True
self
.
handle_finished_requests
(
batch
)
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
output_rids
=
[]
output_meta_info
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
if
self
.
is_generation
:
output_vids
=
[]
decoded_texts
=
[]
output_read_ids
=
[]
output_read_offsets
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
else
:
# for embedding model
output_embeddings
=
[]
unfinished_indices
=
[]
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
not
req
.
finished
()
and
req
is
not
self
.
current_inflight_req
:
unfinished_indices
.
append
(
i
)
if
req
.
finished
()
or
(
req
.
stream
and
(
self
.
decode_forward_ct
%
self
.
stream_interval
==
0
or
len
(
req
.
output_ids
)
==
1
)
):
output_rids
.
append
(
req
.
rid
)
output_finished_reason
.
append
(
req
.
finished_reason
)
if
self
.
is_generation
:
output_vids
.
append
(
req
.
vid
)
decoded_texts
.
append
(
req
.
decoded_text
)
read_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
output_read_ids
.
append
(
read_ids
)
output_read_offsets
.
append
(
read_offset
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
output_spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
is
not
None
else
None
),
}
if
req
.
return_logprob
:
(
meta_info
[
"input_token_logprobs"
],
meta_info
[
"output_token_logprobs"
],
meta_info
[
"input_top_logprobs"
],
meta_info
[
"output_top_logprobs"
],
meta_info
[
"normalized_prompt_logprob"
],
)
=
(
req
.
input_token_logprobs
,
req
.
output_token_logprobs
,
req
.
input_top_logprobs
,
req
.
output_top_logprobs
,
req
.
normalized_prompt_logprob
,
)
output_meta_info
.
append
(
meta_info
)
else
:
# for embedding model
output_embeddings
.
append
(
req
.
embedding
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
}
output_meta_info
.
append
(
meta_info
)
# Send to detokenizer
if
output_rids
:
if
self
.
is_generation
:
self
.
out_pyobjs
.
append
(
BatchTokenIDOut
(
output_rids
,
output_vids
,
decoded_texts
,
output_read_ids
,
output_read_offsets
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
output_meta_info
,
output_finished_reason
,
)
)
else
:
# for embedding model
self
.
out_pyobjs
.
append
(
BatchEmbeddingOut
(
output_rids
,
output_embeddings
,
output_meta_info
,
output_finished_reason
,
)
)
# Remove finished reqs: update batch tensors
batch
.
filter_batch
(
unfinished_indices
)
def
flush_cache
(
self
):
if
len
(
self
.
waiting_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
regex_fsm_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
if_success
=
True
else
:
logging
.
warning
(
f
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
if_success
=
False
return
if_success
def
abort_request
(
self
,
recv_req
:
AbortReq
):
# Delete requests in the waiting queue
to_del
=
None
for
i
,
req
in
enumerate
(
self
.
waiting_queue
):
if
req
.
rid
==
recv_req
.
rid
:
to_del
=
i
break
if
to_del
is
not
None
:
del
self
.
waiting_queue
[
to_del
]
# Delete requests in the running batch
if
self
.
running_batch
:
for
req
in
self
.
running_batch
.
reqs
:
if
req
.
rid
==
recv_req
.
rid
:
req
.
finished_reason
=
FINISH_ABORT
()
break
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
success
,
message
=
self
.
tp_worker
.
update_weights
(
recv_req
)
if
success
:
flash_cache_success
=
self
.
flush_cache
()
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
success
,
message
def
run_scheduler_process
(
server_args
:
ServerArgs
,
...
...
@@ -100,6 +970,7 @@ def run_scheduler_process(
pipe_writer
:
multiprocessing
.
connection
.
Connection
,
):
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
suppress_other_loggers
()
try
:
scheduler
=
Scheduler
(
server_args
,
port_args
,
gpu_id
,
tp_rank
)
...
...
python/sglang/srt/managers/
policy_
scheduler.py
→
python/sglang/srt/managers/scheduler
_policy
.py
View file @
f86c1e61
...
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Request
policy
scheduler"""
"""Request scheduler
policy
"""
import
os
import
random
...
...
@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
CLIP_MAX_NEW_TOKENS
=
int
(
os
.
environ
.
get
(
"SGLANG_CLIP_MAX_NEW_TOKENS"
,
"4096"
))
class
Policy
Scheduler
:
class
Scheduler
Policy
:
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
...
...
python/sglang/srt/managers/tp_worker.py
View file @
f86c1e61
...
...
@@ -17,58 +17,18 @@ limitations under the License.
import
json
import
logging
import
os
import
time
import
warnings
from
typing
import
List
,
Optional
,
Union
import
torch
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
FlushCacheReq
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedRewardReqInput
,
UpdateWeightReqInput
,
UpdateWeightReqOutput
,
)
from
sglang.srt.managers.policy_scheduler
import
PolicyScheduler
,
PrefillAdder
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
BaseFinishReason
,
ImageInputs
,
Req
,
ScheduleBatch
,
)
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.managers.io_struct
import
UpdateWeightReqInput
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
broadcast_pyobj
,
is_multimodal_model
,
set_random_seed
,
suppress_other_loggers
,
)
from
sglang.utils
import
get_exception_traceback
from
sglang.srt.utils
import
broadcast_pyobj
,
is_multimodal_model
,
set_random_seed
logger
=
logging
.
getLogger
(
__name__
)
# Crash on warning if we are running CI tests
crash_on_warning
=
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
class
ModelTpServer
:
class
ModelTpWorker
:
def
__init__
(
self
,
gpu_id
:
int
,
...
...
@@ -76,17 +36,8 @@ class ModelTpServer:
server_args
:
ServerArgs
,
nccl_port
:
int
,
):
suppress_other_loggers
()
# Parse arguments
self
.
gpu_id
=
gpu_id
# Parse args
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
server_args
.
tp_size
self
.
dp_size
=
server_args
.
dp_size
self
.
schedule_policy
=
server_args
.
schedule_policy
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
self
.
lora_paths
=
server_args
.
lora_paths
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
# Init model and tokenizer
self
.
model_config
=
ModelConfig
(
...
...
@@ -120,6 +71,8 @@ class ModelTpServer:
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
# Profile number of tokens
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_prefill_tokens
=
server_args
.
max_prefill_tokens
self
.
max_running_requests
=
min
(
...
...
@@ -136,798 +89,34 @@ class ModelTpServer:
)
# Sync random seed across TP workers
se
rver_args
.
random_seed
=
broadcast_pyobj
(
se
lf
.
random_seed
=
broadcast_pyobj
(
[
server_args
.
random_seed
],
self
.
tp_rank
,
self
.
model_runner
.
tp_group
.
cpu_group
,
)[
0
]
set_random_seed
(
server_args
.
random_seed
)
# Print debug info
logger
.
info
(
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
"
)
# Init cache
if
(
server_args
.
chunked_prefill_size
is
not
None
and
server_args
.
disable_radix_cache
):
self
.
tree_cache
=
ChunkCache
(
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
)
else
:
self
.
tree_cache
=
RadixCache
(
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
disable
=
server_args
.
disable_radix_cache
,
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
scheduler
=
PolicyScheduler
(
self
.
schedule_policy
,
self
.
tree_cache
)
self
.
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
self
.
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
running_batch
:
ScheduleBatch
=
None
self
.
out_pyobjs
=
[]
self
.
decode_forward_ct
=
0
self
.
stream_interval
=
server_args
.
stream_interval
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
self
.
is_mixed_chunk
=
(
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
# Init the FSM cache for constrained generation
if
not
server_args
.
skip_tokenizer_init
:
self
.
regex_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
constrained_json_whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
)
self
.
jump_forward_cache
=
JumpForwardCache
()
# Init new token estimation
assert
(
server_args
.
schedule_conservativeness
>=
0
),
"Invalid schedule_conservativeness"
self
.
min_new_token_ratio
=
min
(
global_config
.
base_min_new_token_ratio
*
server_args
.
schedule_conservativeness
,
1.0
,
)
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
do_not_get_new_batch
=
False
@
torch
.
inference_mode
()
def
exposed_step
(
self
,
recv_reqs
:
List
):
try
:
# Recv requests
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
handle_generate_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
(
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
)
):
self
.
handle_embedding_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
self
.
abort_request
(
recv_req
)
elif
isinstance
(
recv_req
,
UpdateWeightReqInput
):
success
,
message
=
self
.
update_weights
(
recv_req
)
self
.
out_pyobjs
.
append
(
UpdateWeightReqOutput
(
success
,
message
))
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
# Forward
self
.
forward_step
()
except
Exception
:
logger
.
error
(
"Exception in ModelTpServer:
\n
"
+
get_exception_traceback
())
raise
# Return results
ret
=
self
.
out_pyobjs
self
.
out_pyobjs
=
[]
return
ret
def
forward_step
(
self
):
if
self
.
do_not_get_new_batch
and
self
.
current_inflight_req
is
None
:
new_batch
=
None
else
:
new_batch
=
self
.
get_new_prefill_batch
()
self
.
do_not_get_new_batch
=
False
if
new_batch
is
not
None
:
# Run a new prefill batch
self
.
forward_prefill_batch
(
new_batch
)
if
not
new_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
self
.
running_batch
=
new_batch
else
:
self
.
running_batch
.
merge
(
new_batch
)
else
:
# Run a decode batch
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
self
.
num_generated_tokens
+=
len
(
self
.
running_batch
.
reqs
)
self
.
forward_decode_batch
(
self
.
running_batch
)
# Print stats
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_decode_stats
()
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
break
if
self
.
out_pyobjs
and
self
.
running_batch
.
has_stream
:
break
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
print_decode_stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
throughput
=
self
.
num_generated_tokens
/
(
time
.
time
()
-
self
.
last_stats_tic
)
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
logger
.
info
(
f
"Decode batch. "
f
"#running-req:
{
len
(
self
.
running_batch
.
reqs
)
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
)
def
check_memory
(
self
):
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
available_size
!=
self
.
max_total_num_tokens
:
warnings
.
warn
(
"Warning: "
f
"available_size=
{
available_size
}
, max_total_num_tokens=
{
self
.
max_total_num_tokens
}
\n
"
"KV cache pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
warnings
.
warn
(
"Warning: "
f
"available req slots=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total slots=
{
self
.
req_to_token_pool
.
size
}
\n
"
"Memory pool leak detected!"
)
exit
(
1
)
if
crash_on_warning
else
None
def
handle_generate_request
(
self
,
recv_req
:
TokenizedGenerateReqInput
,
):
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
,
lora_path
=
recv_req
.
lora_path
,
)
else
:
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
.
tokenizer
=
self
.
tokenizer
req
.
sampling_params
=
recv_req
.
sampling_params
# Image inputs
if
recv_req
.
image_inputs
is
not
None
:
req
.
image_inputs
=
ImageInputs
.
from_dict
(
recv_req
.
image_inputs
,
self
.
model_config
.
vocab_size
)
req
.
origin_input_ids
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
image_inputs
)
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
if
req
.
logprob_start_len
==
-
1
:
# By default, only return the logprobs for output tokens
req
.
logprob_start_len
=
len
(
recv_req
.
input_ids
)
-
1
# Init regex FSM
if
(
req
.
sampling_params
.
json_schema
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
):
if
req
.
sampling_params
.
json_schema
is
not
None
:
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex_fsm_cache
.
query
(
(
"json"
,
req
.
sampling_params
.
json_schema
)
)
elif
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex_fsm_cache
.
query
(
(
"regex"
,
req
.
sampling_params
.
regex
)
)
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
computed_regex_string
)
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
logger
.
warning
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
req
.
sampling_params
.
max_new_tokens
=
min
(
(
req
.
sampling_params
.
max_new_tokens
if
req
.
sampling_params
.
max_new_tokens
is
not
None
else
1
<<
30
),
self
.
max_req_input_len
-
1
-
len
(
req
.
origin_input_ids
),
)
self
.
waiting_queue
.
append
(
req
)
def
handle_embedding_request
(
self
,
recv_req
:
Union
[
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
],
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
.
tokenizer
=
self
.
tokenizer
req
.
sampling_params
=
recv_req
.
sampling_params
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
logger
.
warning
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
set_random_seed
(
self
.
random_seed
)
self
.
waiting_queue
.
append
(
req
)
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
running_bs
=
(
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
)
if
running_bs
>=
self
.
max_running_requests
:
return
None
# Get priority queue
prefix_computed
=
self
.
scheduler
.
calc_priority
(
self
.
waiting_queue
)
num_mixed_running
=
running_bs
if
self
.
is_mixed_chunk
else
0
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
running_batch
,
self
.
new_token_ratio
,
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
(),
def
get_token_and_memory_info
(
self
):
return
(
self
.
max_total_num_tokens
,
self
.
max_prefill_tokens
,
self
.
chunked_prefill_size
,
num_mixed_running
,
)
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
self
.
current_inflight_req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
)
if
self
.
lora_paths
is
not
None
:
lora_set
=
(
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
if
self
.
running_batch
is
not
None
else
set
([])
)
for
req
in
self
.
waiting_queue
:
if
(
self
.
lora_paths
is
not
None
and
len
(
lora_set
|
set
([
req
.
lora_path
for
req
in
adder
.
can_run_list
])
|
set
([
req
.
lora_path
])
)
>
self
.
max_loras_per_batch
):
break
if
adder
.
no_remaining_tokens
():
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
)
if
(
not
res
or
running_bs
+
len
(
adder
.
can_run_list
)
>=
self
.
max_running_requests
):
break
can_run_list
=
adder
.
can_run_list
if
adder
.
new_inflight_req
is
not
None
:
assert
self
.
current_inflight_req
is
None
self
.
current_inflight_req
=
adder
.
new_inflight_req
if
len
(
can_run_list
)
==
0
:
return
None
# Print stats
if
self
.
tp_rank
==
0
:
if
isinstance
(
self
.
tree_cache
,
RadixCache
):
self
.
tree_cache_metrics
[
"total"
]
+=
(
adder
.
log_input_tokens
+
adder
.
log_hit_tokens
)
/
10
**
9
self
.
tree_cache_metrics
[
"hit"
]
+=
(
adder
.
log_hit_tokens
)
/
10
**
9
tree_cache_hit_rate
=
(
self
.
tree_cache_metrics
[
"hit"
]
/
self
.
tree_cache_metrics
[
"total"
]
)
else
:
tree_cache_hit_rate
=
0.0
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
num_mixed_running
>
0
:
logger
.
info
(
f
"Prefill batch"
f
"(mixed #running-req:
{
num_mixed_running
}
). "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
else
:
logger
.
info
(
f
"Prefill batch. "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has_inflight
}
"
)
# Return the new batch
new_batch
=
ScheduleBatch
.
init_new
(
can_run_list
,
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
max_running_requests
,
self
.
max_req_input_len
,
self
.
random_seed
,
)
self
.
waiting_queue
=
[
x
for
x
in
self
.
waiting_queue
if
x
not
in
can_run_list
]
return
new_batch
def
forward_prefill_batch
(
self
,
batch
:
ScheduleBatch
):
# Build batch tensors
batch
.
prepare_for_extend
(
self
.
model_config
.
vocab_size
)
decoding_reqs
=
[]
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
self
.
running_batch
.
prepare_for_decode
()
batch
.
mix_with_running
(
self
.
running_batch
)
decoding_reqs
=
self
.
running_batch
.
reqs
self
.
running_batch
=
None
if
self
.
model_runner
.
is_generation
:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
logits_output
=
self
.
model_runner
.
forward
(
batch
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
)
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
tolist
()
)
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
else
:
if
self
.
tokenizer
is
None
:
next_token_ids
=
[]
for
req
in
batch
.
reqs
:
next_token_ids
.
append
(
next
(
iter
(
req
.
sampling_params
.
stop_token_ids
))
)
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
# Check finish conditions
logprob_pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_ids
[
i
]
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
elif
req
not
in
decoding_reqs
:
# To reduce overhead, only cache prefill reqs
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
if
req
.
return_logprob
:
logprob_pt
+=
self
.
add_logprob_return_values
(
i
,
req
,
logprob_pt
,
next_token_ids
,
logits_output
)
else
:
assert
batch
.
extend_num_tokens
!=
0
logits_output
=
self
.
model_runner
.
forward
(
batch
)
embeddings
=
logits_output
.
embeddings
.
tolist
()
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
embedding
=
embeddings
[
i
]
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
is
self
.
current_inflight_req
:
# Inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
handle_finished_requests
(
batch
)
def
add_logprob_return_values
(
self
,
i
:
int
,
req
:
Req
,
pt
:
int
,
next_token_ids
:
List
[
int
],
output
:
LogitsProcessorOutput
,
):
"""Attach logprobs to the return values."""
req
.
output_token_logprobs
.
append
(
(
output
.
next_token_logprobs
[
i
],
next_token_ids
[
i
])
)
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
num_input_logprobs
=
req
.
extend_input_len
-
req
.
extend_logprob_start_len
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
if
req
.
input_token_logprobs
is
None
:
input_token_logprobs
=
output
.
input_token_logprobs
[
pt
:
pt
+
num_input_logprobs
-
1
-
req
.
last_update_decode_tokens
]
input_token_ids
=
req
.
fill_ids
[
len
(
req
.
fill_ids
)
-
num_input_logprobs
+
1
:
len
(
req
.
fill_ids
)
-
req
.
last_update_decode_tokens
]
req
.
input_token_logprobs
=
list
(
zip
(
input_token_logprobs
,
input_token_ids
))
if
(
req
.
logprob_start_len
==
0
):
# The first token does not have logprob, pad it.
req
.
input_token_logprobs
=
[
(
None
,
req
.
fill_ids
[
0
])
]
+
req
.
input_token_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
# Some decode tokens are re-computed in an extend batch
req
.
output_token_logprobs
.
extend
(
list
(
zip
(
output
.
input_token_logprobs
[
pt
+
num_input_logprobs
-
1
-
req
.
last_update_decode_tokens
:
pt
+
num_input_logprobs
-
1
],
req
.
fill_ids
[
len
(
req
.
fill_ids
)
-
req
.
last_update_decode_tokens
:
len
(
req
.
fill_ids
)
],
)
)
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
input_top_logprobs
is
None
:
req
.
input_top_logprobs
=
output
.
input_top_logprobs
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
input_top_logprobs
=
[
None
]
+
req
.
input_top_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
output_top_logprobs
.
extend
(
output
.
input_top_logprobs
[
i
][
-
req
.
last_update_decode_tokens
:]
)
req
.
output_top_logprobs
.
append
(
output
.
output_top_logprobs
[
i
])
return
num_input_logprobs
def
forward_decode_batch
(
self
,
batch
:
ScheduleBatch
):
# Check if decode out of memory
if
not
batch
.
check_decode_mem
():
old_ratio
=
self
.
new_token_ratio
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
()
self
.
new_token_ratio
=
new_token_ratio
logger
.
info
(
"Decode out of memory happened. "
f
"#retracted_reqs:
{
len
(
retracted_reqs
)
}
, "
f
"#new_token_ratio:
{
old_ratio
:.
4
f
}
->
{
self
.
new_token_ratio
:.
4
f
}
"
)
self
.
waiting_queue
.
extend
(
retracted_reqs
)
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
self
.
min_new_token_ratio
,
)
if
not
self
.
disable_regex_jump_forward
:
# Check for jump-forward
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
model_runner
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
return
# Update batch tensors
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
def
forward_batch_generation
(
self
,
batch
):
logits_output
=
self
.
model_runner
.
forward
(
batch
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
batch
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
next_token_ids
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish condition
has_finished
=
False
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_id
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
has_finished
=
True
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
return
logits_output
,
next_token_ids
if
not
has_finished
:
self
.
do_not_get_new_batch
=
True
self
.
handle_finished_requests
(
batch
)
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
output_rids
=
[]
output_meta_info
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
if
self
.
model_runner
.
is_generation
:
output_vids
=
[]
decoded_texts
=
[]
output_read_ids
=
[]
output_read_offsets
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
else
:
# for embedding model
output_embeddings
=
[]
unfinished_indices
=
[]
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
not
req
.
finished
()
and
req
is
not
self
.
current_inflight_req
:
unfinished_indices
.
append
(
i
)
if
req
.
finished
()
or
(
req
.
stream
and
(
self
.
decode_forward_ct
%
self
.
stream_interval
==
0
or
len
(
req
.
output_ids
)
==
1
)
):
output_rids
.
append
(
req
.
rid
)
output_finished_reason
.
append
(
req
.
finished_reason
)
if
self
.
model_runner
.
is_generation
:
output_vids
.
append
(
req
.
vid
)
decoded_texts
.
append
(
req
.
decoded_text
)
read_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
output_read_ids
.
append
(
read_ids
)
output_read_offsets
.
append
(
read_offset
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
output_spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
is
not
None
else
None
),
}
if
req
.
return_logprob
:
(
meta_info
[
"input_token_logprobs"
],
meta_info
[
"output_token_logprobs"
],
meta_info
[
"input_top_logprobs"
],
meta_info
[
"output_top_logprobs"
],
meta_info
[
"normalized_prompt_logprob"
],
)
=
(
req
.
input_token_logprobs
,
req
.
output_token_logprobs
,
req
.
input_top_logprobs
,
req
.
output_top_logprobs
,
req
.
normalized_prompt_logprob
,
)
output_meta_info
.
append
(
meta_info
)
else
:
# for embedding model
output_embeddings
.
append
(
req
.
embedding
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
}
output_meta_info
.
append
(
meta_info
)
# Send to detokenizer
if
output_rids
:
if
self
.
model_runner
.
is_generation
:
self
.
out_pyobjs
.
append
(
BatchTokenIDOut
(
output_rids
,
output_vids
,
decoded_texts
,
output_read_ids
,
output_read_offsets
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
output_meta_info
,
output_finished_reason
,
)
)
else
:
# for embedding model
self
.
out_pyobjs
.
append
(
BatchEmbeddingOut
(
output_rids
,
output_embeddings
,
output_meta_info
,
output_finished_reason
,
)
)
# Remove finished reqs: update batch tensors
batch
.
filter_batch
(
unfinished_indices
)
def
flush_cache
(
self
):
if
len
(
self
.
waiting_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
regex_fsm_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
if_success
=
True
else
:
logging
.
warning
(
f
"Cache not flushed because there are pending requests. "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
"#running-req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
)
if_success
=
False
return
if_success
def
abort_request
(
self
,
recv_req
):
# Delete requests in the waiting queue
to_del
=
None
for
i
,
req
in
enumerate
(
self
.
waiting_queue
):
if
req
.
rid
==
recv_req
.
rid
:
to_del
=
i
break
if
to_del
is
not
None
:
del
self
.
waiting_queue
[
to_del
]
# Delete requests in the running batch
if
self
.
running_batch
:
for
req
in
self
.
running_batch
.
reqs
:
if
req
.
rid
==
recv_req
.
rid
:
req
.
finished_reason
=
FINISH_ABORT
()
break
def
forward_batch_embedding
(
self
,
batch
):
logits_output
=
self
.
model_runner
.
forward
(
batch
)
embeddings
=
logits_output
.
embeddings
.
tolist
()
return
embeddings
def
update_weights
(
self
,
recv_req
):
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights
(
recv_req
.
model_path
,
recv_req
.
load_format
)
if
success
:
flash_cache_success
=
self
.
flush_cache
()
assert
flash_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
success
,
message
python/sglang/srt/mem_cache/memory_pool.py
View file @
f86c1e61
...
...
@@ -27,11 +27,11 @@ logger = logging.getLogger(__name__)
class
ReqToTokenPool
:
"""A memory pool that maps a request to its token locations."""
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
):
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
):
self
.
size
=
size
self
.
free_slots
=
list
(
range
(
size
))
self
.
req_to_token
=
torch
.
empty
(
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
)
def
alloc
(
self
,
need_size
:
int
)
->
List
[
int
]:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
f86c1e61
...
...
@@ -87,6 +87,7 @@ class ModelRunner:
self
.
model_config
.
hf_config
.
architectures
)
# Model-specific adjustment
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
and
not
self
.
server_args
.
disable_mla
...
...
@@ -94,6 +95,13 @@ class ModelRunner:
logger
.
info
(
"MLA optimization is tunred on. Use triton backend."
)
self
.
server_args
.
attention_backend
=
"triton"
if
self
.
is_multimodal_model
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args
.
chunked_prefill_size
=
None
server_args
.
mem_fraction_static
*=
0.95
global_server_args_dict
.
update
(
{
"attention_backend"
:
server_args
.
attention_backend
,
...
...
@@ -104,14 +112,6 @@ class ModelRunner:
}
)
# Model-specific adjustment
if
self
.
is_multimodal_model
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args
.
chunked_prefill_size
=
None
server_args
.
mem_fraction_static
*=
0.95
# Init componnets
min_per_gpu_memory
=
self
.
init_torch_distributed
()
self
.
sampler
=
Sampler
()
...
...
@@ -400,8 +400,7 @@ class ModelRunner:
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
max_num_reqs
+
1
,
self
.
model_config
.
context_len
+
4
,
max_num_reqs
+
1
,
self
.
model_config
.
context_len
+
4
,
device
=
"cuda"
)
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment