Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
769bf11c
"host/online_compile/include/tmp_dir.hpp" did not exist on "1264925422920f24b3bb4fa34f178e31a23c97b5"
Unverified
Commit
769bf11c
authored
Oct 19, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 19, 2024
Browse files
Fix the race condition in overlap mode (#1712)
parent
3db43d1b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
21 additions
and
38 deletions
+21
-38
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+6
-9
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-7
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+7
-8
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+2
-4
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+2
-10
python/sglang/srt/server.py
python/sglang/srt/server.py
+2
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
769bf11c
...
@@ -405,9 +405,9 @@ class ScheduleBatch:
...
@@ -405,9 +405,9 @@ class ScheduleBatch:
# Request, memory pool, and cache
# Request, memory pool, and cache
reqs
:
List
[
Req
]
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
req_to_token_pool
:
ReqToTokenPool
=
None
token_to_kv_pool
:
BaseTokenToKVPool
token_to_kv_pool
:
BaseTokenToKVPool
=
None
tree_cache
:
BasePrefixCache
tree_cache
:
BasePrefixCache
=
None
forward_mode
:
ForwardMode
=
None
forward_mode
:
ForwardMode
=
None
sampling_info
:
SamplingBatchInfo
=
None
sampling_info
:
SamplingBatchInfo
=
None
...
@@ -874,12 +874,9 @@ class ScheduleBatch:
...
@@ -874,12 +874,9 @@ class ScheduleBatch:
def
copy
(
self
):
def
copy
(
self
):
return
ScheduleBatch
(
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
reqs
=
self
.
reqs
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
tree_cache
=
self
.
tree_cache
,
forward_mode
=
self
.
forward_mode
,
forward_mode
=
self
.
forward_mode
,
out
put_ids
=
self
.
output_ids
,
out
_cache_loc
=
self
.
out_cache_loc
,
sampling_info
=
self
.
sampling_info
,
return_logprob
=
self
.
return_logprob
,
decoding_reqs
=
self
.
decoding_reqs
,
decoding_reqs
=
self
.
decoding_reqs
,
)
)
...
@@ -929,7 +926,7 @@ class ModelWorkerBatch:
...
@@ -929,7 +926,7 @@ class ModelWorkerBatch:
forward_mode
=
self
.
forward_mode
,
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
.
clone
(),
input_ids
=
self
.
input_ids
.
clone
(),
req_pool_indices
=
self
.
req_pool_indices
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
seq_lens
=
self
.
seq_lens
.
clone
()
,
out_cache_loc
=
self
.
out_cache_loc
,
out_cache_loc
=
self
.
out_cache_loc
,
return_logprob
=
self
.
return_logprob
,
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
769bf11c
...
@@ -261,12 +261,7 @@ class Scheduler:
...
@@ -261,12 +261,7 @@ class Scheduler:
self
.
resolve_next_token_ids
=
(
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
)
)
self
.
cache_finished_req
=
self
.
tree_cache
.
cache_finished_req
def
cache_finished_req
(
req
):
free_delta
=
int
(
self
.
running_batch
and
req
in
self
.
cur_batch
.
reqs
)
self
.
tree_cache
.
cache_finished_req
(
req
,
free_delta
=
free_delta
)
self
.
cache_finished_req
=
cache_finished_req
else
:
else
:
self
.
forward_batch_generation
=
self
.
tp_worker
.
forward_batch_generation
self
.
forward_batch_generation
=
self
.
tp_worker
.
forward_batch_generation
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
...
@@ -798,7 +793,6 @@ class Scheduler:
...
@@ -798,7 +793,6 @@ class Scheduler:
i
,
req
,
logprob_pt
,
next_token_ids
,
logits_output
i
,
req
,
logprob_pt
,
next_token_ids
,
logits_output
)
)
else
:
# embedding or reward model
else
:
# embedding or reward model
assert
batch
.
extend_num_tokens
!=
0
embeddings
,
bid
=
result
embeddings
,
bid
=
result
embeddings
=
embeddings
.
tolist
()
embeddings
=
embeddings
.
tolist
()
...
@@ -838,6 +832,7 @@ class Scheduler:
...
@@ -838,6 +832,7 @@ class Scheduler:
# Check finish condition
# Check finish condition
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
self
.
server_args
.
enable_overlap_schedule
and
req
.
finished
():
if
self
.
server_args
.
enable_overlap_schedule
and
req
.
finished
():
self
.
token_to_kv_pool
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
continue
continue
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
completion_tokens_wo_jump_forward
+=
1
...
...
python/sglang/srt/managers/tp_worker.py
View file @
769bf11c
...
@@ -149,14 +149,12 @@ class TpModelWorker:
...
@@ -149,14 +149,12 @@ class TpModelWorker:
)
)
# Resolve future tokens in the input
# Resolve future tokens in the input
# logger.info(f"raw input {model_worker_batch.input_ids=}")
tic2
=
time
.
time
()
tic2
=
time
.
time
()
resolved_input_ids
=
model_worker_batch
.
input_ids
resolved_input_ids
=
model_worker_batch
.
input_ids
future_mask
=
resolved_input_ids
<
0
future_mask
=
resolved_input_ids
<
0
resolved_input_ids
[
future_mask
]
=
self
.
future_token_ids_map
[
resolved_input_ids
[
future_mask
]
=
self
.
future_token_ids_map
[
-
resolved_input_ids
[
future_mask
]
-
resolved_input_ids
[
future_mask
]
]
]
# logger.info(f"resolved input {model_worker_batch.input_ids=}")
# Run forward
# Run forward
logits_output
,
next_token_ids
=
self
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
forward_batch_generation
(
...
@@ -215,12 +213,13 @@ class TpModelWorker:
...
@@ -215,12 +213,13 @@ class TpModelWorker:
self
.
future_logits_output_ct
+=
1
self
.
future_logits_output_ct
+=
1
bs
=
len
(
model_worker_batch
.
seq_lens
)
bs
=
len
(
model_worker_batch
.
seq_lens
)
future_next_token_ids
=
-
torch
.
arange
(
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
self
.
future_token_ids_ct
+
1
,
future_next_token_ids
=
-
torch
.
arange
(
self
.
future_token_ids_ct
+
1
+
bs
,
self
.
future_token_ids_ct
+
1
,
dtype
=
torch
.
int32
,
self
.
future_token_ids_ct
+
1
+
bs
,
device
=
self
.
device
,
dtype
=
torch
.
int32
,
)
device
=
self
.
device
,
)
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
)
%
self
.
future_token_ids_limit
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
769bf11c
...
@@ -38,16 +38,14 @@ class ChunkCache(BasePrefixCache):
...
@@ -38,16 +38,14 @@ class ChunkCache(BasePrefixCache):
max_prefix_len
=
len
(
key
)
max_prefix_len
=
len
(
key
)
return
entry
.
value
[:
max_prefix_len
],
entry
return
entry
.
value
[:
max_prefix_len
],
entry
def
cache_finished_req
(
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
,
free_delta
:
int
=
0
):
if
token_ids
is
None
:
if
token_ids
is
None
:
token_id_len
=
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
token_id_len
=
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
else
:
else
:
token_id_len
=
len
(
token_ids
)
token_id_len
=
len
(
token_ids
)
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
token_id_len
+
free_delta
req
.
req_pool_idx
,
:
token_id_len
]
]
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
769bf11c
...
@@ -97,9 +97,7 @@ class RadixCache(BasePrefixCache):
...
@@ -97,9 +97,7 @@ class RadixCache(BasePrefixCache):
value
=
[
x
for
x
in
key
]
value
=
[
x
for
x
in
key
]
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
def
cache_finished_req
(
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
,
free_delta
:
int
=
0
):
"""Cache request when it finishes."""
"""Cache request when it finishes."""
if
self
.
disable
:
if
self
.
disable
:
if
token_ids
is
None
:
if
token_ids
is
None
:
...
@@ -108,7 +106,7 @@ class RadixCache(BasePrefixCache):
...
@@ -108,7 +106,7 @@ class RadixCache(BasePrefixCache):
token_ids_len
=
len
(
token_ids
)
token_ids_len
=
len
(
token_ids
)
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
token_ids_len
+
free_delta
req
.
req_pool_idx
,
:
token_ids_len
]
]
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
...
@@ -123,12 +121,6 @@ class RadixCache(BasePrefixCache):
...
@@ -123,12 +121,6 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool
# Radix Cache takes one ref in memory pool
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
if
free_delta
:
self
.
token_to_kv_pool
.
free
(
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
len
(
token_ids
)
:
len
(
token_ids
)
+
1
]
)
# Remove req slot release the cache lock
# Remove req slot release the cache lock
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
...
...
python/sglang/srt/server.py
View file @
769bf11c
...
@@ -542,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
...
@@ -542,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
kill_child_process
(
pid
,
including_parent
=
False
)
kill_child_process
(
pid
,
including_parent
=
False
)
return
return
# logger.info(f"{res.json()=}")
logger
.
info
(
"The server is fired up and ready to roll!"
)
logger
.
info
(
"The server is fired up and ready to roll!"
)
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
"ready"
)
pipe_finish_writer
.
send
(
"ready"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment