Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7fa54a1a
".github/vscode:/vscode.git/clone" did not exist on "6aebf44f47bc73ac34344fb7b5d941790c11d39d"
Unverified
Commit
7fa54a1a
authored
Aug 07, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 07, 2024
Browse files
Make `req_pool_indices` on CPU (#960)
parent
05abd126
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
105 additions
and
109 deletions
+105
-109
python/sglang/global_config.py
python/sglang/global_config.py
+0
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+93
-89
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+3
-6
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+9
-13
No files found.
python/sglang/global_config.py
View file @
7fa54a1a
...
...
@@ -19,7 +19,6 @@ class GlobalConfig:
self
.
init_new_token_ratio
=
0.7
self
.
base_min_new_token_ratio
=
0.1
self
.
new_token_ratio_decay
=
0.001
self
.
new_token_ratio_recovery
=
0.05
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
# This can improve the speed for large batch sizes during prefill.
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
7fa54a1a
...
...
@@ -100,6 +100,9 @@ class Req:
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
input_ids
=
None
# input_ids = origin_input_ids + output_ids
# Memory info
self
.
req_pool_idx
=
None
# For incremental decoding
# ----- | --------- read_ids -------|
# ----- | surr_ids |
...
...
@@ -321,6 +324,9 @@ class ScheduleBatch:
return_logprob
=
return_logprob
,
)
def
batch_size
(
self
):
return
len
(
self
.
reqs
)
if
self
.
reqs
is
not
None
else
0
def
is_empty
(
self
):
return
len
(
self
.
reqs
)
==
0
...
...
@@ -328,118 +334,127 @@ class ScheduleBatch:
# Return whether batch has at least 1 streaming request
return
any
(
r
.
stream
for
r
in
self
.
reqs
)
def
alloc_req_slots
(
self
,
num_reqs
):
req_pool_indices
=
self
.
req_to_token_pool
.
alloc
(
num_reqs
)
if
req_pool_indices
is
None
:
raise
RuntimeError
(
"Out of memory. "
"Please set a smaller number for `--max-running-requests`."
)
return
req_pool_indices
def
alloc_token_slots
(
self
,
num_tokens
:
int
):
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
num_tokens
)
if
out_cache_loc
is
None
:
if
self
.
tree_cache
is
not
None
:
self
.
tree_cache
.
evict
(
num_tokens
,
self
.
token_to_kv_pool
.
free
)
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
num_tokens
)
if
out_cache_loc
is
None
:
logger
.
error
(
"Prefill out of memory. Try to lower your batch size."
)
if
self
.
tree_cache
is
not
None
:
self
.
tree_cache
.
pretty_print
()
exit
(
1
)
return
out_cache_loc
def
batch_sampling_params
(
self
,
vocab_size
,
int_token_logit_bias
):
device
=
"cuda"
bs
,
reqs
=
self
.
batch_size
(),
self
.
reqs
self
.
temperatures
=
torch
.
tensor
(
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
,
).
view
(
-
1
,
1
)
self
.
top_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
self
.
top_ks
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
,
device
=
device
)
self
.
frequency_penalties
=
torch
.
tensor
(
[
r
.
sampling_params
.
frequency_penalty
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
,
)
self
.
presence_penalties
=
torch
.
tensor
(
[
r
.
sampling_params
.
presence_penalty
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
,
)
# Handle logit bias but only allocate when needed
self
.
logit_bias
=
None
for
i
in
range
(
bs
):
if
reqs
[
i
].
sampling_params
.
dtype
==
"int"
:
if
self
.
logit_bias
is
None
:
self
.
logit_bias
=
torch
.
zeros
(
(
bs
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
device
)
self
.
logit_bias
[
i
][:
len
(
int_token_logit_bias
)]
=
int_token_logit_bias
def
prepare_for_extend
(
self
,
vocab_size
:
int
,
int_token_logit_bias
:
torch
.
Tensor
):
device
=
"cuda"
bs
=
len
(
self
.
reqs
)
bs
=
self
.
batch_size
(
)
reqs
=
self
.
reqs
input_ids
=
[
r
.
input_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
prefix_indices
=
[
r
.
prefix_indices
for
r
in
reqs
]
# Handle prefix
flatten_input_ids
=
[]
extend_lens
=
[]
prefix_lens
=
[]
seq_lens
=
[]
req_pool_indices
=
self
.
req_to_token_pool
.
alloc
(
bs
)
req_pool_indices
_cpu
=
self
.
alloc_req_slots
(
bs
)
if
req_pool_indices
is
None
:
raise
RuntimeError
(
"Out of memory. "
"Please set a smaller number for `--max-running-requests`."
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
for
i
in
range
(
bs
):
flatten_input_ids
.
extend
(
input_ids
[
i
])
for
i
,
req
in
enumerate
(
reqs
):
req
.
req_pool_idx
=
req_pool_indices_cpu
[
i
]
extend_lens
.
append
(
len
(
input_ids
[
i
]))
if
len
(
prefix_indices
[
i
])
==
0
:
prefix_lens
.
append
(
0
)
else
:
prefix_lens
.
append
(
len
(
prefix_indices
[
i
]))
self
.
req_to_token_pool
.
req_to_token
[
req_pool_i
ndices_cpu
[
i
]
][
self
.
req_to_token_pool
.
req_to_token
[
req
.
req
_pool_i
dx
][
:
len
(
prefix_indices
[
i
])
]
=
prefix_indices
[
i
]
seq_lens
.
append
(
prefix_lens
[
-
1
]
+
extend_lens
[
-
1
])
position_ids_offsets
=
torch
.
zeros
((
bs
,),
dtype
=
torch
.
int32
,
device
=
device
)
# Allocate memory
seq_lens
,
prefix_lens
=
np
.
array
(
seq_lens
),
np
.
array
(
prefix_lens
)
extend_num_tokens
=
seq_lens
.
sum
()
-
prefix_lens
.
sum
()
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
extend_num_tokens
)
if
out_cache_loc
is
None
:
if
self
.
tree_cache
is
not
None
:
self
.
tree_cache
.
evict
(
extend_num_tokens
,
self
.
token_to_kv_pool
.
free
)
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
extend_num_tokens
)
if
out_cache_loc
is
None
:
logger
.
error
(
"Prefill out of memory. Try to lower your batch size."
)
if
self
.
tree_cache
is
not
None
:
self
.
tree_cache
.
pretty_print
()
exit
(
1
)
out_cache_loc
=
self
.
alloc_token_slots
(
extend_num_tokens
)
pt
=
0
for
i
in
range
(
b
s
):
self
.
req_to_token_pool
.
req_to_token
[
req_pool_i
ndices_cpu
[
i
]
][
for
i
,
req
in
enumerate
(
req
s
):
self
.
req_to_token_pool
.
req_to_token
[
req
.
req
_pool_i
dx
][
prefix_lens
[
i
]
:
prefix_lens
[
i
]
+
extend_lens
[
i
]
]
=
out_cache_loc
[
pt
:
pt
+
extend_lens
[
i
]]
pt
+=
extend_lens
[
i
]
# Handle logit bias but only allocate when needed
logit_bias
=
None
for
i
in
range
(
bs
):
if
reqs
[
i
].
sampling_params
.
dtype
==
"int"
:
if
logit_bias
is
None
:
logit_bias
=
torch
.
zeros
(
(
bs
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
device
)
logit_bias
[
i
][:
len
(
int_token_logit_bias
)]
=
int_token_logit_bias
# Set fields
self
.
input_ids
=
torch
.
tensor
(
flatten_input_ids
,
dtype
=
torch
.
int32
,
device
=
device
)
with
torch
.
device
(
"cuda"
):
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices_cpu
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
)
self
.
position_ids_offsets
=
torch
.
zeros
((
bs
,),
dtype
=
torch
.
int32
)
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_size
for
r
in
reqs
]
self
.
image_offsets
=
[
r
.
image_offset
-
p_len
for
r
,
p_len
in
zip
(
reqs
,
prefix_lens
)
]
self
.
req_pool_indices
=
req_pool_indices
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
prefix_lens
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
position_ids_offsets
=
position_ids_offsets
self
.
extend_num_tokens
=
extend_num_tokens
self
.
out_cache_loc
=
out_cache_loc
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
temperatures
=
torch
.
tensor
(
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
,
).
view
(
-
1
,
1
)
self
.
top_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
self
.
top_ks
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
,
device
=
device
)
self
.
frequency_penalties
=
torch
.
tensor
(
[
r
.
sampling_params
.
frequency_penalty
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
,
)
self
.
presence_penalties
=
torch
.
tensor
(
[
r
.
sampling_params
.
presence_penalty
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
,
)
self
.
logit_bias
=
logit_bias
self
.
batch_sampling_params
(
vocab_size
,
int_token_logit_bias
)
def
check_decode_mem
(
self
):
bs
=
len
(
self
.
reqs
)
bs
=
self
.
batch_size
(
)
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
return
True
...
...
@@ -464,7 +479,6 @@ class ScheduleBatch:
retracted_reqs
=
[]
seq_lens_cpu
=
self
.
seq_lens
.
cpu
().
numpy
()
req_pool_indices_cpu
=
self
.
req_pool_indices
.
cpu
().
numpy
()
while
(
self
.
token_to_kv_pool
.
available_size
()
<
len
(
sorted_indices
)
*
global_config
.
retract_decode_steps
...
...
@@ -482,20 +496,20 @@ class ScheduleBatch:
if
isinstance
(
self
.
tree_cache
,
ChunkCache
):
# ChunkCache does not have eviction
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indice
s_cpu
[
idx
]
]
[:
seq_lens_cpu
[
idx
]]
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
:
seq_len
s_cpu
[
idx
]
]
self
.
token_to_kv_pool
.
free
(
token_indices
)
self
.
req_to_token_pool
.
free
(
int
(
req_pool_i
ndices_cpu
[
idx
])
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_i
dx
)
del
self
.
tree_cache
.
entries
[
req
.
rid
]
else
:
# TODO: apply more fine-grained retraction
last_uncached_pos
=
len
(
req
.
prefix_indices
)
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indice
s_cpu
[
idx
]
]
[
last_uncached_pos
:
seq_lens_cpu
[
idx
]]
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
last_uncached_pos
:
seq_len
s_cpu
[
idx
]
]
self
.
token_to_kv_pool
.
free
(
token_indices
)
self
.
req_to_token_pool
.
free
(
int
(
req_pool_i
ndices_cpu
[
idx
])
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_i
dx
)
# release the last node
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
...
...
@@ -533,8 +547,6 @@ class ScheduleBatch:
jump_forward_reqs
=
[]
filter_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))]
req_pool_indices_cpu
=
None
for
i
,
req
in
enumerate
(
self
.
reqs
):
if
req
.
jump_forward_map
is
not
None
:
jump_forward_bytes
=
req
.
jump_forward_map
.
jump_forward_byte
(
...
...
@@ -584,13 +596,11 @@ class ScheduleBatch:
req
.
vid
+=
1
# insert the old request into tree_cache
if
req_pool_indices_cpu
is
None
:
req_pool_indices_cpu
=
self
.
req_pool_indices
.
tolist
()
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
cur_all_ids
,
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_i
ndices_cpu
[
i
]
,
req_pool_idx
=
req
.
req
_pool_i
dx
,
)
# unlock the last node
...
...
@@ -626,14 +636,8 @@ class ScheduleBatch:
self
.
prefix_lens
=
None
# Alloc mem
bs
=
len
(
self
.
reqs
)
self
.
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
bs
)
if
self
.
out_cache_loc
is
None
:
logger
.
error
(
"Decode out of memory. Try to lower your batch size."
)
if
self
.
tree_cache
is
not
None
:
self
.
tree_cache
.
pretty_print
()
exit
(
1
)
bs
=
self
.
batch_size
()
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
-
1
...
...
python/sglang/srt/managers/tp_worker.py
View file @
7fa54a1a
...
...
@@ -200,7 +200,6 @@ class ModelTpServer:
)
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
new_token_ratio_recovery
=
global_config
.
new_token_ratio_recovery
def
exposed_step
(
self
,
recv_reqs
):
try
:
...
...
@@ -625,13 +624,12 @@ class ModelTpServer:
req
.
output_top_logprobs
.
append
(
output
.
output_top_logprobs
[
i
])
def
cache_filled_batch
(
self
,
batch
:
ScheduleBatch
):
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
numpy
()
for
i
,
req
in
enumerate
(
batch
.
reqs
):
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
tuple
(
req
.
input_ids
),
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_i
ndices_cpu
[
i
]
,
req_pool_idx
=
req
.
req
_pool_i
dx
,
del_in_memory_pool
=
False
,
old_last_node
=
req
.
last_node
,
)
...
...
@@ -639,7 +637,7 @@ class ModelTpServer:
if
req
is
self
.
current_inflight_req
:
# inflight request would get a new req idx
self
.
req_to_token_pool
.
free
(
int
(
req_pool_i
ndices_cpu
[
i
])
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_i
dx
)
def
forward_decode_batch
(
self
,
batch
:
ScheduleBatch
):
# Check if decode out of memory
...
...
@@ -782,14 +780,13 @@ class ModelTpServer:
# Remove finished reqs
if
finished_indices
:
# Update radix cache
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
tolist
()
for
i
in
finished_indices
:
req
=
batch
.
reqs
[
i
]
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
tuple
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
],
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_i
ndices_cpu
[
i
]
,
req_pool_idx
=
req
.
req
_pool_i
dx
,
)
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
7fa54a1a
...
...
@@ -16,6 +16,7 @@ limitations under the License.
"""Memory pool."""
import
logging
from
typing
import
List
import
torch
...
...
@@ -27,34 +28,29 @@ class ReqToTokenPool:
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
):
self
.
size
=
size
self
.
mem_state
=
torch
.
ones
((
size
,),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
self
.
free_slots
=
list
(
range
(
size
)
)
self
.
req_to_token
=
torch
.
empty
(
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
can_use_mem_size
=
size
def
alloc
(
self
,
need_size
:
int
):
if
need_size
>
self
.
can_use_mem_size
:
def
alloc
(
self
,
need_size
:
int
)
->
List
[
int
]
:
if
need_size
>
len
(
self
.
free_slots
)
:
return
None
select_index
=
(
torch
.
nonzero
(
self
.
mem_state
).
squeeze
(
1
)[:
need_size
].
to
(
torch
.
int32
)
)
self
.
mem_state
[
select_index
]
=
False
self
.
can_use_mem_size
-=
need_size
select_index
=
self
.
free_slots
[:
need_size
]
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
return
select_index
def
free
(
self
,
free_index
):
self
.
mem_state
[
free_index
]
=
True
if
isinstance
(
free_index
,
(
int
,)):
self
.
can_use_mem_size
+=
1
self
.
free_slots
.
append
(
free_index
)
else
:
self
.
can_use_mem_size
+=
free_index
.
shape
[
0
]
self
.
free_slots
.
extend
(
free_index
)
def
clear
(
self
):
self
.
mem_state
.
fill_
(
True
)
self
.
can_use_mem_size
=
len
(
self
.
mem_state
)
self
.
free_slots
=
list
(
range
(
self
.
size
))
class
BaseTokenToKVPool
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment