Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b121bc03
Unverified
Commit
b121bc03
authored
Oct 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 20, 2024
Browse files
Simplify batch result resolution (#1735)
parent
e12358dc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
64 additions
and
90 deletions
+64
-90
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-22
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+17
-17
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+31
-43
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-8
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
b121bc03
...
...
@@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
import
dataclasses
import
logging
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason):
}
@
dataclass
@
dataclass
es
.
dataclass
class
ImageInputs
:
"""The image related inputs."""
...
...
@@ -407,7 +407,7 @@ class Req:
bid
=
0
@
dataclass
@
dataclass
es
.
dataclass
class
ScheduleBatch
:
"""Store all inforamtion of a batch."""
...
...
@@ -902,7 +902,7 @@ class ScheduleBatch:
)
@
dataclass
@
dataclass
es
.
dataclass
class
ModelWorkerBatch
:
# The batch id
bid
:
int
...
...
@@ -942,24 +942,7 @@ class ModelWorkerBatch:
mrope_positions_delta
:
List
[
List
[
int
]]
def
copy
(
self
):
return
ModelWorkerBatch
(
bid
=
self
.
bid
,
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
req_to_token_pool_records
=
self
.
req_to_token_pool_records
,
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
extend_seq_lens
=
self
.
extend_seq_lens
,
extend_prefix_lens
=
self
.
extend_prefix_lens
,
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
,
image_inputs
=
self
.
image_inputs
,
lora_paths
=
self
.
lora_paths
,
sampling_info
=
self
.
sampling_info
.
copy
(),
mrope_positions_delta
=
self
.
mrope_positions_delta
,
)
return
dataclasses
.
replace
(
self
,
sampling_info
=
self
.
sampling_info
.
copy
())
def
to
(
self
,
device
:
str
):
self
.
input_ids
=
self
.
input_ids
.
to
(
device
,
non_blocking
=
True
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
b121bc03
...
...
@@ -149,12 +149,8 @@ class Scheduler:
# Launch a tensor parallel worker
if
self
.
enable_overlap
:
TpWorkerClass
=
TpModelWorkerClient
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
)
else
:
TpWorkerClass
=
TpModelWorker
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
self
.
tp_worker
=
TpWorkerClass
(
server_args
=
server_args
,
...
...
@@ -756,9 +752,12 @@ class Scheduler:
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
if
self
.
is_generation
:
logits_output
,
next_token_ids
,
bid
=
result
if
batch
.
return_logprob
:
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resulve_batch_result
(
bid
)
else
:
# Move next_token_ids and logprobs to cpu
if
batch
.
return_logprob
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
...
...
@@ -771,8 +770,7 @@ class Scheduler:
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
next_token_ids
=
self
.
resolve_next_token_ids
(
bid
,
next_token_ids
)
next_token_ids
=
next_token_ids
.
tolist
()
# Check finish conditions
logprob_pt
=
0
...
...
@@ -825,14 +823,16 @@ class Scheduler:
logits_output
,
next_token_ids
,
bid
=
result
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
# Move logprobs to cpu
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
self
.
resolve_next_token_ids
(
bid
,
next_token_ids
)
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resulve_batch_result
(
bid
)
else
:
# Move next_token_ids and logprobs to cpu
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
self
.
token_to_kv_pool
.
free_group_begin
()
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
b121bc03
...
...
@@ -48,19 +48,16 @@ class TpModelWorkerClient:
self
.
max_running_requests
=
self
.
worker
.
max_running_requests
self
.
device
=
self
.
worker
.
device
# Create future mappings
self
.
future_logits_output_dict
=
dict
()
self
.
future_logits_output_ct
=
0
# Init future mappings
self
.
future_token_ids_ct
=
0
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_map
=
torch
.
empty
(
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_output
=
dict
()
# Launch a thread
self
.
future_event_map
=
dict
()
self
.
forward
_queue
=
Queue
()
self
.
input_queue
=
Queue
()
self
.
output
_queue
=
Queue
()
self
.
forward_stream
=
torch
.
cuda
.
Stream
()
self
.
forward_thread
=
threading
.
Thread
(
target
=
self
.
forward_thread_func
,
...
...
@@ -90,9 +87,7 @@ class TpModelWorkerClient:
def
forward_thread_func_
(
self
):
while
True
:
tic1
=
time
.
time
()
model_worker_batch
,
future_logits_output
,
future_next_token_ids
=
(
self
.
forward_queue
.
get
()
)
model_worker_batch
,
future_token_ids_ct
=
self
.
input_queue
.
get
()
# Resolve future tokens in the input
tic2
=
time
.
time
()
...
...
@@ -107,17 +102,22 @@ class TpModelWorkerClient:
model_worker_batch
)
# Set future values
if
model_worker_batch
.
return_logprob
:
self
.
future_logits_output_dict
[
future_logits_output
]
=
logits_output
# Update the future token ids map
bs
=
len
(
model_worker_batch
.
seq_lens
)
future_next_token_ids
=
torch
.
arange
(
-
(
future_token_ids_ct
+
bs
),
-
(
future_token_ids_ct
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
future_token_ids_map
[
-
future_next_token_ids
]
=
next_token_ids
.
to
(
torch
.
int32
)
self
.
future_token_ids_output
[
model_worker_batch
.
bid
]
=
(
next_token_ids
.
tolist
()
)
self
.
future_event_map
[
model_worker_batch
.
bid
].
set
()
# Set the result
next_token_ids
=
next_token_ids
.
tolist
()
assert
logits_output
.
next_token_logprobs
is
None
,
"Not supported"
self
.
output_queue
.
put
((
None
,
next_token_ids
))
if
False
:
tic3
=
time
.
time
()
...
...
@@ -128,38 +128,26 @@ class TpModelWorkerClient:
f
"
{
self
.
acc_time_with_waiting
=
:.
3
f
}
,
{
self
.
acc_time_without_waiting
=
:.
3
f
}
,
{
self
.
forward_queue
.
qsize
()
=
}
"
)
def
resolve_future_token_ids
(
self
,
bid
:
int
):
self
.
future_event_map
[
bid
].
wait
()
ret
=
self
.
future_token_ids_output
[
bid
]
del
self
.
future_event_map
[
bid
]
return
ret
def
resolve_future_logits_output
(
self
,
future_obj
):
return
self
.
future_logits_output_dict
.
pop
(
future_obj
)
def
resulve_batch_result
(
self
,
bid
:
int
):
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
return
logits_output
,
next_token_ids
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
# Allocate output future objects
future_logits_output
=
self
.
future_logits_output_ct
self
.
future_logits_output_ct
+=
1
# Push a new batch to the queue
self
.
input_queue
.
put
((
model_worker_batch
.
copy
(),
self
.
future_token_ids_ct
))
# Allocate output future objects
bs
=
len
(
model_worker_batch
.
seq_lens
)
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
future_next_token_ids
=
-
torch
.
arange
(
self
.
future_token_ids_ct
+
1
,
self
.
future_token_ids_ct
+
1
+
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
future_next_token_ids
=
torch
.
arange
(
-
(
self
.
future_token_ids_ct
+
bs
),
-
(
self
.
future_token_ids_ct
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
ret
=
future_logits_output
,
future_next_token_ids
self
.
future_event_map
[
model_worker_batch
.
bid
]
=
threading
.
Event
()
self
.
forward_queue
.
put
(
(
model_worker_batch
.
copy
(),
future_logits_output
,
future_next_token_ids
)
)
return
ret
return
None
,
future_next_token_ids
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
b121bc03
...
...
@@ -120,7 +120,7 @@ class ModelRunner:
)
if
self
.
is_multimodal_model
:
logger
.
info
(
logger
.
warning
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args
.
chunked_prefill_size
=
None
...
...
@@ -131,13 +131,6 @@ class ModelRunner:
]:
server_args
.
disable_cuda_graph
=
True
if
self
.
server_args
.
enable_overlap_schedule
:
logger
.
warning
(
"Overlap scheduler is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results."
)
# Global vars
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
...
...
python/sglang/srt/server_args.py
View file @
b121bc03
...
...
@@ -177,6 +177,16 @@ class ServerArgs:
if
self
.
sampling_backend
is
None
:
self
.
sampling_backend
=
"flashinfer"
if
self
.
enable_overlap_schedule
:
logger
.
warning
(
"Overlap scheduler mode is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results. "
"The NaN detection is also disabled."
)
self
.
disable_penalizer
=
True
self
.
disable_nan_detection
=
True
# Model-specific patches
if
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
==
self
.
model_path
:
logger
.
info
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment