Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
32eb6e96
Unverified
Commit
32eb6e96
authored
Oct 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 03, 2024
Browse files
Organize sampling batch info better (#1562)
parent
e0b5dbce
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
43 additions
and
35 deletions
+43
-35
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+6
-4
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+3
-1
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+2
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-2
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+20
-17
python/sglang/srt/server.py
python/sglang/srt/server.py
+6
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+3
-2
test/killall_sglang.sh
test/killall_sglang.sh
+1
-1
No files found.
python/sglang/srt/managers/scheduler.py
View file @
32eb6e96
...
@@ -96,7 +96,9 @@ class Scheduler:
...
@@ -96,7 +96,9 @@ class Scheduler:
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
scheduler_port
}
"
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
scheduler_input_port
}
"
)
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
.
connect
(
self
.
send_to_detokenizer
.
connect
(
...
@@ -141,9 +143,6 @@ class Scheduler:
...
@@ -141,9 +143,6 @@ class Scheduler:
nccl_port
=
port_args
.
nccl_ports
[
0
],
nccl_port
=
port_args
.
nccl_ports
[
0
],
)
)
self
.
tp_cpu_group
=
self
.
tp_worker
.
model_runner
.
tp_group
.
cpu_group
self
.
tp_cpu_group
=
self
.
tp_worker
.
model_runner
.
tp_group
.
cpu_group
self
.
pad_input_ids_func
=
getattr
(
self
.
tp_worker
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
# Get token and memory info from the tp worker
# Get token and memory info from the tp worker
(
(
...
@@ -154,6 +153,9 @@ class Scheduler:
...
@@ -154,6 +153,9 @@ class Scheduler:
self
.
random_seed
,
self
.
random_seed
,
)
=
self
.
tp_worker
.
get_token_and_memory_info
()
)
=
self
.
tp_worker
.
get_token_and_memory_info
()
set_random_seed
(
self
.
random_seed
)
set_random_seed
(
self
.
random_seed
)
self
.
pad_input_ids_func
=
getattr
(
self
.
tp_worker
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
# Print debug info
# Print debug info
logger
.
info
(
logger
.
info
(
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
32eb6e96
...
@@ -87,7 +87,9 @@ class TokenizerManager:
...
@@ -87,7 +87,9 @@ class TokenizerManager:
self
.
recv_from_detokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
tokenizer_port
}
"
)
self
.
recv_from_detokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
tokenizer_port
}
"
)
self
.
send_to_scheduler
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_scheduler
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_scheduler
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
scheduler_port
}
"
)
self
.
send_to_scheduler
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
scheduler_input_port
}
"
)
# Read model args
# Read model args
self
.
model_path
=
server_args
.
model_path
self
.
model_path
=
server_args
.
model_path
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
32eb6e96
...
@@ -30,6 +30,7 @@ class ReqToTokenPool:
...
@@ -30,6 +30,7 @@ class ReqToTokenPool:
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
):
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
):
self
.
size
=
size
self
.
size
=
size
self
.
max_context_len
=
max_context_len
self
.
free_slots
=
list
(
range
(
size
))
self
.
free_slots
=
list
(
range
(
size
))
self
.
req_to_token
=
torch
.
empty
(
self
.
req_to_token
=
torch
.
empty
(
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
...
@@ -54,7 +55,7 @@ class ReqToTokenPool:
...
@@ -54,7 +55,7 @@ class ReqToTokenPool:
self
.
free_slots
=
list
(
range
(
self
.
size
))
self
.
free_slots
=
list
(
range
(
self
.
size
))
class
BaseTokenToKVPool
(
ABC
)
:
class
BaseTokenToKVPool
:
"""A memory pool that maps a token to its kv cache locations"""
"""A memory pool that maps a token to its kv cache locations"""
def
__init__
(
def
__init__
(
...
@@ -92,19 +93,15 @@ class BaseTokenToKVPool(ABC):
...
@@ -92,19 +93,15 @@ class BaseTokenToKVPool(ABC):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_slots
=
np
.
arange
(
1
,
self
.
size
+
1
)
self
.
free_slots
=
np
.
arange
(
1
,
self
.
size
+
1
)
@
abstractmethod
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
def
get_value_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
def
get_value_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
def
get_kv_buffer
(
self
,
layer_id
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
get_kv_buffer
(
self
,
layer_id
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
def
set_kv_buffer
(
def
set_kv_buffer
(
self
,
self
,
layer_id
:
int
,
layer_id
:
int
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
32eb6e96
...
@@ -411,8 +411,8 @@ class ModelRunner:
...
@@ -411,8 +411,8 @@ class ModelRunner:
device
=
"cuda"
device
=
"cuda"
self
.
req_to_token_pool
=
ReqToTokenPool
(
self
.
req_to_token_pool
=
ReqToTokenPool
(
max_num_reqs
+
1
,
size
=
max_num_reqs
+
1
,
self
.
model_config
.
context_len
+
4
,
max_context_len
=
self
.
model_config
.
context_len
+
4
,
device
=
device
,
device
=
device
,
)
)
if
(
if
(
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
32eb6e96
...
@@ -14,16 +14,17 @@ if TYPE_CHECKING:
...
@@ -14,16 +14,17 @@ if TYPE_CHECKING:
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
SamplingBatchInfo
:
class
SamplingBatchInfo
:
# Basic Info
vocab_size
:
int
# Batched sampling params
# Batched sampling params
temperatures
:
torch
.
Tensor
=
None
temperatures
:
torch
.
Tensor
top_ps
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
top_ks
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
min_ps
:
torch
.
Tensor
=
None
min_ps
:
torch
.
Tensor
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
# Bias Tensors
# Bias Tensors
vocab_size
:
int
logit_bias
:
torch
.
Tensor
=
None
logit_bias
:
torch
.
Tensor
=
None
vocab_mask
:
torch
.
Tensor
=
None
vocab_mask
:
torch
.
Tensor
=
None
...
@@ -31,9 +32,6 @@ class SamplingBatchInfo:
...
@@ -31,9 +32,6 @@ class SamplingBatchInfo:
regex_fsms
:
List
[
RegexGuide
]
=
None
regex_fsms
:
List
[
RegexGuide
]
=
None
regex_fsm_states
:
List
[
int
]
=
None
regex_fsm_states
:
List
[
int
]
=
None
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
=
False
# Penalizer
# Penalizer
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
linear_penalties
:
torch
.
Tensor
=
None
linear_penalties
:
torch
.
Tensor
=
None
...
@@ -42,25 +40,30 @@ class SamplingBatchInfo:
...
@@ -42,25 +40,30 @@ class SamplingBatchInfo:
@
classmethod
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
ret
=
cls
(
vocab_size
=
vocab_size
)
with
torch
.
device
(
"cuda"
):
with
torch
.
device
(
"cuda"
):
ret
.
temperatures
=
torch
.
tensor
(
temperatures
=
torch
.
tensor
(
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
).
view
(
-
1
,
1
)
).
view
(
-
1
,
1
)
ret
.
top_ps
=
torch
.
tensor
(
top_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
)
)
ret
.
top_ks
=
torch
.
tensor
(
top_ks
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
)
)
ret
.
min_ps
=
torch
.
tensor
(
min_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
)
)
ret
=
cls
(
temperatures
=
temperatures
,
top_ps
=
top_ps
,
top_ks
=
top_ks
,
min_ps
=
min_ps
,
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
),
vocab_size
=
vocab_size
,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
ret
.
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
...
...
python/sglang/srt/server.py
View file @
32eb6e96
...
@@ -118,6 +118,7 @@ async def health_generate(request: Request) -> Response:
...
@@ -118,6 +118,7 @@ async def health_generate(request: Request) -> Response:
@
app
.
get
(
"/get_model_info"
)
@
app
.
get
(
"/get_model_info"
)
async
def
get_model_info
():
async
def
get_model_info
():
"""Get the model information."""
result
=
{
result
=
{
"model_path"
:
tokenizer_manager
.
model_path
,
"model_path"
:
tokenizer_manager
.
model_path
,
"is_generation"
:
tokenizer_manager
.
is_generation
,
"is_generation"
:
tokenizer_manager
.
is_generation
,
...
@@ -127,11 +128,13 @@ async def get_model_info():
...
@@ -127,11 +128,13 @@ async def get_model_info():
@
app
.
get
(
"/get_server_args"
)
@
app
.
get
(
"/get_server_args"
)
async
def
get_server_args
():
async
def
get_server_args
():
"""Get the server arguments."""
return
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
)
return
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
)
@
app
.
get
(
"/flush_cache"
)
@
app
.
get
(
"/flush_cache"
)
async
def
flush_cache
():
async
def
flush_cache
():
"""Flush the radix cache."""
tokenizer_manager
.
flush_cache
()
tokenizer_manager
.
flush_cache
()
return
Response
(
return
Response
(
content
=
"Cache flushed.
\n
Please check backend logs for more details. "
content
=
"Cache flushed.
\n
Please check backend logs for more details. "
...
@@ -142,7 +145,7 @@ async def flush_cache():
...
@@ -142,7 +145,7 @@ async def flush_cache():
@
app
.
post
(
"/update_weights"
)
@
app
.
post
(
"/update_weights"
)
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
"""Update the weights inplace without re-launching the server."""
success
,
message
=
await
tokenizer_manager
.
update_weights
(
obj
,
request
)
success
,
message
=
await
tokenizer_manager
.
update_weights
(
obj
,
request
)
content
=
{
"success"
:
success
,
"message"
:
message
}
content
=
{
"success"
:
success
,
"message"
:
message
}
if
success
:
if
success
:
...
@@ -205,7 +208,7 @@ app.put("/encode")(encode_request)
...
@@ -205,7 +208,7 @@ app.put("/encode")(encode_request)
async
def
judge_request
(
obj
:
RewardReqInput
,
request
:
Request
):
async
def
judge_request
(
obj
:
RewardReqInput
,
request
:
Request
):
"""Handle a
n embedding
request."""
"""Handle a
reward model
request."""
try
:
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
return
ret
...
@@ -307,7 +310,7 @@ def launch_server(
...
@@ -307,7 +310,7 @@ def launch_server(
ports
=
server_args
.
additional_ports
ports
=
server_args
.
additional_ports
port_args
=
PortArgs
(
port_args
=
PortArgs
(
tokenizer_port
=
ports
[
0
],
tokenizer_port
=
ports
[
0
],
scheduler_port
=
ports
[
1
],
scheduler_
input_
port
=
ports
[
1
],
detokenizer_port
=
ports
[
2
],
detokenizer_port
=
ports
[
2
],
nccl_ports
=
ports
[
3
:],
nccl_ports
=
ports
[
3
:],
)
)
...
...
python/sglang/srt/server_args.py
View file @
32eb6e96
...
@@ -627,10 +627,11 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
...
@@ -627,10 +627,11 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
class
PortArgs
:
class
PortArgs
:
# The port for tokenizer to receive inputs from detokenizer (zmq)
# The port for tokenizer to receive inputs from detokenizer (zmq)
tokenizer_port
:
int
tokenizer_port
:
int
# The port for scheduler to receive inputs from tokenizer (zmq)
# The port for scheduler
(rank 0)
to receive inputs from tokenizer (zmq)
scheduler_port
:
int
scheduler_
input_
port
:
int
# The port for detokenizer to receive inputs from scheduler (zmq)
# The port for detokenizer to receive inputs from scheduler (zmq)
detokenizer_port
:
int
detokenizer_port
:
int
# The port for nccl initialization for multiple TP groups (torch.dist)
# The port for nccl initialization for multiple TP groups (torch.dist)
nccl_ports
:
List
[
int
]
nccl_ports
:
List
[
int
]
...
...
test/killall_sglang.sh
View file @
32eb6e96
kill
-9
$(
ps aux |
grep
'
sglang.launch_server
'
|
grep
-v
'grep'
|
awk
'{print $2}'
)
kill
-9
$(
ps aux |
grep
'
multiprocessing.spawn
'
|
grep
-v
'grep'
|
awk
'{print $2}'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment