Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ffd20fcd
Unverified
Commit
ffd20fcd
authored
Nov 19, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 19, 2024
Browse files
Make constrained decoding work for overlap scheduler (#2095)
parent
55bd97f3
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
119 additions
and
95 deletions
+119
-95
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+0
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+12
-20
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+35
-24
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+8
-11
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+10
-3
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+11
-4
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+42
-20
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-12
No files found.
python/sglang/srt/layers/sampler.py
View file @
ffd20fcd
import
logging
import
logging
import
os
from
typing
import
Union
from
typing
import
Union
import
torch
import
torch
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
ffd20fcd
...
@@ -136,6 +136,7 @@ class ImageInputs:
...
@@ -136,6 +136,7 @@ class ImageInputs:
image_embeds
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
image_embeds
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_ids
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_ids
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# QWen2-VL related
# QWen2-VL related
image_grid_thws
:
List
[
Tuple
[
int
,
int
,
int
]]
=
None
image_grid_thws
:
List
[
Tuple
[
int
,
int
,
int
]]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -187,11 +188,10 @@ class Req:
...
@@ -187,11 +188,10 @@ class Req:
self
.
origin_input_ids
=
origin_input_ids
self
.
origin_input_ids
=
origin_input_ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
self
.
sampling_params
=
sampling_params
self
.
sampling_params
=
sampling_params
self
.
lora_path
=
lora_path
self
.
lora_path
=
lora_path
# Memory info
# Memory
pool
info
self
.
req_pool_idx
=
None
self
.
req_pool_idx
=
None
# Check finish
# Check finish
...
@@ -428,7 +428,7 @@ bid = 0
...
@@ -428,7 +428,7 @@ bid = 0
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
ScheduleBatch
:
class
ScheduleBatch
:
"""Store all inforamtion of a batch."""
"""Store all inforamtion of a batch
on the scheduler
."""
# Request, memory pool, and cache
# Request, memory pool, and cache
reqs
:
List
[
Req
]
reqs
:
List
[
Req
]
...
@@ -438,9 +438,9 @@ class ScheduleBatch:
...
@@ -438,9 +438,9 @@ class ScheduleBatch:
# For utility
# For utility
model_config
:
ModelConfig
=
None
model_config
:
ModelConfig
=
None
forward_mode
:
ForwardMode
=
None
forward_mode
:
ForwardMode
=
None
sampling_info
:
SamplingBatchInfo
=
None
sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
# Batched arguments to model runner
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
input_ids
:
torch
.
Tensor
=
None
...
@@ -509,7 +509,7 @@ class ScheduleBatch:
...
@@ -509,7 +509,7 @@ class ScheduleBatch:
def
is_empty
(
self
):
def
is_empty
(
self
):
return
len
(
self
.
reqs
)
==
0
return
len
(
self
.
reqs
)
==
0
def
alloc_req_slots
(
self
,
num_reqs
):
def
alloc_req_slots
(
self
,
num_reqs
:
int
):
req_pool_indices
=
self
.
req_to_token_pool
.
alloc
(
num_reqs
)
req_pool_indices
=
self
.
req_to_token_pool
.
alloc
(
num_reqs
)
if
req_pool_indices
is
None
:
if
req_pool_indices
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
...
@@ -610,7 +610,7 @@ class ScheduleBatch:
...
@@ -610,7 +610,7 @@ class ScheduleBatch:
assert
len
(
self
.
out_cache_loc
)
==
self
.
extend_num_tokens
assert
len
(
self
.
out_cache_loc
)
==
self
.
extend_num_tokens
def
prepare_for_extend
(
self
):
def
prepare_for_extend
(
self
,
enable_overlap_schedule
:
bool
=
False
):
self
.
forward_mode
=
ForwardMode
.
EXTEND
self
.
forward_mode
=
ForwardMode
.
EXTEND
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
...
@@ -704,7 +704,7 @@ class ScheduleBatch:
...
@@ -704,7 +704,7 @@ class ScheduleBatch:
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
self
,
self
.
model_config
.
vocab_size
,
self
.
model_config
.
vocab_size
,
global_server_args_dict
[
"disable_penalizer"
]
,
enable_overlap_schedule
=
enable_overlap_schedule
,
)
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
...
@@ -746,6 +746,7 @@ class ScheduleBatch:
...
@@ -746,6 +746,7 @@ class ScheduleBatch:
return
False
return
False
def
retract_decode
(
self
):
def
retract_decode
(
self
):
"""Retract the decoding requests when there is not enough memory."""
sorted_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))]
sorted_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))]
# TODO(lsyin): improve retraction policy for radix cache
# TODO(lsyin): improve retraction policy for radix cache
...
@@ -886,18 +887,10 @@ class ScheduleBatch:
...
@@ -886,18 +887,10 @@ class ScheduleBatch:
def
prepare_for_idle
(
self
):
def
prepare_for_idle
(
self
):
self
.
forward_mode
=
ForwardMode
.
IDLE
self
.
forward_mode
=
ForwardMode
.
IDLE
self
.
input_ids
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
input_ids
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
device
,
non_blocking
=
True
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
device
,
non_blocking
=
True
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
seq_lens_sum
=
0
self
.
seq_lens_sum
=
0
self
.
extend_num_tokens
=
0
self
.
extend_num_tokens
=
0
...
@@ -1063,7 +1056,6 @@ class ScheduleBatch:
...
@@ -1063,7 +1056,6 @@ class ScheduleBatch:
out_cache_loc
=
self
.
out_cache_loc
,
out_cache_loc
=
self
.
out_cache_loc
,
return_logprob
=
self
.
return_logprob
,
return_logprob
=
self
.
return_logprob
,
decoding_reqs
=
self
.
decoding_reqs
,
decoding_reqs
=
self
.
decoding_reqs
,
sampling_info
=
self
.
sampling_info
,
)
)
def
__str__
(
self
):
def
__str__
(
self
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
ffd20fcd
...
@@ -15,6 +15,7 @@ limitations under the License.
...
@@ -15,6 +15,7 @@ limitations under the License.
"""A scheduler that manages a tensor parallel GPU worker."""
"""A scheduler that manages a tensor parallel GPU worker."""
import
dataclasses
import
logging
import
logging
import
os
import
os
import
threading
import
threading
...
@@ -63,6 +64,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
...
@@ -63,6 +64,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
broadcast_pyobj
,
broadcast_pyobj
,
...
@@ -220,8 +222,12 @@ class Scheduler:
...
@@ -220,8 +222,12 @@ class Scheduler:
# Init running status
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
waiting_queue
:
List
[
Req
]
=
[]
# The running decoding batch for continuous batching
self
.
running_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
running_batch
:
Optional
[
ScheduleBatch
]
=
None
# The current forward batch
self
.
cur_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
cur_batch
:
Optional
[
ScheduleBatch
]
=
None
# The current forward batch
self
.
last_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
forward_ct
=
0
self
.
forward_ct
=
0
self
.
forward_ct_decode
=
0
self
.
forward_ct_decode
=
0
self
.
num_generated_tokens
=
0
self
.
num_generated_tokens
=
0
...
@@ -336,15 +342,12 @@ class Scheduler:
...
@@ -336,15 +342,12 @@ class Scheduler:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
event_loop_normal
(
self
):
def
event_loop_normal
(
self
):
"""A normal blocking scheduler loop."""
"""A normal scheduler loop."""
self
.
last_batch
=
None
while
True
:
while
True
:
recv_reqs
=
self
.
recv_requests
()
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
self
.
process_input_requests
(
recv_reqs
)
batch
=
self
.
get_next_batch_to_run
()
batch
=
self
.
get_next_batch_to_run
()
if
self
.
server_args
.
enable_dp_attention
:
if
self
.
server_args
.
enable_dp_attention
:
batch
=
self
.
prepare_dp_attn_batch
(
batch
)
batch
=
self
.
prepare_dp_attn_batch
(
batch
)
...
@@ -353,20 +356,8 @@ class Scheduler:
...
@@ -353,20 +356,8 @@ class Scheduler:
if
batch
:
if
batch
:
result
=
self
.
run_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
self
.
process_batch_result
(
batch
,
result
)
# Decode multiple steps to reduce the overhead
if
batch
.
forward_mode
.
is_decode
():
for
_
in
range
(
self
.
server_args
.
num_continuous_decode_steps
-
1
):
if
not
self
.
running_batch
:
break
self
.
update_running_batch
()
if
not
self
.
running_batch
:
break
if
self
.
server_args
.
enable_dp_attention
:
batch
=
self
.
prepare_dp_attn_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
else
:
else
:
# Self-check and re-init some states when the server is idle
self
.
check_memory
()
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
new_token_ratio
=
self
.
init_new_token_ratio
...
@@ -377,9 +368,6 @@ class Scheduler:
...
@@ -377,9 +368,6 @@ class Scheduler:
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue
=
deque
()
result_queue
=
deque
()
self
.
last_batch
=
None
self
.
running_batch
=
None
while
True
:
while
True
:
recv_reqs
=
self
.
recv_requests
()
recv_reqs
=
self
.
recv_requests
()
self
.
process_input_requests
(
recv_reqs
)
self
.
process_input_requests
(
recv_reqs
)
...
@@ -390,10 +378,24 @@ class Scheduler:
...
@@ -390,10 +378,24 @@ class Scheduler:
result
=
self
.
run_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
result_queue
.
append
((
batch
.
copy
(),
result
))
result_queue
.
append
((
batch
.
copy
(),
result
))
if
self
.
last_batch
is
None
:
# A dummy first batch to start the pipeline for overlap scheduler.
# It is now used for triggering the sampling_info_done event.
tmp_batch
=
ScheduleBatch
(
reqs
=
None
,
forward_mode
=
ForwardMode
.
DUMMY_FIRST
,
next_batch_sampling_info
=
self
.
tp_worker
.
cur_sampling_info
,
)
self
.
process_batch_result
(
tmp_batch
,
None
)
if
self
.
last_batch
:
if
self
.
last_batch
:
tmp_batch
,
tmp_result
=
result_queue
.
popleft
()
tmp_batch
,
tmp_result
=
result_queue
.
popleft
()
tmp_batch
.
next_batch_sampling_info
=
(
self
.
tp_worker
.
cur_sampling_info
if
batch
else
None
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
elif
batch
is
None
:
elif
batch
is
None
:
# Self-check and re-init some states when the server is idle
self
.
check_memory
()
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
new_token_ratio
=
self
.
init_new_token_ratio
...
@@ -806,7 +808,7 @@ class Scheduler:
...
@@ -806,7 +808,7 @@ class Scheduler:
self
.
tree_cache
,
self
.
tree_cache
,
self
.
model_config
,
self
.
model_config
,
)
)
new_batch
.
prepare_for_extend
()
new_batch
.
prepare_for_extend
(
self
.
enable_overlap
)
# Mixed-style chunked prefill
# Mixed-style chunked prefill
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
...
@@ -893,14 +895,15 @@ class Scheduler:
...
@@ -893,14 +895,15 @@ class Scheduler:
return
ret
return
ret
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result
(
self
,
batch
:
ScheduleBatch
,
result
):
if
batch
.
forward_mode
.
is_idle
():
return
if
batch
.
forward_mode
.
is_decode
():
if
batch
.
forward_mode
.
is_decode
():
self
.
process_batch_result_decode
(
batch
,
result
)
self
.
process_batch_result_decode
(
batch
,
result
)
if
batch
.
is_empty
():
if
batch
.
is_empty
():
self
.
running_batch
=
None
self
.
running_batch
=
None
el
se
:
el
if
batch
.
forward_mode
.
is_extend
()
:
self
.
process_batch_result_prefill
(
batch
,
result
)
self
.
process_batch_result_prefill
(
batch
,
result
)
elif
batch
.
forward_mode
.
is_dummy_first
():
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
...
@@ -953,6 +956,10 @@ class Scheduler:
...
@@ -953,6 +956,10 @@ class Scheduler:
else
:
else
:
req
.
is_being_chunked
-=
1
req
.
is_being_chunked
-=
1
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
else
:
# embedding or reward model
embeddings
,
bid
=
result
embeddings
,
bid
=
result
embeddings
=
embeddings
.
tolist
()
embeddings
=
embeddings
.
tolist
()
...
@@ -1022,6 +1029,10 @@ class Scheduler:
...
@@ -1022,6 +1029,10 @@ class Scheduler:
if
req
.
top_logprobs_num
>
0
:
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
)
self
.
stream_output
(
batch
.
reqs
)
self
.
token_to_kv_pool
.
free_group_end
()
self
.
token_to_kv_pool
.
free_group_end
()
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
ffd20fcd
...
@@ -18,7 +18,6 @@ limitations under the License.
...
@@ -18,7 +18,6 @@ limitations under the License.
import
dataclasses
import
dataclasses
import
logging
import
logging
import
threading
import
threading
import
time
from
queue
import
Queue
from
queue
import
Queue
from
typing
import
Optional
from
typing
import
Optional
...
@@ -96,9 +95,7 @@ class TpModelWorkerClient:
...
@@ -96,9 +95,7 @@ class TpModelWorkerClient:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward_thread_func_
(
self
):
def
forward_thread_func_
(
self
):
while
True
:
while
True
:
model_worker_batch
,
future_token_ids_ct
,
compute_info_done
=
(
model_worker_batch
,
future_token_ids_ct
=
self
.
input_queue
.
get
()
self
.
input_queue
.
get
()
)
if
not
model_worker_batch
:
if
not
model_worker_batch
:
break
break
self
.
launch_done
=
threading
.
Event
()
self
.
launch_done
=
threading
.
Event
()
...
@@ -109,7 +106,6 @@ class TpModelWorkerClient:
...
@@ -109,7 +106,6 @@ class TpModelWorkerClient:
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
# Run forward
# Run forward
compute_info_done
.
wait
()
logits_output
,
next_token_ids
=
self
.
worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
worker
.
forward_batch_generation
(
model_worker_batch
,
self
.
launch_done
model_worker_batch
,
self
.
launch_done
)
)
...
@@ -160,15 +156,16 @@ class TpModelWorkerClient:
...
@@ -160,15 +156,16 @@ class TpModelWorkerClient:
return
logits_output
,
next_token_ids
return
logits_output
,
next_token_ids
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
# A cuda stream sync here to avoid the cuda illegal memory access error.
_
=
model_worker_batch
.
seq_lens
[
0
].
item
()
# Push a new batch to the queue
# Push a new batch to the queue
model_worker_batch
.
sampling_info
=
dataclasses
.
replace
(
model_worker_batch
.
sampling_info
=
dataclasses
.
replace
(
model_worker_batch
.
sampling_info
model_worker_batch
.
sampling_info
,
)
sampling_info_done
=
threading
.
Event
(),
compute_info_done
=
torch
.
cuda
.
Event
()
compute_info_done
.
record
()
self
.
input_queue
.
put
(
(
model_worker_batch
,
self
.
future_token_ids_ct
,
compute_info_done
)
)
)
self
.
cur_sampling_info
=
model_worker_batch
.
sampling_info
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
# Allocate output future objects
# Allocate output future objects
bs
=
len
(
model_worker_batch
.
seq_lens
)
bs
=
len
(
model_worker_batch
.
seq_lens
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
ffd20fcd
...
@@ -52,15 +52,19 @@ if TYPE_CHECKING:
...
@@ -52,15 +52,19 @@ if TYPE_CHECKING:
class
ForwardMode
(
IntEnum
):
class
ForwardMode
(
IntEnum
):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
PREFILL
=
auto
()
PREFILL
=
auto
()
# Extend a sequence. The KV cache of the
first
part of the sequence is already computed (e.g., system prompt).
# Extend a sequence. The KV cache of the
beginning
part of the sequence is already computed (e.g., system prompt).
EXTEND
=
auto
()
EXTEND
=
auto
()
# Decode one token.
# Decode one token.
DECODE
=
auto
()
DECODE
=
auto
()
# Contains both EXTEND and DECODE.
# Contains both EXTEND and DECODE
when doing chunked prefill
.
MIXED
=
auto
()
MIXED
=
auto
()
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated.
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence
are
allocated.
IDLE
=
auto
()
IDLE
=
auto
()
# A dummy first batch to start the pipeline for overlap scheduler.
# It is now used for triggering the sampling_info_done event for the first prefill batch.
DUMMY_FIRST
=
auto
()
def
is_prefill
(
self
):
def
is_prefill
(
self
):
return
self
==
ForwardMode
.
PREFILL
return
self
==
ForwardMode
.
PREFILL
...
@@ -76,6 +80,9 @@ class ForwardMode(IntEnum):
...
@@ -76,6 +80,9 @@ class ForwardMode(IntEnum):
def
is_idle
(
self
):
def
is_idle
(
self
):
return
self
==
ForwardMode
.
IDLE
return
self
==
ForwardMode
.
IDLE
def
is_dummy_first
(
self
):
return
self
==
ForwardMode
.
DUMMY_FIRST
@
dataclass
@
dataclass
class
ForwardBatch
:
class
ForwardBatch
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
ffd20fcd
...
@@ -142,7 +142,6 @@ class ModelRunner:
...
@@ -142,7 +142,6 @@ class ModelRunner:
"triton_attention_reduce_in_fp32"
:
server_args
.
triton_attention_reduce_in_fp32
,
"triton_attention_reduce_in_fp32"
:
server_args
.
triton_attention_reduce_in_fp32
,
"disable_mla"
:
server_args
.
disable_mla
,
"disable_mla"
:
server_args
.
disable_mla
,
"torchao_config"
:
server_args
.
torchao_config
,
"torchao_config"
:
server_args
.
torchao_config
,
"disable_penalizer"
:
server_args
.
disable_penalizer
,
"enable_nan_detection"
:
server_args
.
enable_nan_detection
,
"enable_nan_detection"
:
server_args
.
enable_nan_detection
,
"enable_dp_attention"
:
server_args
.
enable_dp_attention
,
"enable_dp_attention"
:
server_args
.
enable_dp_attention
,
}
}
...
@@ -636,8 +635,16 @@ class ModelRunner:
...
@@ -636,8 +635,16 @@ class ModelRunner:
def
sample
(
def
sample
(
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info
=
forward_batch
.
sampling_info
sampling_info
=
forward_batch
.
sampling_info
if
sampling_info
.
sampling_info_done
:
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
if
sampling_info
.
grammars
:
sampling_info
.
sampling_info_done
.
wait
()
sampling_info
.
update_penalties
()
else
:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info
.
update_regex_vocab_mask
()
sampling_info
.
update_regex_vocab_mask
()
sampling_info
.
update_penalties
()
sampling_info
.
update_penalties
()
logits
=
self
.
apply_logits_bias
(
logits_output
.
next_token_logits
,
sampling_info
)
logits
=
self
.
apply_logits_bias
(
logits_output
.
next_token_logits
,
sampling_info
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
ffd20fcd
from
__future__
import
annotations
from
__future__
import
annotations
import
dataclasses
import
dataclasses
import
logging
import
threading
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
import
torch
import
torch
import
sglang.srt.sampling.penaltylib
as
penaltylib
import
sglang.srt.sampling.penaltylib
as
penaltylib
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
@@ -28,6 +33,7 @@ class SamplingBatchInfo:
...
@@ -28,6 +33,7 @@ class SamplingBatchInfo:
# Bias Tensors
# Bias Tensors
vocab_size
:
int
vocab_size
:
int
grammars
:
Optional
[
List
]
=
None
grammars
:
Optional
[
List
]
=
None
sampling_info_done
:
Optional
[
threading
.
Event
]
=
None
logit_bias
:
torch
.
Tensor
=
None
logit_bias
:
torch
.
Tensor
=
None
vocab_mask
:
Optional
[
torch
.
Tensor
]
=
None
vocab_mask
:
Optional
[
torch
.
Tensor
]
=
None
apply_mask
:
Optional
[
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
],
None
]]
=
None
apply_mask
:
Optional
[
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
],
None
]]
=
None
...
@@ -42,10 +48,7 @@ class SamplingBatchInfo:
...
@@ -42,10 +48,7 @@ class SamplingBatchInfo:
@
classmethod
@
classmethod
def
from_schedule_batch
(
def
from_schedule_batch
(
cls
,
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
,
enable_overlap_schedule
:
bool
batch
:
ScheduleBatch
,
vocab_size
:
int
,
disable_penalizer
:
bool
,
):
):
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
device
=
batch
.
device
device
=
batch
.
device
...
@@ -79,6 +82,33 @@ class SamplingBatchInfo:
...
@@ -79,6 +82,33 @@ class SamplingBatchInfo:
)
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
if
enable_overlap_schedule
:
# TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
# so it is kind of tricky to make it work with overlap scheduler.
# It requires correcly updating the penalty logits before the sampling and syncing the events.
# We will support them later.
penalizers
=
{
penaltylib
.
BatchedMinNewTokensPenalizer
,
}
if
(
any
(
req
.
sampling_params
.
frequency_penalty
!=
0.0
for
req
in
reqs
)
or
any
(
req
.
sampling_params
.
presence_penalty
!=
0.0
for
req
in
reqs
)
or
any
(
req
.
sampling_params
.
repetition_penalty
!=
1.0
for
req
in
reqs
)
):
logger
.
warning
(
"frequency_penalty, presence_penalty, and repetition_penalty are not supported "
"when using the default overlap scheduler. They will be ignored. "
"Please add `--disable-overlap` when launching the server if you need these features. "
"The speed will be slower in that case."
)
else
:
penalizers
=
{
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
penaltylib
.
BatchedPresencePenalizer
,
penaltylib
.
BatchedRepetitionPenalizer
,
}
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# should not add hefty computation overhead other than simple checks.
# should not add hefty computation overhead other than simple checks.
...
@@ -86,19 +116,11 @@ class SamplingBatchInfo:
...
@@ -86,19 +116,11 @@ class SamplingBatchInfo:
# While we choose not to even create the class instances if they are not required, this
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge_batch()} cases as well.
# handle {filter_batch()} and {merge_batch()} cases as well.
if
disable_penalizer
:
ret
.
penalizer_orchestrator
=
None
else
:
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
batch
=
batch
,
batch
=
batch
,
device
=
batch
.
device
,
device
=
batch
.
device
,
Penalizers
=
{
Penalizers
=
penalizers
,
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
penaltylib
.
BatchedPresencePenalizer
,
penaltylib
.
BatchedRepetitionPenalizer
,
},
)
)
# Handle logit bias but only allocate when needed
# Handle logit bias but only allocate when needed
...
@@ -133,13 +155,13 @@ class SamplingBatchInfo:
...
@@ -133,13 +155,13 @@ class SamplingBatchInfo:
self
.
linear_penalties
=
penalizer
.
apply
(
self
.
linear_penalties
)
self
.
linear_penalties
=
penalizer
.
apply
(
self
.
linear_penalties
)
def
update_regex_vocab_mask
(
self
):
def
update_regex_vocab_mask
(
self
):
if
not
self
.
grammars
or
not
any
(
grammar
for
grammar
in
self
.
grammars
)
:
if
not
self
.
grammars
:
self
.
vocab_mask
=
None
self
.
vocab_mask
=
None
self
.
apply_mask
=
None
self
.
apply_mask
=
None
return
return
# find a grammar from the list
# find a grammar from the list
grammar
=
next
(
grammar
for
grammar
in
self
.
grammars
if
grammar
is
not
None
)
grammar
=
next
(
grammar
for
grammar
in
self
.
grammars
if
grammar
)
# maybe we can reuse the existing mask?
# maybe we can reuse the existing mask?
self
.
vocab_mask
=
grammar
.
allocate_vocab_mask
(
self
.
vocab_mask
=
grammar
.
allocate_vocab_mask
(
...
...
python/sglang/srt/server_args.py
View file @
ffd20fcd
...
@@ -123,7 +123,6 @@ class ServerArgs:
...
@@ -123,7 +123,6 @@ class ServerArgs:
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_mla
:
bool
=
False
disable_mla
:
bool
=
False
disable_penalizer
:
bool
=
False
enable_overlap_schedule
:
bool
=
False
enable_overlap_schedule
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_dp_attention
:
bool
=
False
enable_dp_attention
:
bool
=
False
...
@@ -200,12 +199,7 @@ class ServerArgs:
...
@@ -200,12 +199,7 @@ class ServerArgs:
)
)
if
self
.
enable_overlap_schedule
:
if
self
.
enable_overlap_schedule
:
logger
.
warning
(
self
.
disable_jump_forward
=
True
"Overlap scheduler mode is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results. "
)
self
.
disable_penalizer
=
True
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
@@ -622,11 +616,6 @@ class ServerArgs:
...
@@ -622,11 +616,6 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable Multi-head Latent Attention (MLA) for DeepSeek-V2."
,
help
=
"Disable Multi-head Latent Attention (MLA) for DeepSeek-V2."
,
)
)
parser
.
add_argument
(
"--disable-penalizer"
,
action
=
"store_true"
,
help
=
"Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-nan-detection"
,
"--disable-nan-detection"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment