Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
32eb6e96
"vscode:/vscode.git/clone" did not exist on "45123017563c459d698b6221965661b02c783885"
Unverified
Commit
32eb6e96
authored
Oct 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 03, 2024
Browse files
Organize sampling batch info better (#1562)
parent
e0b5dbce
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
43 additions
and
35 deletions
+43
-35
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+6
-4
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+3
-1
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+2
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-2
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+20
-17
python/sglang/srt/server.py
python/sglang/srt/server.py
+6
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+3
-2
test/killall_sglang.sh
test/killall_sglang.sh
+1
-1
No files found.
python/sglang/srt/managers/scheduler.py
View file @
32eb6e96
...
@@ -96,7 +96,9 @@ class Scheduler:
...
@@ -96,7 +96,9 @@ class Scheduler:
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
scheduler_port
}
"
)
self
.
recv_from_tokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
scheduler_input_port
}
"
)
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
.
connect
(
self
.
send_to_detokenizer
.
connect
(
...
@@ -141,9 +143,6 @@ class Scheduler:
...
@@ -141,9 +143,6 @@ class Scheduler:
nccl_port
=
port_args
.
nccl_ports
[
0
],
nccl_port
=
port_args
.
nccl_ports
[
0
],
)
)
self
.
tp_cpu_group
=
self
.
tp_worker
.
model_runner
.
tp_group
.
cpu_group
self
.
tp_cpu_group
=
self
.
tp_worker
.
model_runner
.
tp_group
.
cpu_group
self
.
pad_input_ids_func
=
getattr
(
self
.
tp_worker
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
# Get token and memory info from the tp worker
# Get token and memory info from the tp worker
(
(
...
@@ -154,6 +153,9 @@ class Scheduler:
...
@@ -154,6 +153,9 @@ class Scheduler:
self
.
random_seed
,
self
.
random_seed
,
)
=
self
.
tp_worker
.
get_token_and_memory_info
()
)
=
self
.
tp_worker
.
get_token_and_memory_info
()
set_random_seed
(
self
.
random_seed
)
set_random_seed
(
self
.
random_seed
)
self
.
pad_input_ids_func
=
getattr
(
self
.
tp_worker
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
# Print debug info
# Print debug info
logger
.
info
(
logger
.
info
(
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
32eb6e96
...
@@ -87,7 +87,9 @@ class TokenizerManager:
...
@@ -87,7 +87,9 @@ class TokenizerManager:
self
.
recv_from_detokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
tokenizer_port
}
"
)
self
.
recv_from_detokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
tokenizer_port
}
"
)
self
.
send_to_scheduler
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_scheduler
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_scheduler
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
scheduler_port
}
"
)
self
.
send_to_scheduler
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
scheduler_input_port
}
"
)
# Read model args
# Read model args
self
.
model_path
=
server_args
.
model_path
self
.
model_path
=
server_args
.
model_path
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
32eb6e96
...
@@ -30,6 +30,7 @@ class ReqToTokenPool:
...
@@ -30,6 +30,7 @@ class ReqToTokenPool:
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
):
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
):
self
.
size
=
size
self
.
size
=
size
self
.
max_context_len
=
max_context_len
self
.
free_slots
=
list
(
range
(
size
))
self
.
free_slots
=
list
(
range
(
size
))
self
.
req_to_token
=
torch
.
empty
(
self
.
req_to_token
=
torch
.
empty
(
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
...
@@ -54,7 +55,7 @@ class ReqToTokenPool:
...
@@ -54,7 +55,7 @@ class ReqToTokenPool:
self
.
free_slots
=
list
(
range
(
self
.
size
))
self
.
free_slots
=
list
(
range
(
self
.
size
))
class
BaseTokenToKVPool
(
ABC
)
:
class
BaseTokenToKVPool
:
"""A memory pool that maps a token to its kv cache locations"""
"""A memory pool that maps a token to its kv cache locations"""
def
__init__
(
def
__init__
(
...
@@ -92,19 +93,15 @@ class BaseTokenToKVPool(ABC):
...
@@ -92,19 +93,15 @@ class BaseTokenToKVPool(ABC):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_slots
=
np
.
arange
(
1
,
self
.
size
+
1
)
self
.
free_slots
=
np
.
arange
(
1
,
self
.
size
+
1
)
@
abstractmethod
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
def
get_value_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
def
get_value_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
def
get_kv_buffer
(
self
,
layer_id
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
get_kv_buffer
(
self
,
layer_id
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
def
set_kv_buffer
(
def
set_kv_buffer
(
self
,
self
,
layer_id
:
int
,
layer_id
:
int
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
32eb6e96
...
@@ -411,8 +411,8 @@ class ModelRunner:
...
@@ -411,8 +411,8 @@ class ModelRunner:
device
=
"cuda"
device
=
"cuda"
self
.
req_to_token_pool
=
ReqToTokenPool
(
self
.
req_to_token_pool
=
ReqToTokenPool
(
max_num_reqs
+
1
,
size
=
max_num_reqs
+
1
,
self
.
model_config
.
context_len
+
4
,
max_context_len
=
self
.
model_config
.
context_len
+
4
,
device
=
device
,
device
=
device
,
)
)
if
(
if
(
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
32eb6e96
...
@@ -14,16 +14,17 @@ if TYPE_CHECKING:
...
@@ -14,16 +14,17 @@ if TYPE_CHECKING:
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
SamplingBatchInfo
:
class
SamplingBatchInfo
:
# Basic Info
vocab_size
:
int
# Batched sampling params
# Batched sampling params
temperatures
:
torch
.
Tensor
=
None
temperatures
:
torch
.
Tensor
top_ps
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
top_ks
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
min_ps
:
torch
.
Tensor
=
None
min_ps
:
torch
.
Tensor
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
# Bias Tensors
# Bias Tensors
vocab_size
:
int
logit_bias
:
torch
.
Tensor
=
None
logit_bias
:
torch
.
Tensor
=
None
vocab_mask
:
torch
.
Tensor
=
None
vocab_mask
:
torch
.
Tensor
=
None
...
@@ -31,9 +32,6 @@ class SamplingBatchInfo:
...
@@ -31,9 +32,6 @@ class SamplingBatchInfo:
regex_fsms
:
List
[
RegexGuide
]
=
None
regex_fsms
:
List
[
RegexGuide
]
=
None
regex_fsm_states
:
List
[
int
]
=
None
regex_fsm_states
:
List
[
int
]
=
None
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
=
False
# Penalizer
# Penalizer
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
linear_penalties
:
torch
.
Tensor
=
None
linear_penalties
:
torch
.
Tensor
=
None
...
@@ -42,25 +40,30 @@ class SamplingBatchInfo:
...
@@ -42,25 +40,30 @@ class SamplingBatchInfo:
@
classmethod
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
ret
=
cls
(
vocab_size
=
vocab_size
)
with
torch
.
device
(
"cuda"
):
with
torch
.
device
(
"cuda"
):
ret
.
temperatures
=
torch
.
tensor
(
temperatures
=
torch
.
tensor
(
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
).
view
(
-
1
,
1
)
).
view
(
-
1
,
1
)
ret
.
top_ps
=
torch
.
tensor
(
top_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
)
)
ret
.
top_ks
=
torch
.
tensor
(
top_ks
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
)
)
ret
.
min_ps
=
torch
.
tensor
(
min_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
)
)
ret
=
cls
(
temperatures
=
temperatures
,
top_ps
=
top_ps
,
top_ks
=
top_ks
,
min_ps
=
min_ps
,
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
),
vocab_size
=
vocab_size
,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
ret
.
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
...
...
python/sglang/srt/server.py
View file @
32eb6e96
...
@@ -118,6 +118,7 @@ async def health_generate(request: Request) -> Response:
...
@@ -118,6 +118,7 @@ async def health_generate(request: Request) -> Response:
@
app
.
get
(
"/get_model_info"
)
@
app
.
get
(
"/get_model_info"
)
async
def
get_model_info
():
async
def
get_model_info
():
"""Get the model information."""
result
=
{
result
=
{
"model_path"
:
tokenizer_manager
.
model_path
,
"model_path"
:
tokenizer_manager
.
model_path
,
"is_generation"
:
tokenizer_manager
.
is_generation
,
"is_generation"
:
tokenizer_manager
.
is_generation
,
...
@@ -127,11 +128,13 @@ async def get_model_info():
...
@@ -127,11 +128,13 @@ async def get_model_info():
@
app
.
get
(
"/get_server_args"
)
@
app
.
get
(
"/get_server_args"
)
async
def
get_server_args
():
async
def
get_server_args
():
"""Get the server arguments."""
return
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
)
return
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
)
@
app
.
get
(
"/flush_cache"
)
@
app
.
get
(
"/flush_cache"
)
async
def
flush_cache
():
async
def
flush_cache
():
"""Flush the radix cache."""
tokenizer_manager
.
flush_cache
()
tokenizer_manager
.
flush_cache
()
return
Response
(
return
Response
(
content
=
"Cache flushed.
\n
Please check backend logs for more details. "
content
=
"Cache flushed.
\n
Please check backend logs for more details. "
...
@@ -142,7 +145,7 @@ async def flush_cache():
...
@@ -142,7 +145,7 @@ async def flush_cache():
@
app
.
post
(
"/update_weights"
)
@
app
.
post
(
"/update_weights"
)
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
"""Update the weights inplace without re-launching the server."""
success
,
message
=
await
tokenizer_manager
.
update_weights
(
obj
,
request
)
success
,
message
=
await
tokenizer_manager
.
update_weights
(
obj
,
request
)
content
=
{
"success"
:
success
,
"message"
:
message
}
content
=
{
"success"
:
success
,
"message"
:
message
}
if
success
:
if
success
:
...
@@ -205,7 +208,7 @@ app.put("/encode")(encode_request)
...
@@ -205,7 +208,7 @@ app.put("/encode")(encode_request)
async
def
judge_request
(
obj
:
RewardReqInput
,
request
:
Request
):
async
def
judge_request
(
obj
:
RewardReqInput
,
request
:
Request
):
"""Handle a
n embedding
request."""
"""Handle a
reward model
request."""
try
:
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
return
ret
...
@@ -307,7 +310,7 @@ def launch_server(
...
@@ -307,7 +310,7 @@ def launch_server(
ports
=
server_args
.
additional_ports
ports
=
server_args
.
additional_ports
port_args
=
PortArgs
(
port_args
=
PortArgs
(
tokenizer_port
=
ports
[
0
],
tokenizer_port
=
ports
[
0
],
scheduler_port
=
ports
[
1
],
scheduler_
input_
port
=
ports
[
1
],
detokenizer_port
=
ports
[
2
],
detokenizer_port
=
ports
[
2
],
nccl_ports
=
ports
[
3
:],
nccl_ports
=
ports
[
3
:],
)
)
...
...
python/sglang/srt/server_args.py
View file @
32eb6e96
...
@@ -627,10 +627,11 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
...
@@ -627,10 +627,11 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
class
PortArgs
:
class
PortArgs
:
# The port for tokenizer to receive inputs from detokenizer (zmq)
# The port for tokenizer to receive inputs from detokenizer (zmq)
tokenizer_port
:
int
tokenizer_port
:
int
# The port for scheduler to receive inputs from tokenizer (zmq)
# The port for scheduler
(rank 0)
to receive inputs from tokenizer (zmq)
scheduler_port
:
int
scheduler_
input_
port
:
int
# The port for detokenizer to receive inputs from scheduler (zmq)
# The port for detokenizer to receive inputs from scheduler (zmq)
detokenizer_port
:
int
detokenizer_port
:
int
# The port for nccl initialization for multiple TP groups (torch.dist)
# The port for nccl initialization for multiple TP groups (torch.dist)
nccl_ports
:
List
[
int
]
nccl_ports
:
List
[
int
]
...
...
test/killall_sglang.sh
View file @
32eb6e96
kill
-9
$(
ps aux |
grep
'
sglang.launch_server
'
|
grep
-v
'grep'
|
awk
'{print $2}'
)
kill
-9
$(
ps aux |
grep
'
multiprocessing.spawn
'
|
grep
-v
'grep'
|
awk
'{print $2}'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment