Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b121bc03
Unverified
Commit
b121bc03
authored
Oct 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 20, 2024
Browse files
Simplify batch result resolution (#1735)
parent
e12358dc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
64 additions
and
90 deletions
+64
-90
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-22
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+17
-17
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+31
-43
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-8
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
b121bc03
...
@@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
...
@@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
It contains low-level tensor data. Most of the data consists of GPU tensors.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
"""
import
dataclasses
import
logging
import
logging
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason):
...
@@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason):
}
}
@
dataclass
@
dataclass
es
.
dataclass
class
ImageInputs
:
class
ImageInputs
:
"""The image related inputs."""
"""The image related inputs."""
...
@@ -407,7 +407,7 @@ class Req:
...
@@ -407,7 +407,7 @@ class Req:
bid
=
0
bid
=
0
@
dataclass
@
dataclass
es
.
dataclass
class
ScheduleBatch
:
class
ScheduleBatch
:
"""Store all inforamtion of a batch."""
"""Store all inforamtion of a batch."""
...
@@ -902,7 +902,7 @@ class ScheduleBatch:
...
@@ -902,7 +902,7 @@ class ScheduleBatch:
)
)
@
dataclass
@
dataclass
es
.
dataclass
class
ModelWorkerBatch
:
class
ModelWorkerBatch
:
# The batch id
# The batch id
bid
:
int
bid
:
int
...
@@ -942,24 +942,7 @@ class ModelWorkerBatch:
...
@@ -942,24 +942,7 @@ class ModelWorkerBatch:
mrope_positions_delta
:
List
[
List
[
int
]]
mrope_positions_delta
:
List
[
List
[
int
]]
def
copy
(
self
):
def
copy
(
self
):
return
ModelWorkerBatch
(
return
dataclasses
.
replace
(
self
,
sampling_info
=
self
.
sampling_info
.
copy
())
bid
=
self
.
bid
,
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
req_to_token_pool_records
=
self
.
req_to_token_pool_records
,
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
extend_seq_lens
=
self
.
extend_seq_lens
,
extend_prefix_lens
=
self
.
extend_prefix_lens
,
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
,
image_inputs
=
self
.
image_inputs
,
lora_paths
=
self
.
lora_paths
,
sampling_info
=
self
.
sampling_info
.
copy
(),
mrope_positions_delta
=
self
.
mrope_positions_delta
,
)
def
to
(
self
,
device
:
str
):
def
to
(
self
,
device
:
str
):
self
.
input_ids
=
self
.
input_ids
.
to
(
device
,
non_blocking
=
True
)
self
.
input_ids
=
self
.
input_ids
.
to
(
device
,
non_blocking
=
True
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
b121bc03
...
@@ -149,12 +149,8 @@ class Scheduler:
...
@@ -149,12 +149,8 @@ class Scheduler:
# Launch a tensor parallel worker
# Launch a tensor parallel worker
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
TpWorkerClass
=
TpModelWorkerClient
TpWorkerClass
=
TpModelWorkerClient
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
)
else
:
else
:
TpWorkerClass
=
TpModelWorker
TpWorkerClass
=
TpModelWorker
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
self
.
tp_worker
=
TpWorkerClass
(
self
.
tp_worker
=
TpWorkerClass
(
server_args
=
server_args
,
server_args
=
server_args
,
...
@@ -756,9 +752,12 @@ class Scheduler:
...
@@ -756,9 +752,12 @@ class Scheduler:
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
if
self
.
is_generation
:
if
self
.
is_generation
:
logits_output
,
next_token_ids
,
bid
=
result
logits_output
,
next_token_ids
,
bid
=
result
if
batch
.
return_logprob
:
# Move logprobs to cpu
if
self
.
enable_overlap
:
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resulve_batch_result
(
bid
)
else
:
# Move next_token_ids and logprobs to cpu
if
batch
.
return_logprob
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
[
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
...
@@ -771,8 +770,7 @@ class Scheduler:
...
@@ -771,8 +770,7 @@ class Scheduler:
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
tolist
()
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
self
.
resolve_next_token_ids
(
bid
,
next_token_ids
)
# Check finish conditions
# Check finish conditions
logprob_pt
=
0
logprob_pt
=
0
...
@@ -825,14 +823,16 @@ class Scheduler:
...
@@ -825,14 +823,16 @@ class Scheduler:
logits_output
,
next_token_ids
,
bid
=
result
logits_output
,
next_token_ids
,
bid
=
result
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
# Move logprobs to cpu
if
self
.
enable_overlap
:
if
batch
.
return_logprob
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resulve_batch_result
(
bid
)
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
else
:
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
# Move next_token_ids and logprobs to cpu
next_token_ids
,
if
batch
.
return_logprob
:
].
tolist
()
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
next_token_ids
=
self
.
resolve_next_token_ids
(
bid
,
next_token_ids
)
next_token_ids
,
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
self
.
token_to_kv_pool
.
free_group_begin
()
self
.
token_to_kv_pool
.
free_group_begin
()
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
b121bc03
...
@@ -48,19 +48,16 @@ class TpModelWorkerClient:
...
@@ -48,19 +48,16 @@ class TpModelWorkerClient:
self
.
max_running_requests
=
self
.
worker
.
max_running_requests
self
.
max_running_requests
=
self
.
worker
.
max_running_requests
self
.
device
=
self
.
worker
.
device
self
.
device
=
self
.
worker
.
device
# Create future mappings
# Init future mappings
self
.
future_logits_output_dict
=
dict
()
self
.
future_logits_output_ct
=
0
self
.
future_token_ids_ct
=
0
self
.
future_token_ids_ct
=
0
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_map
=
torch
.
empty
(
self
.
future_token_ids_map
=
torch
.
empty
(
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_output
=
dict
()
# Launch a thread
# Launch a thread
self
.
future_event_map
=
dict
()
self
.
input_queue
=
Queue
()
self
.
forward
_queue
=
Queue
()
self
.
output
_queue
=
Queue
()
self
.
forward_stream
=
torch
.
cuda
.
Stream
()
self
.
forward_stream
=
torch
.
cuda
.
Stream
()
self
.
forward_thread
=
threading
.
Thread
(
self
.
forward_thread
=
threading
.
Thread
(
target
=
self
.
forward_thread_func
,
target
=
self
.
forward_thread_func
,
...
@@ -90,9 +87,7 @@ class TpModelWorkerClient:
...
@@ -90,9 +87,7 @@ class TpModelWorkerClient:
def
forward_thread_func_
(
self
):
def
forward_thread_func_
(
self
):
while
True
:
while
True
:
tic1
=
time
.
time
()
tic1
=
time
.
time
()
model_worker_batch
,
future_logits_output
,
future_next_token_ids
=
(
model_worker_batch
,
future_token_ids_ct
=
self
.
input_queue
.
get
()
self
.
forward_queue
.
get
()
)
# Resolve future tokens in the input
# Resolve future tokens in the input
tic2
=
time
.
time
()
tic2
=
time
.
time
()
...
@@ -107,17 +102,22 @@ class TpModelWorkerClient:
...
@@ -107,17 +102,22 @@ class TpModelWorkerClient:
model_worker_batch
model_worker_batch
)
)
# Set future values
# Update the future token ids map
if
model_worker_batch
.
return_logprob
:
bs
=
len
(
model_worker_batch
.
seq_lens
)
self
.
future_logits_output_dict
[
future_logits_output
]
=
logits_output
future_next_token_ids
=
torch
.
arange
(
-
(
future_token_ids_ct
+
bs
),
-
(
future_token_ids_ct
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
future_token_ids_map
[
-
future_next_token_ids
]
=
next_token_ids
.
to
(
self
.
future_token_ids_map
[
-
future_next_token_ids
]
=
next_token_ids
.
to
(
torch
.
int32
torch
.
int32
)
)
self
.
future_token_ids_output
[
model_worker_batch
.
bid
]
=
(
next_token_ids
.
tolist
()
# Set the result
)
next_token_ids
=
next_token_ids
.
tolist
()
self
.
future_event_map
[
model_worker_batch
.
bid
].
set
()
assert
logits_output
.
next_token_logprobs
is
None
,
"Not supported"
self
.
output_queue
.
put
((
None
,
next_token_ids
))
if
False
:
if
False
:
tic3
=
time
.
time
()
tic3
=
time
.
time
()
...
@@ -128,38 +128,26 @@ class TpModelWorkerClient:
...
@@ -128,38 +128,26 @@ class TpModelWorkerClient:
f
"
{
self
.
acc_time_with_waiting
=
:.
3
f
}
,
{
self
.
acc_time_without_waiting
=
:.
3
f
}
,
{
self
.
forward_queue
.
qsize
()
=
}
"
f
"
{
self
.
acc_time_with_waiting
=
:.
3
f
}
,
{
self
.
acc_time_without_waiting
=
:.
3
f
}
,
{
self
.
forward_queue
.
qsize
()
=
}
"
)
)
def
resolve_future_token_ids
(
self
,
bid
:
int
):
def
resulve_batch_result
(
self
,
bid
:
int
):
self
.
future_event_map
[
bid
].
wait
()
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
ret
=
self
.
future_token_ids_output
[
bid
]
return
logits_output
,
next_token_ids
del
self
.
future_event_map
[
bid
]
return
ret
def
resolve_future_logits_output
(
self
,
future_obj
):
return
self
.
future_logits_output_dict
.
pop
(
future_obj
)
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
# Allocate output future objects
# Push a new batch to the queue
future_logits_output
=
self
.
future_logits_output_ct
self
.
input_queue
.
put
((
model_worker_batch
.
copy
(),
self
.
future_token_ids_ct
))
self
.
future_logits_output_ct
+=
1
# Allocate output future objects
bs
=
len
(
model_worker_batch
.
seq_lens
)
bs
=
len
(
model_worker_batch
.
seq_lens
)
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
future_next_token_ids
=
torch
.
arange
(
future_next_token_ids
=
-
torch
.
arange
(
-
(
self
.
future_token_ids_ct
+
bs
),
self
.
future_token_ids_ct
+
1
,
-
(
self
.
future_token_ids_ct
),
self
.
future_token_ids_ct
+
1
+
bs
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
)
%
self
.
future_token_ids_limit
ret
=
future_logits_output
,
future_next_token_ids
return
None
,
future_next_token_ids
self
.
future_event_map
[
model_worker_batch
.
bid
]
=
threading
.
Event
()
self
.
forward_queue
.
put
(
(
model_worker_batch
.
copy
(),
future_logits_output
,
future_next_token_ids
)
)
return
ret
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
b121bc03
...
@@ -120,7 +120,7 @@ class ModelRunner:
...
@@ -120,7 +120,7 @@ class ModelRunner:
)
)
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
logger
.
info
(
logger
.
warning
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
)
server_args
.
chunked_prefill_size
=
None
server_args
.
chunked_prefill_size
=
None
...
@@ -131,13 +131,6 @@ class ModelRunner:
...
@@ -131,13 +131,6 @@ class ModelRunner:
]:
]:
server_args
.
disable_cuda_graph
=
True
server_args
.
disable_cuda_graph
=
True
if
self
.
server_args
.
enable_overlap_schedule
:
logger
.
warning
(
"Overlap scheduler is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results."
)
# Global vars
# Global vars
if
server_args
.
show_time_cost
:
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
enable_show_time_cost
()
...
...
python/sglang/srt/server_args.py
View file @
b121bc03
...
@@ -177,6 +177,16 @@ class ServerArgs:
...
@@ -177,6 +177,16 @@ class ServerArgs:
if
self
.
sampling_backend
is
None
:
if
self
.
sampling_backend
is
None
:
self
.
sampling_backend
=
"flashinfer"
self
.
sampling_backend
=
"flashinfer"
if
self
.
enable_overlap_schedule
:
logger
.
warning
(
"Overlap scheduler mode is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results. "
"The NaN detection is also disabled."
)
self
.
disable_penalizer
=
True
self
.
disable_nan_detection
=
True
# Model-specific patches
# Model-specific patches
if
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
==
self
.
model_path
:
if
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
==
self
.
model_path
:
logger
.
info
(
logger
.
info
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment