Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
769bf11c
Unverified
Commit
769bf11c
authored
Oct 19, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 19, 2024
Browse files
Fix the race condition in overlap mode (#1712)
parent
3db43d1b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
21 additions
and
38 deletions
+21
-38
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+6
-9
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-7
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+7
-8
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+2
-4
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+2
-10
python/sglang/srt/server.py
python/sglang/srt/server.py
+2
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
769bf11c
...
@@ -405,9 +405,9 @@ class ScheduleBatch:
...
@@ -405,9 +405,9 @@ class ScheduleBatch:
# Request, memory pool, and cache
# Request, memory pool, and cache
reqs
:
List
[
Req
]
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
req_to_token_pool
:
ReqToTokenPool
=
None
token_to_kv_pool
:
BaseTokenToKVPool
token_to_kv_pool
:
BaseTokenToKVPool
=
None
tree_cache
:
BasePrefixCache
tree_cache
:
BasePrefixCache
=
None
forward_mode
:
ForwardMode
=
None
forward_mode
:
ForwardMode
=
None
sampling_info
:
SamplingBatchInfo
=
None
sampling_info
:
SamplingBatchInfo
=
None
...
@@ -874,12 +874,9 @@ class ScheduleBatch:
...
@@ -874,12 +874,9 @@ class ScheduleBatch:
def
copy
(
self
):
def
copy
(
self
):
return
ScheduleBatch
(
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
reqs
=
self
.
reqs
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
tree_cache
=
self
.
tree_cache
,
forward_mode
=
self
.
forward_mode
,
forward_mode
=
self
.
forward_mode
,
out
put_ids
=
self
.
output_ids
,
out
_cache_loc
=
self
.
out_cache_loc
,
sampling_info
=
self
.
sampling_info
,
return_logprob
=
self
.
return_logprob
,
decoding_reqs
=
self
.
decoding_reqs
,
decoding_reqs
=
self
.
decoding_reqs
,
)
)
...
@@ -929,7 +926,7 @@ class ModelWorkerBatch:
...
@@ -929,7 +926,7 @@ class ModelWorkerBatch:
forward_mode
=
self
.
forward_mode
,
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
.
clone
(),
input_ids
=
self
.
input_ids
.
clone
(),
req_pool_indices
=
self
.
req_pool_indices
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
seq_lens
=
self
.
seq_lens
.
clone
()
,
out_cache_loc
=
self
.
out_cache_loc
,
out_cache_loc
=
self
.
out_cache_loc
,
return_logprob
=
self
.
return_logprob
,
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
769bf11c
...
@@ -261,12 +261,7 @@ class Scheduler:
...
@@ -261,12 +261,7 @@ class Scheduler:
self
.
resolve_next_token_ids
=
(
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
)
)
self
.
cache_finished_req
=
self
.
tree_cache
.
cache_finished_req
def
cache_finished_req
(
req
):
free_delta
=
int
(
self
.
running_batch
and
req
in
self
.
cur_batch
.
reqs
)
self
.
tree_cache
.
cache_finished_req
(
req
,
free_delta
=
free_delta
)
self
.
cache_finished_req
=
cache_finished_req
else
:
else
:
self
.
forward_batch_generation
=
self
.
tp_worker
.
forward_batch_generation
self
.
forward_batch_generation
=
self
.
tp_worker
.
forward_batch_generation
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
...
@@ -798,7 +793,6 @@ class Scheduler:
...
@@ -798,7 +793,6 @@ class Scheduler:
i
,
req
,
logprob_pt
,
next_token_ids
,
logits_output
i
,
req
,
logprob_pt
,
next_token_ids
,
logits_output
)
)
else
:
# embedding or reward model
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
embeddings
,
bid
=
result
embeddings
,
bid
=
result
embeddings
=
embeddings
.
tolist
()
embeddings
=
embeddings
.
tolist
()
...
@@ -838,6 +832,7 @@ class Scheduler:
...
@@ -838,6 +832,7 @@ class Scheduler:
# Check finish condition
# Check finish condition
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
self
.
server_args
.
enable_overlap_schedule
and
req
.
finished
():
if
self
.
server_args
.
enable_overlap_schedule
and
req
.
finished
():
self
.
token_to_kv_pool
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
continue
continue
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
completion_tokens_wo_jump_forward
+=
1
...
...
python/sglang/srt/managers/tp_worker.py
View file @
769bf11c
...
@@ -149,14 +149,12 @@ class TpModelWorker:
...
@@ -149,14 +149,12 @@ class TpModelWorker:
)
)
# Resolve future tokens in the input
# Resolve future tokens in the input
# logger.info(f"raw input {model_worker_batch.input_ids=}")
tic2
=
time
.
time
()
tic2
=
time
.
time
()
resolved_input_ids
=
model_worker_batch
.
input_ids
resolved_input_ids
=
model_worker_batch
.
input_ids
future_mask
=
resolved_input_ids
<
0
future_mask
=
resolved_input_ids
<
0
resolved_input_ids
[
future_mask
]
=
self
.
future_token_ids_map
[
resolved_input_ids
[
future_mask
]
=
self
.
future_token_ids_map
[
-
resolved_input_ids
[
future_mask
]
-
resolved_input_ids
[
future_mask
]
]
]
# logger.info(f"resolved input {model_worker_batch.input_ids=}")
# Run forward
# Run forward
logits_output
,
next_token_ids
=
self
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
forward_batch_generation
(
...
@@ -215,12 +213,13 @@ class TpModelWorker:
...
@@ -215,12 +213,13 @@ class TpModelWorker:
self
.
future_logits_output_ct
+=
1
self
.
future_logits_output_ct
+=
1
bs
=
len
(
model_worker_batch
.
seq_lens
)
bs
=
len
(
model_worker_batch
.
seq_lens
)
future_next_token_ids
=
-
torch
.
arange
(
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
self
.
future_token_ids_ct
+
1
,
future_next_token_ids
=
-
torch
.
arange
(
self
.
future_token_ids_ct
+
1
+
bs
,
self
.
future_token_ids_ct
+
1
,
dtype
=
torch
.
int32
,
self
.
future_token_ids_ct
+
1
+
bs
,
device
=
self
.
device
,
dtype
=
torch
.
int32
,
)
device
=
self
.
device
,
)
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
)
%
self
.
future_token_ids_limit
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
769bf11c
...
@@ -38,16 +38,14 @@ class ChunkCache(BasePrefixCache):
...
@@ -38,16 +38,14 @@ class ChunkCache(BasePrefixCache):
max_prefix_len
=
len
(
key
)
max_prefix_len
=
len
(
key
)
return
entry
.
value
[:
max_prefix_len
],
entry
return
entry
.
value
[:
max_prefix_len
],
entry
def
cache_finished_req
(
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
,
free_delta
:
int
=
0
):
if
token_ids
is
None
:
if
token_ids
is
None
:
token_id_len
=
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
token_id_len
=
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
else
:
else
:
token_id_len
=
len
(
token_ids
)
token_id_len
=
len
(
token_ids
)
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
token_id_len
+
free_delta
req
.
req_pool_idx
,
:
token_id_len
]
]
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
769bf11c
...
@@ -97,9 +97,7 @@ class RadixCache(BasePrefixCache):
...
@@ -97,9 +97,7 @@ class RadixCache(BasePrefixCache):
value
=
[
x
for
x
in
key
]
value
=
[
x
for
x
in
key
]
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
def
cache_finished_req
(
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
,
free_delta
:
int
=
0
):
"""Cache request when it finishes."""
"""Cache request when it finishes."""
if
self
.
disable
:
if
self
.
disable
:
if
token_ids
is
None
:
if
token_ids
is
None
:
...
@@ -108,7 +106,7 @@ class RadixCache(BasePrefixCache):
...
@@ -108,7 +106,7 @@ class RadixCache(BasePrefixCache):
token_ids_len
=
len
(
token_ids
)
token_ids_len
=
len
(
token_ids
)
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
token_ids_len
+
free_delta
req
.
req_pool_idx
,
:
token_ids_len
]
]
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
...
@@ -123,12 +121,6 @@ class RadixCache(BasePrefixCache):
...
@@ -123,12 +121,6 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool
# Radix Cache takes one ref in memory pool
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
if
free_delta
:
self
.
token_to_kv_pool
.
free
(
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
len
(
token_ids
)
:
len
(
token_ids
)
+
1
]
)
# Remove req slot release the cache lock
# Remove req slot release the cache lock
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
...
...
python/sglang/srt/server.py
View file @
769bf11c
...
@@ -542,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
...
@@ -542,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
kill_child_process
(
pid
,
including_parent
=
False
)
kill_child_process
(
pid
,
including_parent
=
False
)
return
return
# logger.info(f"{res.json()=}")
logger
.
info
(
"The server is fired up and ready to roll!"
)
logger
.
info
(
"The server is fired up and ready to roll!"
)
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
"ready"
)
pipe_finish_writer
.
send
(
"ready"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment