Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
32eb6e96
Unverified
Commit
32eb6e96
authored
Oct 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 03, 2024
Browse files
Organize sampling batch info better (#1562)
parent
e0b5dbce
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
43 additions
and
35 deletions
+43
-35
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+6
-4
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+3
-1
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+2
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-2
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+20
-17
python/sglang/srt/server.py
python/sglang/srt/server.py
+6
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+3
-2
test/killall_sglang.sh
test/killall_sglang.sh
+1
-1
No files found.
python/sglang/srt/managers/scheduler.py
View file @
32eb6e96
...
...
@@ -96,7 +96,9 @@ class Scheduler:
if
self
.
tp_rank
==
0
:
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
scheduler_port
}
"
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
scheduler_input_port
}
"
)
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
.
connect
(
...
...
@@ -141,9 +143,6 @@ class Scheduler:
nccl_port
=
port_args
.
nccl_ports
[
0
],
)
self
.
tp_cpu_group
=
self
.
tp_worker
.
model_runner
.
tp_group
.
cpu_group
self
.
pad_input_ids_func
=
getattr
(
self
.
tp_worker
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
# Get token and memory info from the tp worker
(
...
...
@@ -154,6 +153,9 @@ class Scheduler:
self
.
random_seed
,
)
=
self
.
tp_worker
.
get_token_and_memory_info
()
set_random_seed
(
self
.
random_seed
)
self
.
pad_input_ids_func
=
getattr
(
self
.
tp_worker
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
# Print debug info
logger
.
info
(
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
32eb6e96
...
...
@@ -87,7 +87,9 @@ class TokenizerManager:
self
.
recv_from_detokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
tokenizer_port
}
"
)
self
.
send_to_scheduler
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_scheduler
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
scheduler_port
}
"
)
self
.
send_to_scheduler
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
scheduler_input_port
}
"
)
# Read model args
self
.
model_path
=
server_args
.
model_path
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
32eb6e96
...
...
@@ -30,6 +30,7 @@ class ReqToTokenPool:
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
):
self
.
size
=
size
self
.
max_context_len
=
max_context_len
self
.
free_slots
=
list
(
range
(
size
))
self
.
req_to_token
=
torch
.
empty
(
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
...
...
@@ -54,7 +55,7 @@ class ReqToTokenPool:
self
.
free_slots
=
list
(
range
(
self
.
size
))
class
BaseTokenToKVPool
(
ABC
)
:
class
BaseTokenToKVPool
:
"""A memory pool that maps a token to its kv cache locations"""
def
__init__
(
...
...
@@ -92,19 +93,15 @@ class BaseTokenToKVPool(ABC):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_slots
=
np
.
arange
(
1
,
self
.
size
+
1
)
@
abstractmethod
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
@
abstractmethod
def
get_value_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
@
abstractmethod
def
get_kv_buffer
(
self
,
layer_id
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
()
@
abstractmethod
def
set_kv_buffer
(
self
,
layer_id
:
int
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
32eb6e96
...
...
@@ -411,8 +411,8 @@ class ModelRunner:
device
=
"cuda"
self
.
req_to_token_pool
=
ReqToTokenPool
(
max_num_reqs
+
1
,
self
.
model_config
.
context_len
+
4
,
size
=
max_num_reqs
+
1
,
max_context_len
=
self
.
model_config
.
context_len
+
4
,
device
=
device
,
)
if
(
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
32eb6e96
...
...
@@ -14,16 +14,17 @@ if TYPE_CHECKING:
@
dataclasses
.
dataclass
class
SamplingBatchInfo
:
# Basic Info
vocab_size
:
int
# Batched sampling params
temperatures
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
=
None
min_ps
:
torch
.
Tensor
=
None
temperatures
:
torch
.
Tensor
top_ps
:
torch
.
Tensor
top_ks
:
torch
.
Tensor
min_ps
:
torch
.
Tensor
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
# Bias Tensors
vocab_size
:
int
logit_bias
:
torch
.
Tensor
=
None
vocab_mask
:
torch
.
Tensor
=
None
...
...
@@ -31,9 +32,6 @@ class SamplingBatchInfo:
regex_fsms
:
List
[
RegexGuide
]
=
None
regex_fsm_states
:
List
[
int
]
=
None
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
=
False
# Penalizer
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
linear_penalties
:
torch
.
Tensor
=
None
...
...
@@ -42,25 +40,30 @@ class SamplingBatchInfo:
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
reqs
=
batch
.
reqs
ret
=
cls
(
vocab_size
=
vocab_size
)
with
torch
.
device
(
"cuda"
):
ret
.
temperatures
=
torch
.
tensor
(
temperatures
=
torch
.
tensor
(
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
dtype
=
torch
.
float
,
).
view
(
-
1
,
1
)
ret
.
top_ps
=
torch
.
tensor
(
top_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
)
ret
.
top_ks
=
torch
.
tensor
(
top_ks
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
)
ret
.
min_ps
=
torch
.
tensor
(
min_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
)
ret
=
cls
(
temperatures
=
temperatures
,
top_ps
=
top_ps
,
top_ks
=
top_ks
,
min_ps
=
min_ps
,
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
),
vocab_size
=
vocab_size
,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
ret
.
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
...
...
python/sglang/srt/server.py
View file @
32eb6e96
...
...
@@ -118,6 +118,7 @@ async def health_generate(request: Request) -> Response:
@
app
.
get
(
"/get_model_info"
)
async
def
get_model_info
():
"""Get the model information."""
result
=
{
"model_path"
:
tokenizer_manager
.
model_path
,
"is_generation"
:
tokenizer_manager
.
is_generation
,
...
...
@@ -127,11 +128,13 @@ async def get_model_info():
@
app
.
get
(
"/get_server_args"
)
async
def
get_server_args
():
"""Get the server arguments."""
return
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
)
@
app
.
get
(
"/flush_cache"
)
async
def
flush_cache
():
"""Flush the radix cache."""
tokenizer_manager
.
flush_cache
()
return
Response
(
content
=
"Cache flushed.
\n
Please check backend logs for more details. "
...
...
@@ -142,7 +145,7 @@ async def flush_cache():
@
app
.
post
(
"/update_weights"
)
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
"""Update the weights inplace without re-launching the server."""
success
,
message
=
await
tokenizer_manager
.
update_weights
(
obj
,
request
)
content
=
{
"success"
:
success
,
"message"
:
message
}
if
success
:
...
...
@@ -205,7 +208,7 @@ app.put("/encode")(encode_request)
async
def
judge_request
(
obj
:
RewardReqInput
,
request
:
Request
):
"""Handle a
n embedding
request."""
"""Handle a
reward model
request."""
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
...
...
@@ -307,7 +310,7 @@ def launch_server(
ports
=
server_args
.
additional_ports
port_args
=
PortArgs
(
tokenizer_port
=
ports
[
0
],
scheduler_port
=
ports
[
1
],
scheduler_
input_
port
=
ports
[
1
],
detokenizer_port
=
ports
[
2
],
nccl_ports
=
ports
[
3
:],
)
...
...
python/sglang/srt/server_args.py
View file @
32eb6e96
...
...
@@ -627,10 +627,11 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
class
PortArgs
:
# The port for tokenizer to receive inputs from detokenizer (zmq)
tokenizer_port
:
int
# The port for scheduler to receive inputs from tokenizer (zmq)
scheduler_port
:
int
# The port for scheduler
(rank 0)
to receive inputs from tokenizer (zmq)
scheduler_
input_
port
:
int
# The port for detokenizer to receive inputs from scheduler (zmq)
detokenizer_port
:
int
# The port for nccl initialization for multiple TP groups (torch.dist)
nccl_ports
:
List
[
int
]
...
...
test/killall_sglang.sh
View file @
32eb6e96
kill
-9
$(
ps aux |
grep
'
sglang.launch_server
'
|
grep
-v
'grep'
|
awk
'{print $2}'
)
kill
-9
$(
ps aux |
grep
'
multiprocessing.spawn
'
|
grep
-v
'grep'
|
awk
'{print $2}'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment