Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
63ba2f8d
Unverified
Commit
63ba2f8d
authored
Sep 30, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 30, 2024
Browse files
Clean up batch data structures: Introducing ModelWorkerBatch (#1544)
parent
36d5acfc
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
274 additions
and
155 deletions
+274
-155
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+15
-8
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+5
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+122
-48
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+12
-11
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+6
-3
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+64
-27
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+25
-30
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+25
-16
python/sglang/srt/server.py
python/sglang/srt/server.py
+0
-11
No files found.
python/sglang/bench_latency.py
View file @
63ba2f8d
...
@@ -62,11 +62,13 @@ import torch.distributed as dist
...
@@ -62,11 +62,13 @@ import torch.distributed as dist
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server
import
_set_envs_and_config
from
sglang.srt.server
import
_set_envs_and_config
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
allocate_init_ports
,
configure_logger
,
configure_logger
,
kill_child_process
,
kill_child_process
,
suppress_other_loggers
,
suppress_other_loggers
,
...
@@ -125,6 +127,11 @@ def load_model(server_args, tp_rank):
...
@@ -125,6 +127,11 @@ def load_model(server_args, tp_rank):
suppress_other_loggers
()
suppress_other_loggers
()
rank_print
=
print
if
tp_rank
==
0
else
lambda
*
args
,
**
kwargs
:
None
rank_print
=
print
if
tp_rank
==
0
else
lambda
*
args
,
**
kwargs
:
None
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
server_args
.
port
,
server_args
.
additional_ports
,
server_args
.
dp_size
,
)
model_config
=
ModelConfig
(
model_config
=
ModelConfig
(
server_args
.
model_path
,
server_args
.
model_path
,
server_args
.
trust_remote_code
,
server_args
.
trust_remote_code
,
...
@@ -136,7 +143,7 @@ def load_model(server_args, tp_rank):
...
@@ -136,7 +143,7 @@ def load_model(server_args, tp_rank):
gpu_id
=
tp_rank
,
gpu_id
=
tp_rank
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
tp_size
=
server_args
.
tp_size
,
tp_size
=
server_args
.
tp_size
,
nccl_port
=
28888
,
nccl_port
=
server_args
.
additional_ports
[
-
1
]
,
server_args
=
server_args
,
server_args
=
server_args
,
)
)
rank_print
(
f
"max_total_num_tokens=
{
model_runner
.
max_total_num_tokens
}
"
)
rank_print
(
f
"max_total_num_tokens=
{
model_runner
.
max_total_num_tokens
}
"
)
...
@@ -225,17 +232,19 @@ def extend(reqs, model_runner):
...
@@ -225,17 +232,19 @@ def extend(reqs, model_runner):
tree_cache
=
None
,
tree_cache
=
None
,
)
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
)
forward_batch
=
batch
.
get_forward_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
batch
).
tolist
()
next_token_ids
=
model_runner
.
sample
(
logits_output
,
forward_
batch
).
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
def
decode
(
input_token_ids
,
batch
,
model_runner
):
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
prepare_for_decode
(
input_token_ids
)
batch
.
prepare_for_decode
(
input_token_ids
)
forward_batch
=
batch
.
get_forward_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
batch
).
tolist
()
next_token_ids
=
model_runner
.
sample
(
logits_output
,
forward_
batch
).
tolist
()
return
next_token_ids
,
logits_output
.
next_token_logits
return
next_token_ids
,
logits_output
.
next_token_logits
...
@@ -357,7 +366,6 @@ def latency_test(
...
@@ -357,7 +366,6 @@ def latency_test(
tp_rank
,
tp_rank
,
):
):
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
_set_envs_and_config
(
server_args
)
rank_print
=
print
if
tp_rank
==
0
else
lambda
*
args
,
**
kwargs
:
None
rank_print
=
print
if
tp_rank
==
0
else
lambda
*
args
,
**
kwargs
:
None
# Load the model
# Load the model
...
@@ -463,6 +471,7 @@ def plot_latency_test(
...
@@ -463,6 +471,7 @@ def plot_latency_test(
def
main
(
server_args
,
bench_args
):
def
main
(
server_args
,
bench_args
):
_set_envs_and_config
(
server_args
)
if
server_args
.
model_path
:
if
server_args
.
model_path
:
if
bench_args
.
correctness_test
:
if
bench_args
.
correctness_test
:
...
@@ -513,8 +522,6 @@ if __name__ == "__main__":
...
@@ -513,8 +522,6 @@ if __name__ == "__main__":
format
=
"%(message)s"
,
format
=
"%(message)s"
,
)
)
multiprocessing
.
set_start_method
(
"spawn"
,
force
=
True
)
try
:
try
:
main
(
server_args
,
bench_args
)
main
(
server_args
,
bench_args
)
except
Exception
as
e
:
except
Exception
as
e
:
...
...
python/sglang/srt/layers/logits_processor.py
View file @
63ba2f8d
...
@@ -62,7 +62,11 @@ class LogitsMetadata:
...
@@ -62,7 +62,11 @@ class LogitsMetadata:
@
classmethod
@
classmethod
def
from_forward_batch
(
cls
,
forward_batch
:
ForwardBatch
):
def
from_forward_batch
(
cls
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
return_logprob
:
return_top_logprob
=
any
(
x
>
0
for
x
in
forward_batch
.
top_logprobs_nums
)
return_top_logprob
=
any
(
x
>
0
for
x
in
forward_batch
.
top_logprobs_nums
)
else
:
return_top_logprob
=
False
if
forward_batch
.
forward_mode
.
is_extend
():
if
forward_batch
.
forward_mode
.
is_extend
():
extend_logprob_pruned_lens_cpu
=
[
extend_logprob_pruned_lens_cpu
=
[
extend_len
-
start_len
extend_len
-
start_len
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
63ba2f8d
from
__future__
import
annotations
"""
"""
Copyright 2023-2024 SGLang Team
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -15,7 +13,19 @@ See the License for the specific language governing permissions and
...
@@ -15,7 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""Meta data for requests and batches"""
"""
Store information about requests and batches.
The following is the flow of data structures for a batch:
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
import
logging
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -29,7 +39,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
...
@@ -29,7 +39,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -105,6 +115,8 @@ class FINISH_ABORT(BaseFinishReason):
...
@@ -105,6 +115,8 @@ class FINISH_ABORT(BaseFinishReason):
@
dataclass
@
dataclass
class
ImageInputs
:
class
ImageInputs
:
"""The image related inputs."""
pixel_values
:
torch
.
Tensor
pixel_values
:
torch
.
Tensor
image_hash
:
int
image_hash
:
int
image_sizes
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
...
@@ -137,7 +149,7 @@ class ImageInputs:
...
@@ -137,7 +149,7 @@ class ImageInputs:
class
Req
:
class
Req
:
"""
Store all inforamtion
of a request."""
"""
The input and output status
of a request."""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -393,20 +405,20 @@ class ScheduleBatch:
...
@@ -393,20 +405,20 @@ class ScheduleBatch:
sampling_info
:
SamplingBatchInfo
=
None
sampling_info
:
SamplingBatchInfo
=
None
# Batched arguments to model runner
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
input_ids
:
List
[
int
]
=
None
req_pool_indices
:
torch
.
Tensor
=
None
req_pool_indices
:
List
[
int
]
=
None
seq_lens
:
torch
.
Tensor
=
None
seq_lens
:
List
[
int
]
=
None
position_ids_offsets
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
extend_num_tokens
:
int
=
None
# For mixed chunekd prefill
prefix_lens_cpu
:
List
[
int
]
=
None
running_bs
:
int
=
None
# For processing logprobs
# For processing logprobs
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
top_logprobs_nums
:
Optional
[
List
[
int
]]
=
None
# For extend and mixed chunekd prefill
prefix_lens
:
List
[
int
]
=
None
extend_lens
:
List
[
int
]
=
None
extend_num_tokens
:
int
=
None
running_bs
:
int
=
None
# Stream
# Stream
has_stream
:
bool
=
False
has_stream
:
bool
=
False
...
@@ -466,12 +478,12 @@ class ScheduleBatch:
...
@@ -466,12 +478,12 @@ class ScheduleBatch:
seq_lens
=
[]
seq_lens
=
[]
# Allocate memory
# Allocate memory
req_pool_indices
_cpu
=
self
.
alloc_req_slots
(
bs
)
req_pool_indices
=
self
.
alloc_req_slots
(
bs
)
out_cache_loc
=
self
.
alloc_token_slots
(
extend_num_tokens
)
out_cache_loc
=
self
.
alloc_token_slots
(
extend_num_tokens
)
pt
=
0
pt
=
0
for
i
,
req
in
enumerate
(
reqs
):
for
i
,
req
in
enumerate
(
reqs
):
req
.
req_pool_idx
=
req_pool_indices
_cpu
[
i
]
req
.
req_pool_idx
=
req_pool_indices
[
i
]
pre_len
,
seq_len
=
len
(
req
.
prefix_indices
),
len
(
req
.
fill_ids
)
pre_len
,
seq_len
=
len
(
req
.
prefix_indices
),
len
(
req
.
fill_ids
)
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
assert
seq_len
-
pre_len
==
req
.
extend_input_len
assert
seq_len
-
pre_len
==
req
.
extend_input_len
...
@@ -497,22 +509,19 @@ class ScheduleBatch:
...
@@ -497,22 +509,19 @@ class ScheduleBatch:
pt
+=
req
.
extend_input_len
pt
+=
req
.
extend_input_len
# Set fields
# Set fields
with
torch
.
device
(
"cuda"
):
self
.
input_ids
=
sum
(
input_ids
,
[])
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices
,
device
=
"cuda"
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices_cpu
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
device
=
"cuda"
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
)
self
.
position_ids_offsets
=
torch
.
zeros
((
bs
,),
dtype
=
torch
.
int64
)
self
.
extend_num_tokens
=
extend_num_tokens
self
.
extend_num_tokens
=
extend_num_tokens
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
if
self
.
return_logprob
:
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
prefix_lens_cpu
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
self
.
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
self
.
extend_lens_cpu
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens_cpu
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
)
def
get_forward_batch
(
self
):
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
)
return
ForwardBatch
.
from_schedule_batch
(
self
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
self
.
forward_mode
=
ForwardMode
.
MIXED
self
.
forward_mode
=
ForwardMode
.
MIXED
...
@@ -522,24 +531,24 @@ class ScheduleBatch:
...
@@ -522,24 +531,24 @@ class ScheduleBatch:
req
.
fill_ids
=
req
.
origin_input_ids
+
req
.
output_ids
req
.
fill_ids
=
req
.
origin_input_ids
+
req
.
output_ids
req
.
extend_input_len
=
1
req
.
extend_input_len
=
1
input_ids
=
torch
.
cat
([
self
.
input_ids
,
running_batch
.
input_ids
])
input_ids
=
self
.
input_ids
+
running_batch
.
input_ids
out_cache_loc
=
torch
.
cat
([
self
.
out_cache_loc
,
running_batch
.
out_cache_loc
])
out_cache_loc
=
torch
.
cat
([
self
.
out_cache_loc
,
running_batch
.
out_cache_loc
])
extend_num_tokens
=
self
.
extend_num_tokens
+
running_bs
extend_num_tokens
=
self
.
extend_num_tokens
+
running_bs
self
.
merge
(
running_batch
)
self
.
merge
_batch
(
running_batch
)
self
.
input_ids
=
input_ids
self
.
input_ids
=
input_ids
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
self
.
extend_num_tokens
=
extend_num_tokens
self
.
extend_num_tokens
=
extend_num_tokens
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self
.
prefix_lens
_cpu
.
extend
(
self
.
prefix_lens
.
extend
(
[
[
len
(
r
.
origin_input_ids
)
+
len
(
r
.
output_ids
)
-
1
len
(
r
.
origin_input_ids
)
+
len
(
r
.
output_ids
)
-
1
for
r
in
running_batch
.
reqs
for
r
in
running_batch
.
reqs
]
]
)
)
self
.
extend_lens
_cpu
.
extend
([
1
]
*
running_bs
)
self
.
extend_lens
.
extend
([
1
]
*
running_bs
)
self
.
extend_logprob_start_lens
_cpu
.
extend
([
0
]
*
running_bs
)
self
.
extend_logprob_start_lens
.
extend
([
0
]
*
running_bs
)
def
check_decode_mem
(
self
):
def
check_decode_mem
(
self
):
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
...
@@ -631,7 +640,7 @@ class ScheduleBatch:
...
@@ -631,7 +640,7 @@ class ScheduleBatch:
return
retracted_reqs
,
new_estimate_ratio
return
retracted_reqs
,
new_estimate_ratio
def
check_for_jump_forward
(
self
,
model_runner
):
def
check_for_jump_forward
(
self
,
pad_input_ids_func
):
jump_forward_reqs
=
[]
jump_forward_reqs
=
[]
filter_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))]
filter_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))]
...
@@ -688,7 +697,7 @@ class ScheduleBatch:
...
@@ -688,7 +697,7 @@ class ScheduleBatch:
# re-applying image padding
# re-applying image padding
if
req
.
image_inputs
is
not
None
:
if
req
.
image_inputs
is
not
None
:
req
.
origin_input_ids
=
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids
=
pad_input_ids
_func
(
req
.
origin_input_ids_unpadded
,
req
.
image_inputs
req
.
origin_input_ids_unpadded
,
req
.
image_inputs
)
)
...
@@ -708,7 +717,7 @@ class ScheduleBatch:
...
@@ -708,7 +717,7 @@ class ScheduleBatch:
for
r
in
self
.
reqs
for
r
in
self
.
reqs
]
]
self
.
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
input_ids
=
input_ids
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens
.
add_
(
1
)
# Alloc mem
# Alloc mem
...
@@ -731,32 +740,97 @@ class ScheduleBatch:
...
@@ -731,32 +740,97 @@ class ScheduleBatch:
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
unfinished_indices
]
self
.
reqs
=
[
self
.
reqs
[
i
]
for
i
in
unfinished_indices
]
new_indices
=
torch
.
tensor
(
unfinished_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
new_indices
=
torch
.
tensor
(
unfinished_indices
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
=
self
.
seq_lens
[
new_indices
]
self
.
input_ids
=
None
self
.
req_pool_indices
=
self
.
req_pool_indices
[
new_indices
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
new_indices
]
self
.
position_ids_offsets
=
self
.
position_ids_offset
s
[
new_indices
]
self
.
seq_lens
=
self
.
seq_len
s
[
new_indices
]
self
.
out_cache_loc
=
None
self
.
out_cache_loc
=
None
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
unfinished_indices
]
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
if
self
.
return_logprob
:
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
unfinished_indices
]
self
.
has_stream
=
any
(
req
.
stream
for
req
in
self
.
reqs
)
self
.
has_stream
=
any
(
req
.
stream
for
req
in
self
.
reqs
)
self
.
sampling_info
.
filter
(
unfinished_indices
,
new_indices
)
self
.
sampling_info
.
filter
_batch
(
unfinished_indices
,
new_indices
)
def
merge
(
self
,
other
:
"ScheduleBatch"
):
def
merge
_batch
(
self
,
other
:
"ScheduleBatch"
):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs.
# needs to be called with pre-merged Batch.reqs.
self
.
sampling_info
.
merge
(
other
.
sampling_info
)
self
.
sampling_info
.
merge
_batch
(
other
.
sampling_info
)
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
req_pool_indices
=
torch
.
concat
(
self
.
req_pool_indices
=
torch
.
concat
(
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
)
)
self
.
seq_lens
=
torch
.
concat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
seq_lens
=
torch
.
concat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
position_ids_offsets
=
torch
.
concat
(
[
self
.
position_ids_offsets
,
other
.
position_ids_offsets
]
)
self
.
out_cache_loc
=
None
self
.
out_cache_loc
=
None
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
if
self
.
return_logprob
and
other
.
return_logprob
:
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
elif
self
.
return_logprob
:
self
.
top_logprobs_nums
.
extend
([
0
]
*
len
(
other
.
reqs
))
elif
other
.
return_logprob
:
self
.
top_logprobs_nums
=
[
0
]
*
len
(
self
.
reqs
)
+
other
.
top_logprobs_nums
self
.
has_stream
=
any
(
req
.
stream
for
req
in
self
.
reqs
)
self
.
has_stream
=
any
(
req
.
stream
for
req
in
self
.
reqs
)
def
get_model_worker_batch
(
self
):
if
self
.
forward_mode
.
is_decode
():
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
(
image_inputs
)
=
None
else
:
extend_seq_lens
=
self
.
extend_lens
extend_prefix_lens
=
self
.
prefix_lens
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
image_inputs
=
[
r
.
image_inputs
for
r
in
self
.
reqs
]
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
]
self
.
sampling_info
.
regex_fsm_states
=
[
req
.
regex_fsm_state
for
req
in
self
.
reqs
]
return
ModelWorkerBatch
(
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
extend_seq_lens
=
extend_seq_lens
,
extend_prefix_lens
=
extend_prefix_lens
,
extend_logprob_start_lens
=
extend_logprob_start_lens
,
image_inputs
=
image_inputs
,
lora_paths
=
lora_paths
,
sampling_info
=
self
.
sampling_info
,
)
@
dataclass
class
ModelWorkerBatch
:
# The forward mode
forward_mode
:
ForwardMode
# The input ids
input_ids
:
List
[
int
]
# The indices of requests in the req_to_token_pool
req_pool_indices
:
torch
.
Tensor
# The sequence length
seq_lens
:
torch
.
Tensor
# The indices of output tokens in the token_to_kv_pool
out_cache_loc
:
torch
.
Tensor
# For logprob
return_logprob
:
bool
top_logprobs_nums
:
Optional
[
List
[
int
]]
# For extend
extend_seq_lens
:
Optional
[
List
[
int
]]
extend_prefix_lens
:
Optional
[
List
[
int
]]
extend_logprob_start_lens
:
Optional
[
List
[
int
]]
# For multimodal
image_inputs
:
Optional
[
List
[
ImageInputs
]]
# For LoRA
lora_paths
:
Optional
[
List
[
str
]]
# Sampling info
sampling_info
:
SamplingBatchInfo
python/sglang/srt/managers/scheduler.py
View file @
63ba2f8d
...
@@ -141,6 +141,9 @@ class Scheduler:
...
@@ -141,6 +141,9 @@ class Scheduler:
nccl_port
=
port_args
.
nccl_ports
[
0
],
nccl_port
=
port_args
.
nccl_ports
[
0
],
)
)
self
.
tp_cpu_group
=
self
.
tp_worker
.
model_runner
.
tp_group
.
cpu_group
self
.
tp_cpu_group
=
self
.
tp_worker
.
model_runner
.
tp_group
.
cpu_group
self
.
pad_input_ids_func
=
getattr
(
self
.
tp_worker
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
# Get token and memory info from the tp worker
# Get token and memory info from the tp worker
(
(
...
@@ -292,7 +295,7 @@ class Scheduler:
...
@@ -292,7 +295,7 @@ class Scheduler:
if
self
.
running_batch
is
None
:
if
self
.
running_batch
is
None
:
self
.
running_batch
=
new_batch
self
.
running_batch
=
new_batch
else
:
else
:
self
.
running_batch
.
merge
(
new_batch
)
self
.
running_batch
.
merge
_batch
(
new_batch
)
else
:
else
:
# Run a decode batch
# Run a decode batch
if
self
.
running_batch
is
not
None
:
if
self
.
running_batch
is
not
None
:
...
@@ -370,7 +373,7 @@ class Scheduler:
...
@@ -370,7 +373,7 @@ class Scheduler:
req
.
image_inputs
=
ImageInputs
.
from_dict
(
req
.
image_inputs
=
ImageInputs
.
from_dict
(
recv_req
.
image_inputs
,
self
.
model_config
.
vocab_size
recv_req
.
image_inputs
,
self
.
model_config
.
vocab_size
)
)
req
.
origin_input_ids
=
self
.
tp_worker
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids
=
self
.
pad_input_ids
_func
(
req
.
origin_input_ids_unpadded
,
req
.
image_inputs
req
.
origin_input_ids_unpadded
,
req
.
image_inputs
)
)
...
@@ -575,9 +578,9 @@ class Scheduler:
...
@@ -575,9 +578,9 @@ class Scheduler:
if
self
.
is_generation
:
if
self
.
is_generation
:
# Forward and sample the next tokens
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
forward
_batch
=
batch
.
get_
forward
_batch
()
model_worker
_batch
=
batch
.
get_
model_worker
_batch
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
forward_batch
,
batch
model_worker_
batch
)
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
next_token_ids
...
@@ -641,8 +644,8 @@ class Scheduler:
...
@@ -641,8 +644,8 @@ class Scheduler:
)
)
else
:
else
:
assert
batch
.
extend_num_tokens
!=
0
assert
batch
.
extend_num_tokens
!=
0
forward
_batch
=
batch
.
get_
forward
_batch
()
model_worker
_batch
=
batch
.
get_
model_worker
_batch
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
forward
_batch
)
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker
_batch
)
# Check finish conditions
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
...
@@ -759,9 +762,7 @@ class Scheduler:
...
@@ -759,9 +762,7 @@ class Scheduler:
# Check for jump-forward
# Check for jump-forward
if
not
self
.
disable_regex_jump_forward
:
if
not
self
.
disable_regex_jump_forward
:
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
pad_input_ids_func
)
self
.
tp_worker
.
model_runner
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
if
batch
.
is_empty
():
return
return
...
@@ -771,9 +772,9 @@ class Scheduler:
...
@@ -771,9 +772,9 @@ class Scheduler:
batch
.
prepare_for_decode
()
batch
.
prepare_for_decode
()
# Forward and sample the next tokens
# Forward and sample the next tokens
forward
_batch
=
batch
.
get_
forward
_batch
()
model_worker
_batch
=
batch
.
get_
model_worker
_batch
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
forward_batch
,
batch
model_worker_
batch
)
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
next_token_ids
...
...
python/sglang/srt/managers/tp_worker.py
View file @
63ba2f8d
...
@@ -21,6 +21,7 @@ import logging
...
@@ -21,6 +21,7 @@ import logging
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
UpdateWeightReqInput
from
sglang.srt.managers.io_struct
import
UpdateWeightReqInput
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -108,12 +109,14 @@ class TpModelWorker:
...
@@ -108,12 +109,14 @@ class TpModelWorker:
self
.
random_seed
,
self
.
random_seed
,
)
)
def
forward_batch_generation
(
self
,
forward_batch
:
ForwardBatch
,
batch
):
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
batch
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
model_worker_
batch
)
return
logits_output
,
next_token_ids
return
logits_output
,
next_token_ids
def
forward_batch_embedding
(
self
,
forward_batch
:
ForwardBatch
):
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
embeddings
=
logits_output
.
embeddings
.
tolist
()
embeddings
=
logits_output
.
embeddings
.
tolist
()
return
embeddings
return
embeddings
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
63ba2f8d
...
@@ -15,18 +15,33 @@ See the License for the specific language governing permissions and
...
@@ -15,18 +15,33 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""Meta data for a forward pass."""
"""
Store information about a forward batch.
The following is the flow of data structures for a batch:
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention_backend
import
AttentionBackend
from
sglang.srt.layers.attention_backend
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
ImageInputs
,
Schedule
Batch
from
sglang.srt.managers.schedule_batch
import
ImageInputs
,
ModelWorker
Batch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
class
ForwardMode
(
IntEnum
):
class
ForwardMode
(
IntEnum
):
...
@@ -69,25 +84,28 @@ class ForwardBatch:
...
@@ -69,25 +84,28 @@ class ForwardBatch:
# The indices of output tokens in the token_to_kv_pool
# The indices of output tokens in the token_to_kv_pool
out_cache_loc
:
torch
.
Tensor
out_cache_loc
:
torch
.
Tensor
# For logprob
return_logprob
:
bool
=
False
top_logprobs_nums
:
Optional
[
List
[
int
]]
=
None
# Position information
# Position information
positions
:
torch
.
Tensor
=
None
positions
:
torch
.
Tensor
=
None
# For extend
# For extend
extend_seq_lens
:
torch
.
Tensor
=
None
extend_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
extend_prefix_lens
:
torch
.
Tensor
=
None
extend_prefix_lens
:
Optional
[
torch
.
Tensor
]
=
None
extend_start_loc
:
torch
.
Tensor
=
None
extend_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
extend_seq_lens_cpu
:
Optional
[
List
[
int
]]
=
None
# For logprob
extend_logprob_start_lens_cpu
:
Optional
[
List
[
int
]]
=
None
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
extend_seq_lens_cpu
:
List
[
int
]
=
None
extend_logprob_start_lens_cpu
:
List
[
int
]
=
None
# For multimodal
# For multimodal
image_inputs
:
List
[
ImageInputs
]
=
None
image_inputs
:
Optional
[
List
[
ImageInputs
]
]
=
None
# For LoRA
# For LoRA
lora_paths
:
List
[
str
]
=
None
lora_paths
:
Optional
[
List
[
str
]]
=
None
# Sampling info
sampling_info
:
SamplingBatchInfo
=
None
# Attention backend
# Attention backend
req_to_token_pool
:
ReqToTokenPool
=
None
req_to_token_pool
:
ReqToTokenPool
=
None
...
@@ -95,42 +113,61 @@ class ForwardBatch:
...
@@ -95,42 +113,61 @@ class ForwardBatch:
attn_backend
:
AttentionBackend
=
None
attn_backend
:
AttentionBackend
=
None
@
classmethod
@
classmethod
def
from_schedule_batch
(
def
init_new
(
cls
,
cls
,
batch
:
ScheduleBatch
,
batch
:
ModelWorkerBatch
,
model_runner
:
ModelRunner
,
):
):
device
=
"cuda"
ret
=
cls
(
ret
=
cls
(
forward_mode
=
batch
.
forward_mode
,
forward_mode
=
batch
.
forward_mode
,
batch_size
=
batch
.
batch_size
(
),
batch_size
=
len
(
batch
.
seq_lens
),
input_ids
=
batch
.
input_ids
,
input_ids
=
torch
.
tensor
(
batch
.
input_ids
,
dtype
=
torch
.
int32
,
device
=
device
),
req_pool_indices
=
batch
.
req_pool_indices
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
seq_lens
=
batch
.
seq_lens
,
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_loc
=
batch
.
out_cache_loc
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
lora_paths
=
[
req
.
lora_path
for
req
in
batch
.
reqs
],
lora_paths
=
batch
.
lora_paths
,
sampling_info
=
batch
.
sampling_info
,
)
)
# Init position information
if
ret
.
forward_mode
.
is_decode
():
if
ret
.
forward_mode
.
is_decode
():
ret
.
positions
=
(
ret
.
seq_lens
-
1
).
to
(
torch
.
int64
)
ret
.
positions
=
(
ret
.
seq_lens
-
1
).
to
(
torch
.
int64
)
else
:
else
:
ret
.
positions
=
torch
.
tensor
(
ret
.
positions
=
torch
.
tensor
(
np
.
concatenate
(
np
.
concatenate
(
[
[
np
.
arange
(
batch
.
prefix_lens_cpu
[
i
],
len
(
req
.
fill_ids
))
np
.
arange
(
prefix_len
,
prefix_len
+
extend_len
)
for
i
,
req
in
enumerate
(
batch
.
reqs
)
for
prefix_len
,
extend_len
in
zip
(
batch
.
extend_prefix_lens
,
batch
.
extend_seq_lens
)
],
],
axis
=
0
,
axis
=
0
,
),
),
device
=
"cuda"
,
device
=
device
,
).
to
(
torch
.
int64
)
).
to
(
torch
.
int64
)
ret
.
image_inputs
=
[
r
.
image_inputs
for
r
in
batch
.
reqs
]
ret
.
image_inputs
=
batch
.
image_inputs
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_lens_cpu
,
device
=
"cuda"
)
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_seq_lens
,
device
=
device
)
ret
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
ret
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
extend_prefix_lens
,
device
=
device
)
ret
.
extend_start_loc
=
torch
.
zeros_like
(
ret
.
extend_seq_lens
)
ret
.
extend_start_loc
=
torch
.
zeros_like
(
ret
.
extend_seq_lens
)
ret
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
ret
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
ret
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
ret
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
ret
.
extend_seq_lens_cpu
=
batch
.
extend_lens_cpu
ret
.
extend_seq_lens_cpu
=
batch
.
extend_seq_lens
ret
.
extend_logprob_start_lens_cpu
=
batch
.
extend_logprob_start_lens_cpu
ret
.
extend_logprob_start_lens_cpu
=
batch
.
extend_logprob_start_lens
# Init attention information
ret
.
req_to_token_pool
=
model_runner
.
req_to_token_pool
ret
.
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
ret
.
attn_backend
=
model_runner
.
attn_backend
model_runner
.
attn_backend
.
init_forward_metadata
(
ret
)
# Init lora information
if
model_runner
.
server_args
.
lora_paths
is
not
None
:
model_runner
.
lora_manager
.
prepare_lora_batch
(
ret
)
return
ret
return
ret
python/sglang/srt/model_executor/model_runner.py
View file @
63ba2f8d
...
@@ -21,7 +21,7 @@ import importlib.resources
...
@@ -21,7 +21,7 @@ import importlib.resources
import
logging
import
logging
import
pkgutil
import
pkgutil
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Optional
,
Tuple
,
Type
from
typing
import
Optional
,
Type
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -38,11 +38,12 @@ from vllm.model_executor.model_loader import get_model
...
@@ -38,11 +38,12 @@ from vllm.model_executor.model_loader import get_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.constrained
import
disable_cache
from
sglang.srt.layers.attention_backend
import
FlashInferAttnBackend
,
TritonAttnBackend
from
sglang.srt.layers.attention_backend
import
FlashInferAttnBackend
,
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
MHATokenToKVPool
,
MLATokenToKVPool
,
MLATokenToKVPool
,
...
@@ -52,6 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...
@@ -52,6 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
enable_show_time_cost
,
get_available_gpu_memory
,
get_available_gpu_memory
,
is_generation_model
,
is_generation_model
,
is_multimodal_model
,
is_multimodal_model
,
...
@@ -102,6 +104,12 @@ class ModelRunner:
...
@@ -102,6 +104,12 @@ class ModelRunner:
server_args
.
chunked_prefill_size
=
None
server_args
.
chunked_prefill_size
=
None
server_args
.
mem_fraction_static
*=
0.95
server_args
.
mem_fraction_static
*=
0.95
# Global vars
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
if
server_args
.
disable_disk_cache
:
disable_cache
()
global_server_args_dict
.
update
(
global_server_args_dict
.
update
(
{
{
"attention_backend"
:
server_args
.
attention_backend
,
"attention_backend"
:
server_args
.
attention_backend
,
...
@@ -491,16 +499,6 @@ class ModelRunner:
...
@@ -491,16 +499,6 @@ class ModelRunner:
)
)
def
forward
(
self
,
forward_batch
:
ForwardBatch
)
->
LogitsProcessorOutput
:
def
forward
(
self
,
forward_batch
:
ForwardBatch
)
->
LogitsProcessorOutput
:
# Attach attention information
forward_batch
.
req_to_token_pool
=
self
.
req_to_token_pool
forward_batch
.
token_to_kv_pool
=
self
.
token_to_kv_pool
forward_batch
.
attn_backend
=
self
.
attn_backend
forward_batch
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
# Attach lora information
if
self
.
server_args
.
lora_paths
is
not
None
:
self
.
lora_manager
.
prepare_lora_batch
(
forward_batch
)
if
forward_batch
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
forward_batch
)
return
self
.
forward_decode
(
forward_batch
)
elif
forward_batch
.
forward_mode
.
is_extend
():
elif
forward_batch
.
forward_mode
.
is_extend
():
...
@@ -508,16 +506,27 @@ class ModelRunner:
...
@@ -508,16 +506,27 @@ class ModelRunner:
else
:
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
forward_batch
.
forward_mode
}
"
)
raise
ValueError
(
f
"Invaid forward mode:
{
forward_batch
.
forward_mode
}
"
)
def
_apply_logits_bias
(
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
):
)
->
torch
.
Tensor
:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info
=
forward_batch
.
sampling_info
sampling_info
.
update_regex_vocab_mask
()
sampling_info
.
update_penalties
()
logits
=
self
.
apply_logits_bias
(
logits_output
.
next_token_logits
,
sampling_info
)
# Sample the next tokens.
next_token_ids
=
self
.
sampler
(
logits
,
sampling_info
)
return
next_token_ids
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# Apply logit_bias
# Apply logit_bias
if
sampling_info
.
logit_bias
is
not
None
:
if
sampling_info
.
logit_bias
is
not
None
:
logits
.
add_
(
sampling_info
.
logit_bias
)
logits
.
add_
(
sampling_info
.
logit_bias
)
# min-token, presence, frequency
# min-token, presence, frequency
if
sampling_info
.
linear_penalties
is
not
None
:
if
sampling_info
.
linear_penalties
is
not
None
:
logits
+=
sampling_info
.
linear_penalties
logits
.
add_
(
sampling_info
.
linear_penalties
)
# repetition
# repetition
if
sampling_info
.
scaling_penalties
is
not
None
:
if
sampling_info
.
scaling_penalties
is
not
None
:
...
@@ -533,20 +542,6 @@ class ModelRunner:
...
@@ -533,20 +542,6 @@ class ModelRunner:
return
logits
return
logits
def
sample
(
self
,
logits_output
:
LogitsProcessorOutput
,
batch
:
ScheduleBatch
)
->
torch
.
Tensor
:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
batch
.
sampling_info
.
update_regex_vocab_mask
(
batch
)
batch
.
sampling_info
.
update_penalties
()
logits
=
self
.
_apply_logits_bias
(
logits_output
.
next_token_logits
,
batch
.
sampling_info
)
# Sample the next tokens.
next_token_ids
=
self
.
sampler
(
logits
,
batch
.
sampling_info
)
return
next_token_ids
@
lru_cache
()
@
lru_cache
()
def
import_model_classes
():
def
import_model_classes
():
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
63ba2f8d
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List
import
torch
import
torch
import
sglang.srt.sampling.penaltylib
as
penaltylib
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.srt.constrained
import
RegexGuide
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
@@ -22,13 +23,17 @@ class SamplingBatchInfo:
...
@@ -22,13 +23,17 @@ class SamplingBatchInfo:
top_ks
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
=
None
min_ps
:
torch
.
Tensor
=
None
min_ps
:
torch
.
Tensor
=
None
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
=
False
# Bias Tensors
# Bias Tensors
logit_bias
:
torch
.
Tensor
=
None
logit_bias
:
torch
.
Tensor
=
None
vocab_mask
:
torch
.
Tensor
=
None
vocab_mask
:
torch
.
Tensor
=
None
# FSM states
regex_fsms
:
List
[
RegexGuide
]
=
None
regex_fsm_states
:
List
[
int
]
=
None
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
=
False
# Penalizer
# Penalizer
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
linear_penalties
:
torch
.
Tensor
=
None
linear_penalties
:
torch
.
Tensor
=
None
...
@@ -54,6 +59,8 @@ class SamplingBatchInfo:
...
@@ -54,6 +59,8 @@ class SamplingBatchInfo:
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
)
)
ret
.
regex_fsms
=
[
r
.
regex_fsm
for
r
in
reqs
]
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
ret
.
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
)
ret
.
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
...
@@ -102,24 +109,22 @@ class SamplingBatchInfo:
...
@@ -102,24 +109,22 @@ class SamplingBatchInfo:
)
)
self
.
linear_penalties
=
penalizer
.
apply
(
self
.
linear_penalties
)
self
.
linear_penalties
=
penalizer
.
apply
(
self
.
linear_penalties
)
def
update_regex_vocab_mask
(
self
,
batch
:
ScheduleBatch
):
def
update_regex_vocab_mask
(
self
):
has_regex
=
any
(
req
.
regex_fsm
is
not
None
for
req
in
batch
.
reqs
)
# Reset the vocab mask
# Reset the vocab mask
self
.
vocab_mask
=
None
self
.
vocab_mask
=
None
if
has_regex
:
if
any
(
regex_fsm
is
not
None
for
regex_fsm
in
self
.
regex_fsms
)
:
self
.
vocab_mask
=
torch
.
zeros
(
self
.
vocab_mask
=
torch
.
zeros
(
batch
.
batch_size
(
),
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
"cuda"
len
(
self
.
regex_fsms
),
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
"cuda"
)
)
for
i
,
re
q
in
enumerate
(
batch
.
req
s
):
for
i
,
re
gex_fsm
in
enumerate
(
self
.
regex_fsm
s
):
if
req
.
regex_fsm
is
not
None
:
if
regex_fsm
is
not
None
:
self
.
vocab_mask
[
i
].
fill_
(
1
)
self
.
vocab_mask
[
i
].
fill_
(
1
)
self
.
vocab_mask
[
i
][
self
.
vocab_mask
[
i
][
req
.
regex_fsm
.
get_next_instruction
(
req
.
regex_fsm_state
).
tokens
regex_fsm
.
get_next_instruction
(
self
.
regex_fsm_state
s
[
i
]
).
tokens
]
=
0
]
=
0
def
filter
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
def
filter
_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
for
item
in
[
for
item
in
[
...
@@ -129,9 +134,11 @@ class SamplingBatchInfo:
...
@@ -129,9 +134,11 @@ class SamplingBatchInfo:
"min_ps"
,
"min_ps"
,
"logit_bias"
,
"logit_bias"
,
]:
]:
self_val
=
getattr
(
self
,
item
,
None
)
value
=
getattr
(
self
,
item
,
None
)
if
self_val
is
not
None
:
# logit_bias can be None
if
value
is
not
None
:
# logit_bias can be None
setattr
(
self
,
item
,
self_val
[
new_indices
])
setattr
(
self
,
item
,
value
[
new_indices
])
self
.
regex_fsms
=
[
self
.
regex_fsms
[
i
]
for
i
in
new_indices
]
@
staticmethod
@
staticmethod
def
merge_bias_tensor
(
def
merge_bias_tensor
(
...
@@ -153,7 +160,7 @@ class SamplingBatchInfo:
...
@@ -153,7 +160,7 @@ class SamplingBatchInfo:
return
None
return
None
def
merge
(
self
,
other
:
"SamplingBatchInfo"
):
def
merge
_batch
(
self
,
other
:
"SamplingBatchInfo"
):
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
for
item
in
[
for
item
in
[
...
@@ -169,3 +176,5 @@ class SamplingBatchInfo:
...
@@ -169,3 +176,5 @@ class SamplingBatchInfo:
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
)
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
)
)
)
self
.
regex_fsms
.
extend
(
other
.
regex_fsms
)
python/sglang/srt/server.py
View file @
63ba2f8d
...
@@ -41,7 +41,6 @@ from fastapi.middleware.cors import CORSMiddleware
...
@@ -41,7 +41,6 @@ from fastapi.middleware.cors import CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.srt.constrained
import
disable_cache
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.detokenizer_manager
import
run_detokenizer_process
from
sglang.srt.managers.detokenizer_manager
import
run_detokenizer_process
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
...
@@ -72,8 +71,6 @@ from sglang.srt.utils import (
...
@@ -72,8 +71,6 @@ from sglang.srt.utils import (
allocate_init_ports
,
allocate_init_ports
,
assert_pkg_version
,
assert_pkg_version
,
configure_logger
,
configure_logger
,
enable_show_time_cost
,
is_hip
,
kill_child_process
,
kill_child_process
,
maybe_set_triton_cache_manager
,
maybe_set_triton_cache_manager
,
prepare_model_and_tokenizer
,
prepare_model_and_tokenizer
,
...
@@ -400,14 +397,6 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -400,14 +397,6 @@ def _set_envs_and_config(server_args: ServerArgs):
# Set ulimit
# Set ulimit
set_ulimit
()
set_ulimit
()
# Enable show time cost for debugging
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
# Disable disk cache
if
server_args
.
disable_disk_cache
:
disable_cache
()
# Fix triton bugs
# Fix triton bugs
if
server_args
.
tp_size
*
server_args
.
dp_size
>
1
:
if
server_args
.
tp_size
*
server_args
.
dp_size
>
1
:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment