Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
769bf11c
Unverified
Commit
769bf11c
authored
Oct 19, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 19, 2024
Browse files
Fix the race condition in overlap mode (#1712)
parent
3db43d1b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
21 additions
and
38 deletions
+21
-38
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+6
-9
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-7
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+7
-8
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+2
-4
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+2
-10
python/sglang/srt/server.py
python/sglang/srt/server.py
+2
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
769bf11c
...
...
@@ -405,9 +405,9 @@ class ScheduleBatch:
# Request, memory pool, and cache
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
BaseTokenToKVPool
tree_cache
:
BasePrefixCache
req_to_token_pool
:
ReqToTokenPool
=
None
token_to_kv_pool
:
BaseTokenToKVPool
=
None
tree_cache
:
BasePrefixCache
=
None
forward_mode
:
ForwardMode
=
None
sampling_info
:
SamplingBatchInfo
=
None
...
...
@@ -874,12 +874,9 @@ class ScheduleBatch:
def
copy
(
self
):
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
tree_cache
=
self
.
tree_cache
,
forward_mode
=
self
.
forward_mode
,
out
put_ids
=
self
.
output_ids
,
sampling_info
=
self
.
sampling_info
,
out
_cache_loc
=
self
.
out_cache_loc
,
return_logprob
=
self
.
return_logprob
,
decoding_reqs
=
self
.
decoding_reqs
,
)
...
...
@@ -929,7 +926,7 @@ class ModelWorkerBatch:
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
.
clone
(),
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
seq_lens
=
self
.
seq_lens
.
clone
()
,
out_cache_loc
=
self
.
out_cache_loc
,
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
769bf11c
...
...
@@ -261,12 +261,7 @@ class Scheduler:
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
)
def
cache_finished_req
(
req
):
free_delta
=
int
(
self
.
running_batch
and
req
in
self
.
cur_batch
.
reqs
)
self
.
tree_cache
.
cache_finished_req
(
req
,
free_delta
=
free_delta
)
self
.
cache_finished_req
=
cache_finished_req
self
.
cache_finished_req
=
self
.
tree_cache
.
cache_finished_req
else
:
self
.
forward_batch_generation
=
self
.
tp_worker
.
forward_batch_generation
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
...
...
@@ -798,7 +793,6 @@ class Scheduler:
i
,
req
,
logprob_pt
,
next_token_ids
,
logits_output
)
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
embeddings
,
bid
=
result
embeddings
=
embeddings
.
tolist
()
...
...
@@ -838,6 +832,7 @@ class Scheduler:
# Check finish condition
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
self
.
server_args
.
enable_overlap_schedule
and
req
.
finished
():
self
.
token_to_kv_pool
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
continue
req
.
completion_tokens_wo_jump_forward
+=
1
...
...
python/sglang/srt/managers/tp_worker.py
View file @
769bf11c
...
...
@@ -149,14 +149,12 @@ class TpModelWorker:
)
# Resolve future tokens in the input
# logger.info(f"raw input {model_worker_batch.input_ids=}")
tic2
=
time
.
time
()
resolved_input_ids
=
model_worker_batch
.
input_ids
future_mask
=
resolved_input_ids
<
0
resolved_input_ids
[
future_mask
]
=
self
.
future_token_ids_map
[
-
resolved_input_ids
[
future_mask
]
]
# logger.info(f"resolved input {model_worker_batch.input_ids=}")
# Run forward
logits_output
,
next_token_ids
=
self
.
forward_batch_generation
(
...
...
@@ -215,12 +213,13 @@ class TpModelWorker:
self
.
future_logits_output_ct
+=
1
bs
=
len
(
model_worker_batch
.
seq_lens
)
future_next_token_ids
=
-
torch
.
arange
(
self
.
future_token_ids_ct
+
1
,
self
.
future_token_ids_ct
+
1
+
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
future_next_token_ids
=
-
torch
.
arange
(
self
.
future_token_ids_ct
+
1
,
self
.
future_token_ids_ct
+
1
+
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
769bf11c
...
...
@@ -38,16 +38,14 @@ class ChunkCache(BasePrefixCache):
max_prefix_len
=
len
(
key
)
return
entry
.
value
[:
max_prefix_len
],
entry
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
,
free_delta
:
int
=
0
):
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
if
token_ids
is
None
:
token_id_len
=
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
else
:
token_id_len
=
len
(
token_ids
)
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
token_id_len
+
free_delta
req
.
req_pool_idx
,
:
token_id_len
]
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
769bf11c
...
...
@@ -97,9 +97,7 @@ class RadixCache(BasePrefixCache):
value
=
[
x
for
x
in
key
]
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
,
free_delta
:
int
=
0
):
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
"""Cache request when it finishes."""
if
self
.
disable
:
if
token_ids
is
None
:
...
...
@@ -108,7 +106,7 @@ class RadixCache(BasePrefixCache):
token_ids_len
=
len
(
token_ids
)
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
token_ids_len
+
free_delta
req
.
req_pool_idx
,
:
token_ids_len
]
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
...
...
@@ -123,12 +121,6 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
if
free_delta
:
self
.
token_to_kv_pool
.
free
(
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
len
(
token_ids
)
:
len
(
token_ids
)
+
1
]
)
# Remove req slot release the cache lock
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
...
...
python/sglang/srt/server.py
View file @
769bf11c
...
...
@@ -542,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
kill_child_process
(
pid
,
including_parent
=
False
)
return
# logger.info(f"{res.json()=}")
logger
.
info
(
"The server is fired up and ready to roll!"
)
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
"ready"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment