Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b48edff6
Unverified
Commit
b48edff6
authored
Oct 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 20, 2024
Browse files
Split the overlapped version of TpModelWorkerClient into a separate file (#1726)
parent
593b19f2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
217 additions
and
131 deletions
+217
-131
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+8
-12
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+0
-109
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+174
-0
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+29
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+1
-1
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
b48edff6
...
...
@@ -639,8 +639,8 @@ class ScheduleBatch:
if
isinstance
(
self
.
tree_cache
,
ChunkCache
):
# ChunkCache does not have eviction
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
:
seq_lens_cpu
[
idx
]
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
seq_lens_cpu
[
idx
]
]
self
.
token_to_kv_pool
.
free
(
token_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
...
...
@@ -648,8 +648,8 @@ class ScheduleBatch:
else
:
# TODO: apply more fine-grained retraction
last_uncached_pos
=
len
(
req
.
prefix_indices
)
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
last_uncached_pos
:
seq_lens_cpu
[
idx
]
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
last_uncached_pos
:
seq_lens_cpu
[
idx
]
]
self
.
token_to_kv_pool
.
free
(
token_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
b48edff6
...
...
@@ -59,6 +59,7 @@ from sglang.srt.managers.schedule_policy import (
SchedulePolicy
,
)
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker_overlap_thread
import
TpModelWorkerClient
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
@@ -146,9 +147,14 @@ class Scheduler:
# Launch a tensor parallel worker
if
self
.
server_args
.
enable_overlap_schedule
:
TpWorkerClass
=
TpModelWorker
TpWorkerClass
=
TpModelWorkerClient
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
)
else
:
TpWorkerClass
=
TpModelWorker
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
self
.
tp_worker
=
TpWorkerClass
(
server_args
=
server_args
,
gpu_id
=
gpu_id
,
...
...
@@ -156,16 +162,6 @@ class Scheduler:
dp_rank
=
dp_rank
,
nccl_port
=
port_args
.
nccl_port
,
)
if
self
.
server_args
.
enable_overlap_schedule
:
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
)
self
.
forward_batch_generation
=
(
self
.
tp_worker
.
forward_batch_generation_non_blocking
)
else
:
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
self
.
forward_batch_generation
=
self
.
tp_worker
.
forward_batch_generation
# Get token and memory info from the model worker
(
...
...
@@ -728,7 +724,7 @@ class Scheduler:
if
self
.
is_generation
:
if
batch
.
forward_mode
.
is_decode
()
or
batch
.
extend_num_tokens
!=
0
:
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
next_token_ids
=
self
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
else
:
...
...
python/sglang/srt/managers/tp_worker.py
View file @
b48edff6
...
...
@@ -17,13 +17,8 @@ limitations under the License.
import
json
import
logging
import
threading
import
time
from
queue
import
Queue
from
typing
import
Optional
import
torch
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
UpdateWeightReqInput
...
...
@@ -108,9 +103,6 @@ class TpModelWorker:
)[
0
]
set_random_seed
(
self
.
random_seed
)
if
server_args
.
enable_overlap_schedule
:
self
.
init_overlap_status
()
def
get_worker_info
(
self
):
return
(
self
.
max_total_num_tokens
,
...
...
@@ -137,81 +129,6 @@ class TpModelWorker:
self
.
model_runner
.
token_to_kv_pool
,
)
def
init_overlap_status
(
self
):
self
.
future_logits_output_dict
=
dict
()
self
.
future_logits_output_ct
=
0
self
.
future_token_ids_ct
=
0
self
.
future_token_ids_map
=
torch
.
empty
(
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_output
=
dict
()
self
.
future_event_map
=
dict
()
self
.
forward_queue
=
Queue
()
self
.
forward_stream
=
torch
.
cuda
.
Stream
()
self
.
forward_thread
=
threading
.
Thread
(
target
=
self
.
forward_thread_func
,
)
self
.
forward_thread
.
start
()
def
forward_thread_func
(
self
):
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
self
.
forward_thread_func_
()
@
torch
.
inference_mode
()
def
forward_thread_func_
(
self
):
while
True
:
tic1
=
time
.
time
()
model_worker_batch
,
future_logits_output
,
future_next_token_ids
=
(
self
.
forward_queue
.
get
()
)
# Resolve future tokens in the input
tic2
=
time
.
time
()
resolved_input_ids
=
model_worker_batch
.
input_ids
future_mask
=
resolved_input_ids
<
0
resolved_input_ids
[
future_mask
]
=
self
.
future_token_ids_map
[
-
resolved_input_ids
[
future_mask
]
]
# Run forward
logits_output
,
next_token_ids
=
self
.
forward_batch_generation
(
model_worker_batch
)
# Set future values
if
model_worker_batch
.
return_logprob
:
self
.
future_logits_output_dict
[
future_logits_output
]
=
logits_output
# logger.info(f"set output {future_next_token_ids=}, {next_token_ids=}")
self
.
future_token_ids_map
[
-
future_next_token_ids
]
=
next_token_ids
.
to
(
torch
.
int32
)
# logger.info("Set event")
self
.
future_token_ids_output
[
model_worker_batch
.
bid
]
=
(
next_token_ids
.
tolist
()
)
self
.
future_event_map
[
model_worker_batch
.
bid
].
set
()
if
False
:
tic3
=
time
.
time
()
self
.
acc_time_with_waiting
+=
tic3
-
tic1
self
.
acc_time_without_waiting
+=
tic3
-
tic2
if
self
.
forward_queue
.
qsize
()
==
0
:
logger
.
info
(
f
"
{
self
.
acc_time_with_waiting
=
:.
3
f
}
,
{
self
.
acc_time_without_waiting
=
:.
3
f
}
,
{
self
.
forward_queue
.
qsize
()
=
}
"
)
def
resolve_future_token_ids
(
self
,
bid
:
int
):
self
.
future_event_map
[
bid
].
wait
()
ret
=
self
.
future_token_ids_output
[
bid
]
del
self
.
future_event_map
[
bid
]
return
ret
def
resolve_future_logits_output
(
self
,
future_obj
):
return
self
.
future_logits_output_dict
.
pop
(
future_obj
)
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
...
...
@@ -224,32 +141,6 @@ class TpModelWorker:
embeddings
=
logits_output
.
embeddings
return
embeddings
def
forward_batch_generation_non_blocking
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
# Allocate output future objects
future_logits_output
=
self
.
future_logits_output_ct
self
.
future_logits_output_ct
+=
1
bs
=
len
(
model_worker_batch
.
seq_lens
)
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
future_next_token_ids
=
-
torch
.
arange
(
self
.
future_token_ids_ct
+
1
,
self
.
future_token_ids_ct
+
1
+
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
ret
=
future_logits_output
,
future_next_token_ids
self
.
future_event_map
[
model_worker_batch
.
bid
]
=
threading
.
Event
()
self
.
forward_queue
.
put
(
(
model_worker_batch
.
copy
(),
future_logits_output
,
future_next_token_ids
)
)
return
ret
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights
(
recv_req
.
model_path
,
recv_req
.
load_format
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
0 → 100644
View file @
b48edff6
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""A tensor parallel worker."""
import
logging
import
threading
import
time
from
queue
import
Queue
from
typing
import
Optional
import
torch
from
sglang.srt.managers.io_struct
import
UpdateWeightReqInput
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.server_args
import
ServerArgs
logger
=
logging
.
getLogger
(
__name__
)
class
TpModelWorkerClient
:
"""A tensor parallel model worker."""
def
__init__
(
self
,
server_args
:
ServerArgs
,
gpu_id
:
int
,
tp_rank
:
int
,
dp_rank
:
Optional
[
int
],
nccl_port
:
int
,
):
# Load the model
self
.
worker
=
TpModelWorker
(
server_args
,
gpu_id
,
tp_rank
,
dp_rank
,
nccl_port
)
self
.
max_running_requests
=
self
.
worker
.
max_running_requests
self
.
device
=
self
.
worker
.
device
# Create future mappings
self
.
future_logits_output_dict
=
dict
()
self
.
future_logits_output_ct
=
0
self
.
future_token_ids_ct
=
0
self
.
future_token_ids_map
=
torch
.
empty
(
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_output
=
dict
()
# Launch a thread
self
.
future_event_map
=
dict
()
self
.
forward_queue
=
Queue
()
self
.
forward_stream
=
torch
.
cuda
.
Stream
()
self
.
forward_thread
=
threading
.
Thread
(
target
=
self
.
forward_thread_func
,
)
self
.
forward_thread
.
start
()
def
get_worker_info
(
self
):
return
self
.
worker
.
get_worker_info
()
def
get_pad_input_ids_func
(
self
):
return
self
.
worker
.
get_pad_input_ids_func
()
def
get_tp_cpu_group
(
self
):
return
self
.
worker
.
get_tp_cpu_group
()
def
get_memory_pool
(
self
):
return
(
self
.
worker
.
model_runner
.
req_to_token_pool
,
self
.
worker
.
model_runner
.
token_to_kv_pool
,
)
def
forward_thread_func
(
self
):
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
self
.
forward_thread_func_
()
@
torch
.
inference_mode
()
def
forward_thread_func_
(
self
):
while
True
:
tic1
=
time
.
time
()
model_worker_batch
,
future_logits_output
,
future_next_token_ids
=
(
self
.
forward_queue
.
get
()
)
# Resolve future tokens in the input
tic2
=
time
.
time
()
resolved_input_ids
=
model_worker_batch
.
input_ids
future_mask
=
resolved_input_ids
<
0
resolved_input_ids
[
future_mask
]
=
self
.
future_token_ids_map
[
-
resolved_input_ids
[
future_mask
]
]
# Run forward
logits_output
,
next_token_ids
=
self
.
worker
.
forward_batch_generation
(
model_worker_batch
)
# Set future values
if
model_worker_batch
.
return_logprob
:
self
.
future_logits_output_dict
[
future_logits_output
]
=
logits_output
self
.
future_token_ids_map
[
-
future_next_token_ids
]
=
next_token_ids
.
to
(
torch
.
int32
)
self
.
future_token_ids_output
[
model_worker_batch
.
bid
]
=
(
next_token_ids
.
tolist
()
)
self
.
future_event_map
[
model_worker_batch
.
bid
].
set
()
if
False
:
tic3
=
time
.
time
()
self
.
acc_time_with_waiting
+=
tic3
-
tic1
self
.
acc_time_without_waiting
+=
tic3
-
tic2
if
self
.
forward_queue
.
qsize
()
==
0
:
logger
.
info
(
f
"
{
self
.
acc_time_with_waiting
=
:.
3
f
}
,
{
self
.
acc_time_without_waiting
=
:.
3
f
}
,
{
self
.
forward_queue
.
qsize
()
=
}
"
)
def
resolve_future_token_ids
(
self
,
bid
:
int
):
self
.
future_event_map
[
bid
].
wait
()
ret
=
self
.
future_token_ids_output
[
bid
]
del
self
.
future_event_map
[
bid
]
return
ret
def
resolve_future_logits_output
(
self
,
future_obj
):
return
self
.
future_logits_output_dict
.
pop
(
future_obj
)
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
# Allocate output future objects
future_logits_output
=
self
.
future_logits_output_ct
self
.
future_logits_output_ct
+=
1
bs
=
len
(
model_worker_batch
.
seq_lens
)
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
future_next_token_ids
=
-
torch
.
arange
(
self
.
future_token_ids_ct
+
1
,
self
.
future_token_ids_ct
+
1
+
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
ret
=
future_logits_output
,
future_next_token_ids
self
.
future_event_map
[
model_worker_batch
.
bid
]
=
threading
.
Event
()
self
.
forward_queue
.
put
(
(
model_worker_batch
.
copy
(),
future_logits_output
,
future_next_token_ids
)
)
return
ret
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
embeddings
=
logits_output
.
embeddings
return
embeddings
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights
(
recv_req
.
model_path
,
recv_req
.
load_format
)
return
success
,
message
python/sglang/srt/mem_cache/memory_pool.py
View file @
b48edff6
...
...
@@ -13,7 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Memory pool."""
"""
Memory pool.
SGLang has two levels of memory pool.
ReqToTokenPool maps a a request to its token locations.
BaseTokenToKVPool maps a token location to its KV cache data.
"""
import
logging
from
typing
import
List
,
Tuple
,
Union
...
...
@@ -26,7 +32,7 @@ logger = logging.getLogger(__name__)
class
ReqToTokenPool
:
"""A memory pool that maps a request to its token locations."""
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
):
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
,
use_records
:
bool
):
self
.
size
=
size
self
.
max_context_len
=
max_context_len
self
.
device
=
device
...
...
@@ -34,6 +40,13 @@ class ReqToTokenPool:
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
free_slots
=
list
(
range
(
size
))
self
.
write_records
=
[]
if
use_records
:
# records all write operations
self
.
write
=
self
.
write_with_records
else
:
self
.
write
=
self
.
write_without_records
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
...
...
@@ -55,16 +68,27 @@ class ReqToTokenPool:
def
clear
(
self
):
self
.
free_slots
=
list
(
range
(
self
.
size
))
self
.
write_records
=
[]
def
write_without_records
(
self
,
indices
,
values
):
self
.
req_to_token
[
indices
]
=
values
def
write
(
self
,
indices
,
values
):
def
write
_with_records
(
self
,
indices
,
values
):
self
.
req_to_token
[
indices
]
=
values
self
.
write_records
.
append
((
indices
,
values
))
def
get_write_records
(
self
):
return
None
ret
=
self
.
write_records
self
.
write_records
=
[]
return
ret
def
apply_write_records
(
self
,
write_records
:
List
[
Tuple
]):
for
indices
,
values
in
write_records
:
self
.
req_to_token
[
indices
]
=
values
class
BaseTokenToKVPool
:
"""A memory pool that maps a token to its kv cache
locations
"""
"""A memory pool that maps a token
location
to its kv cache
data.
"""
def
__init__
(
self
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
b48edff6
...
...
@@ -461,6 +461,7 @@ class ModelRunner:
size
=
max_num_reqs
+
1
,
max_context_len
=
self
.
model_config
.
context_len
+
4
,
device
=
self
.
device
,
use_records
=
False
,
)
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
...
...
test/srt/test_vision_openai_server.py
View file @
b48edff6
...
...
@@ -170,7 +170,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
text
=
response
.
choices
[
0
].
message
.
content
assert
isinstance
(
text
,
str
)
print
(
text
)
assert
"man"
in
text
and
"taxi
"
in
text
,
text
assert
"man"
in
text
or
"cab
"
in
text
,
text
assert
"logo"
in
text
,
text
assert
response
.
id
assert
response
.
created
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment