Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
b121bc03
"vscode:/vscode.git/clone" did not exist on "52ef1bc782ddf8fac0fae519fd61425eea3e5786"
Unverified
Commit
b121bc03
authored
Oct 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 20, 2024
Browse files
Simplify batch result resolution (#1735)
parent
e12358dc
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
64 additions
and
90 deletions
+64
-90
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-22
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+17
-17
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+31
-43
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-8
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
b121bc03
...
@@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
...
@@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
It contains low-level tensor data. Most of the data consists of GPU tensors.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
"""
import
dataclasses
import
logging
import
logging
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason):
...
@@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason):
}
}
@
dataclass
@
dataclass
es
.
dataclass
class
ImageInputs
:
class
ImageInputs
:
"""The image related inputs."""
"""The image related inputs."""
...
@@ -407,7 +407,7 @@ class Req:
...
@@ -407,7 +407,7 @@ class Req:
bid
=
0
bid
=
0
@
dataclass
@
dataclass
es
.
dataclass
class
ScheduleBatch
:
class
ScheduleBatch
:
"""Store all inforamtion of a batch."""
"""Store all inforamtion of a batch."""
...
@@ -902,7 +902,7 @@ class ScheduleBatch:
...
@@ -902,7 +902,7 @@ class ScheduleBatch:
)
)
@
dataclass
@
dataclass
es
.
dataclass
class
ModelWorkerBatch
:
class
ModelWorkerBatch
:
# The batch id
# The batch id
bid
:
int
bid
:
int
...
@@ -942,24 +942,7 @@ class ModelWorkerBatch:
...
@@ -942,24 +942,7 @@ class ModelWorkerBatch:
mrope_positions_delta
:
List
[
List
[
int
]]
mrope_positions_delta
:
List
[
List
[
int
]]
def
copy
(
self
):
def
copy
(
self
):
return
ModelWorkerBatch
(
return
dataclasses
.
replace
(
self
,
sampling_info
=
self
.
sampling_info
.
copy
())
bid
=
self
.
bid
,
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
req_to_token_pool_records
=
self
.
req_to_token_pool_records
,
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
extend_seq_lens
=
self
.
extend_seq_lens
,
extend_prefix_lens
=
self
.
extend_prefix_lens
,
extend_logprob_start_lens
=
self
.
extend_logprob_start_lens
,
image_inputs
=
self
.
image_inputs
,
lora_paths
=
self
.
lora_paths
,
sampling_info
=
self
.
sampling_info
.
copy
(),
mrope_positions_delta
=
self
.
mrope_positions_delta
,
)
def
to
(
self
,
device
:
str
):
def
to
(
self
,
device
:
str
):
self
.
input_ids
=
self
.
input_ids
.
to
(
device
,
non_blocking
=
True
)
self
.
input_ids
=
self
.
input_ids
.
to
(
device
,
non_blocking
=
True
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
b121bc03
...
@@ -149,12 +149,8 @@ class Scheduler:
...
@@ -149,12 +149,8 @@ class Scheduler:
# Launch a tensor parallel worker
# Launch a tensor parallel worker
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
TpWorkerClass
=
TpModelWorkerClient
TpWorkerClass
=
TpModelWorkerClient
self
.
resolve_next_token_ids
=
(
lambda
bid
,
x
:
self
.
tp_worker
.
resolve_future_token_ids
(
bid
)
)
else
:
else
:
TpWorkerClass
=
TpModelWorker
TpWorkerClass
=
TpModelWorker
self
.
resolve_next_token_ids
=
lambda
bid
,
x
:
x
.
tolist
()
self
.
tp_worker
=
TpWorkerClass
(
self
.
tp_worker
=
TpWorkerClass
(
server_args
=
server_args
,
server_args
=
server_args
,
...
@@ -756,9 +752,12 @@ class Scheduler:
...
@@ -756,9 +752,12 @@ class Scheduler:
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
if
self
.
is_generation
:
if
self
.
is_generation
:
logits_output
,
next_token_ids
,
bid
=
result
logits_output
,
next_token_ids
,
bid
=
result
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resulve_batch_result
(
bid
)
else
:
# Move next_token_ids and logprobs to cpu
if
batch
.
return_logprob
:
if
batch
.
return_logprob
:
# Move logprobs to cpu
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
[
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
...
@@ -771,8 +770,7 @@ class Scheduler:
...
@@ -771,8 +770,7 @@ class Scheduler:
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
=
(
logits_output
.
normalized_prompt_logprobs
.
tolist
()
logits_output
.
normalized_prompt_logprobs
.
tolist
()
)
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
self
.
resolve_next_token_ids
(
bid
,
next_token_ids
)
# Check finish conditions
# Check finish conditions
logprob_pt
=
0
logprob_pt
=
0
...
@@ -825,14 +823,16 @@ class Scheduler:
...
@@ -825,14 +823,16 @@ class Scheduler:
logits_output
,
next_token_ids
,
bid
=
result
logits_output
,
next_token_ids
,
bid
=
result
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
# Move logprobs to cpu
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resulve_batch_result
(
bid
)
else
:
# Move next_token_ids and logprobs to cpu
if
batch
.
return_logprob
:
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
next_token_logprobs
=
logits_output
.
next_token_logprobs
[
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
torch
.
arange
(
len
(
next_token_ids
),
device
=
self
.
device
),
next_token_ids
,
next_token_ids
,
].
tolist
()
].
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
self
.
resolve_next_token_ids
(
bid
,
next_token_ids
)
self
.
token_to_kv_pool
.
free_group_begin
()
self
.
token_to_kv_pool
.
free_group_begin
()
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
b121bc03
...
@@ -48,19 +48,16 @@ class TpModelWorkerClient:
...
@@ -48,19 +48,16 @@ class TpModelWorkerClient:
self
.
max_running_requests
=
self
.
worker
.
max_running_requests
self
.
max_running_requests
=
self
.
worker
.
max_running_requests
self
.
device
=
self
.
worker
.
device
self
.
device
=
self
.
worker
.
device
# Create future mappings
# Init future mappings
self
.
future_logits_output_dict
=
dict
()
self
.
future_logits_output_ct
=
0
self
.
future_token_ids_ct
=
0
self
.
future_token_ids_ct
=
0
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_map
=
torch
.
empty
(
self
.
future_token_ids_map
=
torch
.
empty
(
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
(
self
.
max_running_requests
*
5
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
future_token_ids_limit
=
self
.
max_running_requests
*
3
self
.
future_token_ids_output
=
dict
()
# Launch a thread
# Launch a thread
self
.
future_event_map
=
dict
()
self
.
input_queue
=
Queue
()
self
.
forward
_queue
=
Queue
()
self
.
output
_queue
=
Queue
()
self
.
forward_stream
=
torch
.
cuda
.
Stream
()
self
.
forward_stream
=
torch
.
cuda
.
Stream
()
self
.
forward_thread
=
threading
.
Thread
(
self
.
forward_thread
=
threading
.
Thread
(
target
=
self
.
forward_thread_func
,
target
=
self
.
forward_thread_func
,
...
@@ -90,9 +87,7 @@ class TpModelWorkerClient:
...
@@ -90,9 +87,7 @@ class TpModelWorkerClient:
def
forward_thread_func_
(
self
):
def
forward_thread_func_
(
self
):
while
True
:
while
True
:
tic1
=
time
.
time
()
tic1
=
time
.
time
()
model_worker_batch
,
future_logits_output
,
future_next_token_ids
=
(
model_worker_batch
,
future_token_ids_ct
=
self
.
input_queue
.
get
()
self
.
forward_queue
.
get
()
)
# Resolve future tokens in the input
# Resolve future tokens in the input
tic2
=
time
.
time
()
tic2
=
time
.
time
()
...
@@ -107,17 +102,22 @@ class TpModelWorkerClient:
...
@@ -107,17 +102,22 @@ class TpModelWorkerClient:
model_worker_batch
model_worker_batch
)
)
# Set future values
# Update the future token ids map
if
model_worker_batch
.
return_logprob
:
bs
=
len
(
model_worker_batch
.
seq_lens
)
self
.
future_logits_output_dict
[
future_logits_output
]
=
logits_output
future_next_token_ids
=
torch
.
arange
(
-
(
future_token_ids_ct
+
bs
),
-
(
future_token_ids_ct
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
future_token_ids_map
[
-
future_next_token_ids
]
=
next_token_ids
.
to
(
self
.
future_token_ids_map
[
-
future_next_token_ids
]
=
next_token_ids
.
to
(
torch
.
int32
torch
.
int32
)
)
self
.
future_token_ids_output
[
model_worker_batch
.
bid
]
=
(
next_token_ids
.
tolist
()
# Set the result
)
next_token_ids
=
next_token_ids
.
tolist
()
self
.
future_event_map
[
model_worker_batch
.
bid
].
set
()
assert
logits_output
.
next_token_logprobs
is
None
,
"Not supported"
self
.
output_queue
.
put
((
None
,
next_token_ids
))
if
False
:
if
False
:
tic3
=
time
.
time
()
tic3
=
time
.
time
()
...
@@ -128,38 +128,26 @@ class TpModelWorkerClient:
...
@@ -128,38 +128,26 @@ class TpModelWorkerClient:
f
"
{
self
.
acc_time_with_waiting
=
:.
3
f
}
,
{
self
.
acc_time_without_waiting
=
:.
3
f
}
,
{
self
.
forward_queue
.
qsize
()
=
}
"
f
"
{
self
.
acc_time_with_waiting
=
:.
3
f
}
,
{
self
.
acc_time_without_waiting
=
:.
3
f
}
,
{
self
.
forward_queue
.
qsize
()
=
}
"
)
)
def
resolve_future_token_ids
(
self
,
bid
:
int
):
def
resulve_batch_result
(
self
,
bid
:
int
):
self
.
future_event_map
[
bid
].
wait
()
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
ret
=
self
.
future_token_ids_output
[
bid
]
return
logits_output
,
next_token_ids
del
self
.
future_event_map
[
bid
]
return
ret
def
resolve_future_logits_output
(
self
,
future_obj
):
return
self
.
future_logits_output_dict
.
pop
(
future_obj
)
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
# Allocate output future objects
# Push a new batch to the queue
future_logits_output
=
self
.
future_logits_output_ct
self
.
input_queue
.
put
((
model_worker_batch
.
copy
(),
self
.
future_token_ids_ct
))
self
.
future_logits_output_ct
+=
1
# Allocate output future objects
bs
=
len
(
model_worker_batch
.
seq_lens
)
bs
=
len
(
model_worker_batch
.
seq_lens
)
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
future_next_token_ids
=
torch
.
arange
(
future_next_token_ids
=
-
torch
.
arange
(
-
(
self
.
future_token_ids_ct
+
bs
),
self
.
future_token_ids_ct
+
1
,
-
(
self
.
future_token_ids_ct
),
self
.
future_token_ids_ct
+
1
+
bs
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
)
%
self
.
future_token_ids_limit
ret
=
future_logits_output
,
future_next_token_ids
return
None
,
future_next_token_ids
self
.
future_event_map
[
model_worker_batch
.
bid
]
=
threading
.
Event
()
self
.
forward_queue
.
put
(
(
model_worker_batch
.
copy
(),
future_logits_output
,
future_next_token_ids
)
)
return
ret
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
b121bc03
...
@@ -120,7 +120,7 @@ class ModelRunner:
...
@@ -120,7 +120,7 @@ class ModelRunner:
)
)
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
logger
.
info
(
logger
.
warning
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
)
server_args
.
chunked_prefill_size
=
None
server_args
.
chunked_prefill_size
=
None
...
@@ -131,13 +131,6 @@ class ModelRunner:
...
@@ -131,13 +131,6 @@ class ModelRunner:
]:
]:
server_args
.
disable_cuda_graph
=
True
server_args
.
disable_cuda_graph
=
True
if
self
.
server_args
.
enable_overlap_schedule
:
logger
.
warning
(
"Overlap scheduler is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results."
)
# Global vars
# Global vars
if
server_args
.
show_time_cost
:
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
enable_show_time_cost
()
...
...
python/sglang/srt/server_args.py
View file @
b121bc03
...
@@ -177,6 +177,16 @@ class ServerArgs:
...
@@ -177,6 +177,16 @@ class ServerArgs:
if
self
.
sampling_backend
is
None
:
if
self
.
sampling_backend
is
None
:
self
.
sampling_backend
=
"flashinfer"
self
.
sampling_backend
=
"flashinfer"
if
self
.
enable_overlap_schedule
:
logger
.
warning
(
"Overlap scheduler mode is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results. "
"The NaN detection is also disabled."
)
self
.
disable_penalizer
=
True
self
.
disable_nan_detection
=
True
# Model-specific patches
# Model-specific patches
if
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
==
self
.
model_path
:
if
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
==
self
.
model_path
:
logger
.
info
(
logger
.
info
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment