Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
59cbf476
"tutorials/vscode:/vscode.git/clone" did not exist on "3f6f6941598f669bf05447cc50018ef63cc7ab02"
Unverified
Commit
59cbf476
authored
Oct 19, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 19, 2024
Browse files
Unify the memory pool api and tp worker API (#1724)
parent
95946271
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
87 additions
and
25 deletions
+87
-25
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+26
-10
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+24
-8
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+6
-2
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+7
-1
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+4
-3
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+7
-0
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+11
-1
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
59cbf476
...
...
@@ -23,6 +23,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
It will be transformed from CPU scheduler to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
...
...
@@ -522,12 +524,12 @@ class ScheduleBatch:
assert
seq_len
-
pre_len
==
req
.
extend_input_len
if
pre_len
>
0
:
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
pre_len
]
=
(
req
.
prefix_indices
self
.
req_to_token_pool
.
write
(
(
req
.
req_pool_idx
,
slice
(
0
,
pre_len
)),
req
.
prefix_indices
)
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
pre_len
:
seq_len
]
=
(
out_cache_loc
[
pt
:
pt
+
req
.
extend_input_len
]
self
.
req_to_token_pool
.
write
(
(
req
.
req_pool_idx
,
slice
(
pre_len
,
seq_len
)),
out_cache_loc
[
pt
:
pt
+
req
.
extend_input_len
]
,
)
# Compute the relative logprob_start_len in an extend batch
...
...
@@ -765,9 +767,8 @@ class ScheduleBatch:
# Alloc mem
bs
=
len
(
self
.
reqs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
]
=
(
self
.
out_cache_loc
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
self
.
seq_lens
),
self
.
out_cache_loc
)
self
.
seq_lens
.
add_
(
1
)
...
...
@@ -848,7 +849,6 @@ class ScheduleBatch:
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
image_inputs
=
[
r
.
image_inputs
for
r
in
self
.
reqs
]
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
]
if
self
.
has_regex
:
self
.
sampling_info
.
regex_fsms
=
[
req
.
regex_fsm
for
req
in
self
.
reqs
]
self
.
sampling_info
.
regex_fsm_states
=
[
...
...
@@ -869,13 +869,14 @@ class ScheduleBatch:
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
req_to_token_pool_records
=
self
.
req_to_token_pool
.
get_write_records
(),
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
extend_seq_lens
=
extend_seq_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_logprob_start_lens
=
extend_logprob_start_lens
,
image_inputs
=
image_inputs
,
lora_paths
=
lora_path
s
,
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
]
,
sampling_info
=
self
.
sampling_info
,
mrope_positions_delta
=
mrope_positions_delta
,
)
...
...
@@ -911,6 +912,9 @@ class ModelWorkerBatch:
# The indices of output tokens in the token_to_kv_pool
out_cache_loc
:
torch
.
Tensor
# The memory pool operation records
req_to_token_pool_records
:
Optional
[
List
[
Tuple
[
Tuple
,
torch
.
Tensor
]]]
# For logprob
return_logprob
:
bool
top_logprobs_nums
:
Optional
[
List
[
int
]]
...
...
@@ -940,6 +944,7 @@ class ModelWorkerBatch:
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
.
clone
(),
out_cache_loc
=
self
.
out_cache_loc
,
req_to_token_pool_records
=
self
.
req_to_token_pool_records
,
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
extend_seq_lens
=
self
.
extend_seq_lens
,
...
...
@@ -950,3 +955,14 @@ class ModelWorkerBatch:
sampling_info
=
self
.
sampling_info
.
copy
(),
mrope_positions_delta
=
self
.
mrope_positions_delta
,
)
def
to
(
self
,
device
:
str
):
self
.
input_ids
=
self
.
input_ids
.
to
(
device
,
non_blocking
=
True
)
self
.
req_pool_indices
=
self
.
req_pool_indices
.
to
(
device
,
non_blocking
=
True
)
self
.
seq_lens
=
self
.
seq_lens
.
to
(
device
,
non_blocking
=
True
)
self
.
out_cache_loc
=
self
.
out_cache_loc
.
to
(
device
,
non_blocking
=
True
)
self
.
req_to_token_pool_records
=
[
(
x
,
y
.
to
(
device
,
non_blocking
=
True
))
for
x
,
y
in
self
.
req_to_token_pool_records
]
self
.
sampling_info
.
to
(
device
)
python/sglang/srt/managers/scheduler.py
View file @
59cbf476
...
...
@@ -51,6 +51,7 @@ from sglang.srt.managers.schedule_batch import (
ImageInputs
,
Req
,
ScheduleBatch
,
global_server_args_dict
,
)
from
sglang.srt.managers.schedule_policy
import
(
AddReqResult
,
...
...
@@ -144,25 +145,27 @@ class Scheduler:
)
# Launch a tensor parallel worker
self
.
tp_worker
=
TpModelWorker
(
if
self
.
server_args
.
enable_overlap_schedule
:
TpWorkerClass
=
TpModelWorker
else
:
TpWorkerClass
=
TpModelWorker
self
.
tp_worker
=
TpWorkerClass
(
server_args
=
server_args
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
dp_rank
=
dp_rank
,
nccl_port
=
port_args
.
nccl_port
,
)
# Init states for overlap schedule
if
self
.
server_args
.
enable_overlap_schedule
:
self
.
forward_batch_generation
=
(
self
.
tp_worker
.
forward_batch_generation_non_blocking
)
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
)
self
.
forward_batch_generation
=
(
self
.
tp_worker
.
forward_batch_generation_non_blocking
)
else
:
self
.
forward_batch_generation
=
self
.
tp_worker
.
forward_batch_generation
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
self
.
forward_batch_generation
=
self
.
tp_worker
.
forward_batch_generation
# Get token and memory info from the model worker
(
...
...
@@ -172,9 +175,14 @@ class Scheduler:
self
.
max_req_input_len
,
self
.
random_seed
,
self
.
device
,
)
=
self
.
tp_worker
.
get_token_and_memory_info
()
worker_global_server_args_dict
,
_
,
_
,
_
,
)
=
self
.
tp_worker
.
get_worker_info
()
self
.
tp_cpu_group
=
self
.
tp_worker
.
get_tp_cpu_group
()
self
.
pad_input_ids_func
=
self
.
tp_worker
.
get_pad_input_ids_func
()
global_server_args_dict
.
update
(
worker_global_server_args_dict
)
set_random_seed
(
self
.
random_seed
)
# Print debug info
...
...
@@ -266,6 +274,7 @@ class Scheduler:
@
torch
.
inference_mode
()
def
event_loop_normal
(
self
):
"""A normal blocking scheduler loop."""
self
.
last_batch
=
None
while
True
:
...
...
@@ -296,6 +305,7 @@ class Scheduler:
@
torch
.
inference_mode
()
def
event_loop_overlap
(
self
):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue
=
deque
()
self
.
last_batch
=
None
...
...
@@ -572,6 +582,7 @@ class Scheduler:
else
set
([])
)
# Get requests from the waiting queue to a new prefill batch
for
req
in
self
.
waiting_queue
:
if
(
self
.
lora_paths
...
...
@@ -673,6 +684,7 @@ class Scheduler:
return
new_batch
def
update_running_batch
(
self
):
"""Update the current running decoding batch."""
global
test_retract
batch
=
self
.
running_batch
...
...
@@ -712,6 +724,7 @@ class Scheduler:
batch
.
prepare_for_decode
()
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
"""Run a batch."""
if
self
.
is_generation
:
if
batch
.
forward_mode
.
is_decode
()
or
batch
.
extend_num_tokens
!=
0
:
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
...
@@ -933,6 +946,7 @@ class Scheduler:
return
num_input_logprobs
def
stream_output
(
self
,
reqs
:
List
[
Req
]):
"""Stream the output to detokenizer."""
output_rids
=
[]
output_meta_info
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
...
...
@@ -1030,6 +1044,7 @@ class Scheduler:
)
def
flush_cache
(
self
):
"""Flush the memory pool and cache."""
if
len
(
self
.
waiting_queue
)
==
0
and
(
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
...
...
@@ -1070,6 +1085,7 @@ class Scheduler:
break
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
"""In-place update of the weights."""
success
,
message
=
self
.
tp_worker
.
update_weights
(
recv_req
)
if
success
:
flash_cache_success
=
self
.
flush_cache
()
...
...
python/sglang/srt/managers/tp_worker.py
View file @
59cbf476
...
...
@@ -27,7 +27,7 @@ import torch
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
UpdateWeightReqInput
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -111,7 +111,7 @@ class TpModelWorker:
if
server_args
.
enable_overlap_schedule
:
self
.
init_overlap_status
()
def
get_
token_and_memory
_info
(
self
):
def
get_
worker
_info
(
self
):
return
(
self
.
max_total_num_tokens
,
self
.
max_prefill_tokens
,
...
...
@@ -119,6 +119,10 @@ class TpModelWorker:
self
.
max_req_input_len
,
self
.
random_seed
,
self
.
device
,
global_server_args_dict
,
self
.
model_runner
.
req_to_token_pool
.
size
,
self
.
model_runner
.
req_to_token_pool
.
max_context_len
,
self
.
model_runner
.
token_to_kv_pool
.
size
,
)
def
get_pad_input_ids_func
(
self
):
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
59cbf476
...
...
@@ -56,6 +56,12 @@ class ReqToTokenPool:
def
clear
(
self
):
self
.
free_slots
=
list
(
range
(
self
.
size
))
def
write
(
self
,
indices
,
values
):
self
.
req_to_token
[
indices
]
=
values
def
get_write_records
(
self
):
return
None
class
BaseTokenToKVPool
:
"""A memory pool that maps a token to its kv cache locations"""
...
...
@@ -68,12 +74,12 @@ class BaseTokenToKVPool:
):
self
.
size
=
size
self
.
dtype
=
dtype
self
.
device
=
device
if
dtype
==
torch
.
float8_e5m2
:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self
.
store_dtype
=
torch
.
uint8
else
:
self
.
store_dtype
=
dtype
self
.
device
=
device
self
.
free_slots
=
None
self
.
is_not_in_free_group
=
True
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
59cbf476
...
...
@@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache):
# The prefix indices could be updated, reuse it
new_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
assert
len
(
new_indices
)
==
len
(
token_ids
)
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
len
(
req
.
prefix_indices
)
:
len
(
new_indices
)
]
=
new_indices
[
len
(
req
.
prefix_indices
)
:]
self
.
req_to_token_pool
.
write
(
(
req
.
req_pool_idx
,
slice
(
len
(
req
.
prefix_indices
),
len
(
new_indices
))),
new_indices
[
len
(
req
.
prefix_indices
)
:],
)
self
.
dec_lock_ref
(
req
.
last_node
)
self
.
inc_lock_ref
(
new_last_node
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
59cbf476
...
...
@@ -25,6 +25,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
It will be transformed from CPU scheduler to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
59cbf476
...
...
@@ -131,6 +131,13 @@ class ModelRunner:
]:
server_args
.
disable_cuda_graph
=
True
if
self
.
server_args
.
enable_overlap_schedule
:
logger
.
warning
(
"Overlap scheduler is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results."
)
# Global vars
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
59cbf476
...
...
@@ -78,7 +78,7 @@ class SamplingBatchInfo:
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
),
is_all_greedy
=
top_ks
.
max
().
item
()
<=
1
,
vocab_size
=
vocab_size
,
device
=
batch
.
input_ids
.
device
,
device
=
device
,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
...
...
@@ -224,3 +224,13 @@ class SamplingBatchInfo:
vocab_size
=
self
.
vocab_size
,
device
=
self
.
device
,
)
def
to
(
self
,
device
:
str
):
for
item
in
[
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
]:
value
=
getattr
(
self
,
item
)
setattr
(
self
,
item
,
value
.
to
(
device
,
non_blocking
=
True
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment