Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f724f1f1
Unverified
Commit
f724f1f1
authored
Aug 07, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 07, 2024
Browse files
PrefillAdder abstraction (#968)
parent
6db27f7b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
151 additions
and
135 deletions
+151
-135
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+122
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+27
-121
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-2
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+0
-1
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+0
-11
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
f724f1f1
...
...
@@ -17,6 +17,9 @@ limitations under the License.
import
random
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
class
PolicyScheduler
:
...
...
@@ -83,3 +86,122 @@ class PolicyScheduler:
for
child
in
childs
:
self
.
get_dfs_priority
(
child
,
node_to_priority
,
last_node_to_reqs
,
q
)
q
.
extend
(
last_node_to_reqs
[
cur_node
])
class
PrefillAdder
:
def
__init__
(
self
,
tree_cache
,
rem_total_tokens
,
rem_input_tokens
,
rem_chunk_tokens
,
):
self
.
tree_cache
=
tree_cache
self
.
rem_total_tokens
=
rem_total_tokens
self
.
rem_input_tokens
=
rem_input_tokens
self
.
rem_chunk_tokens
=
rem_chunk_tokens
self
.
can_run_list
=
[]
self
.
new_inflight_req
=
None
self
.
log_hit_tokens
=
0
self
.
log_input_tokens
=
0
def
no_remaining_tokens
(
self
):
return
(
self
.
rem_total_tokens
<=
0
or
self
.
rem_input_tokens
<=
0
or
(
self
.
rem_chunk_tokens
<=
0
if
self
.
rem_chunk_tokens
is
not
None
else
False
)
)
def
remove_running_tokens
(
self
,
running_batch
:
ScheduleBatch
,
new_token_ratio
:
float
):
self
.
rem_total_tokens
-=
sum
(
[
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
))
*
new_token_ratio
for
r
in
running_batch
.
reqs
]
)
def
_prefill_one_req
(
self
,
prefix_len
:
int
,
extend_input_len
:
int
,
max_new_tokens
:
int
):
self
.
rem_total_tokens
-=
extend_input_len
+
max_new_tokens
self
.
rem_input_tokens
-=
extend_input_len
if
self
.
rem_chunk_tokens
is
not
None
:
self
.
rem_chunk_tokens
-=
extend_input_len
self
.
log_hit_tokens
+=
prefix_len
self
.
log_input_tokens
+=
extend_input_len
def
add_inflight_req
(
self
,
req
:
Req
):
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
output_ids
req
.
extend_input_len
=
len
(
req
.
input_ids
)
-
len
(
req
.
prefix_indices
)
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
input_ids
=
req
.
input_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
self
.
can_run_list
.
append
(
req
)
self
.
_prefill_one_req
(
len
(
req
.
prefix_indices
),
req
.
extend_input_len
,
req
.
sampling_params
.
max_new_tokens
if
not
truncated
else
0
,
)
# Return if chunked prefill not finished
return
req
if
truncated
else
None
@
contextmanager
def
_lock_node
(
self
,
last_node
):
try
:
delta
=
self
.
tree_cache
.
inc_lock_ref
(
last_node
)
self
.
rem_total_tokens
+=
delta
yield
None
finally
:
delta
=
self
.
tree_cache
.
dec_lock_ref
(
last_node
)
self
.
rem_total_tokens
+=
delta
def
add_one_req
(
self
,
req
:
Req
):
total_tokens
=
req
.
extend_input_len
+
req
.
sampling_params
.
max_new_tokens
input_tokens
=
req
.
extend_input_len
prefix_len
=
len
(
req
.
prefix_indices
)
if
total_tokens
>=
self
.
rem_total_tokens
:
return
False
if
input_tokens
>
self
.
rem_input_tokens
and
len
(
self
.
can_run_list
)
!=
0
:
return
False
with
self
.
_lock_node
(
req
.
last_node
):
if
total_tokens
>
self
.
rem_total_tokens
:
return
False
if
(
self
.
rem_chunk_tokens
is
None
or
input_tokens
<=
self
.
rem_chunk_tokens
or
(
req
.
return_logprob
and
req
.
normalized_prompt_logprob
is
None
)
):
# Non-chunked prefill
self
.
can_run_list
.
append
(
req
)
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
_prefill_one_req
(
prefix_len
,
input_tokens
,
req
.
sampling_params
.
max_new_tokens
)
else
:
# Chunked prefill
trunc_len
=
self
.
rem_chunk_tokens
if
trunc_len
==
0
:
return
False
req
.
extend_input_len
=
trunc_len
req
.
input_ids
=
req
.
input_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
self
.
can_run_list
.
append
(
req
)
self
.
new_inflight_req
=
req
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
_prefill_one_req
(
prefix_len
,
trunc_len
,
0
)
return
True
python/sglang/srt/managers/tp_worker.py
View file @
f724f1f1
...
...
@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.managers.policy_scheduler
import
PolicyScheduler
from
sglang.srt.managers.policy_scheduler
import
PolicyScheduler
,
PrefillAdder
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
BaseFinishReason
,
...
...
@@ -377,151 +377,57 @@ class ModelTpServer:
# Get priority queue
self
.
waiting_queue
=
self
.
scheduler
.
get_priority_queue
(
self
.
waiting_queue
)
# Add requests if there is available space
can_run_list
=
[]
new_batch_total_tokens
=
0
new_batch_input_tokens
=
0
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
(),
self
.
max_prefill_tokens
,
self
.
chunked_prefill_size
,
)
if
self
.
running_batch
:
available_size
-=
sum
(
[
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
))
*
self
.
new_token_ratio
for
r
in
self
.
running_batch
.
reqs
]
)
# Handle the current inflight request
take_inflight
=
0
if
self
.
current_inflight_req
:
take_inflight
=
1
r
=
self
.
current_inflight_req
r
.
input_ids
=
r
.
origin_input_ids
+
r
.
output_ids
truncated
=
(
len
(
r
.
input_ids
)
-
len
(
r
.
prefix_indices
)
>
self
.
chunked_prefill_size
)
r
.
extend_input_len
=
min
(
len
(
r
.
input_ids
)
-
len
(
r
.
prefix_indices
),
self
.
chunked_prefill_size
if
self
.
running_batch
is
not
None
:
adder
.
remove_running_tokens
(
self
.
running_batch
,
self
.
new_token_ratio
)
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
)
r
.
input_ids
=
r
.
input_ids
[:
len
(
r
.
prefix_indices
)
+
r
.
extend_input_len
]
can_run_list
.
append
(
r
)
if
not
truncated
:
# Finish inflight
self
.
current_inflight_req
=
None
new_batch_total_tokens
+=
(
r
.
extend_input_len
+
r
.
sampling_params
.
max_new_tokens
)
new_batch_input_tokens
+=
r
.
extend_input_len
else
:
new_batch_total_tokens
+=
r
.
extend_input_len
new_batch_input_tokens
+=
r
.
extend_input_len
for
req
in
self
.
waiting_queue
:
if
req
.
return_logprob
and
req
.
normalized_prompt_logprob
is
None
:
# Need at least two tokens to compute normalized logprob
if
req
.
extend_input_len
<
2
:
delta
=
2
-
req
.
extend_input_len
req
.
extend_input_len
+=
delta
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
delta
]
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
delta
if
req
.
extend_input_len
==
0
and
req
.
sampling_params
.
max_new_tokens
>
0
:
# Need at least one token to compute logits
req
.
extend_input_len
=
1
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
1
]
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
1
res
=
adder
.
add_one_req
(
req
)
if
(
req
.
extend_input_len
+
req
.
sampling_params
.
max_new_tokens
+
new_batch_total_tokens
<
available_size
and
(
req
.
extend_input_len
+
new_batch_input_tokens
<=
self
.
max_prefill_tokens
or
len
(
can_run_list
)
==
0
)
not
res
or
adder
.
no_remaining_tokens
()
or
running_bs
+
len
(
adder
.
can_run_list
)
>=
self
.
max_running_requests
):
delta
=
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
available_size
+=
delta
if
not
(
req
.
extend_input_len
+
req
.
sampling_params
.
max_new_tokens
+
new_batch_total_tokens
<
available_size
):
# Undo locking
delta
=
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
available_size
+=
delta
break
else
:
# Add this request to the running batch
if
(
self
.
chunked_prefill_size
is
None
or
(
new_batch_input_tokens
+
req
.
extend_input_len
<=
self
.
chunked_prefill_size
)
or
(
req
.
return_logprob
and
req
.
normalized_prompt_logprob
is
None
)
):
can_run_list
.
append
(
req
)
new_batch_total_tokens
+=
(
req
.
extend_input_len
+
req
.
sampling_params
.
max_new_tokens
)
new_batch_input_tokens
+=
req
.
extend_input_len
else
:
trunc_len
=
self
.
chunked_prefill_size
-
new_batch_input_tokens
if
trunc_len
<=
0
:
# Undo locking
delta
=
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
available_size
+=
delta
break
req
.
extend_input_len
=
trunc_len
req
.
input_ids
=
req
.
input_ids
[
:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
can_run_list
.
append
(
req
)
self
.
current_inflight_req
=
req
new_batch_input_tokens
+=
req
.
extend_input_len
new_batch_total_tokens
+=
req
.
extend_input_len
break
else
:
break
if
running_bs
+
len
(
can_run_list
)
>=
self
.
max_running_requests
:
break
can_run_list
=
adder
.
can_run_list
if
adder
.
new_inflight_req
is
not
None
:
assert
self
.
current_inflight_req
is
None
self
.
current_inflight_req
=
adder
.
new_inflight_req
if
len
(
can_run_list
)
==
0
:
return
None
# Print stats
if
self
.
tp_rank
==
0
:
hit_tokens
=
sum
(
len
(
x
.
prefix_indices
)
for
x
in
can_run_list
)
self
.
tree_cache_metrics
[
"total"
]
+=
(
hit_tokens
+
new_batch_inpu
t_tokens
adder
.
log_input_tokens
+
adder
.
log_hi
t_tokens
)
/
10
**
9
self
.
tree_cache_metrics
[
"hit"
]
+=
hit_tokens
/
10
**
9
self
.
tree_cache_metrics
[
"hit"
]
+=
(
adder
.
log_
hit_tokens
)
/
10
**
9
tree_cache_hit_rate
=
(
self
.
tree_cache_metrics
[
"hit"
]
/
self
.
tree_cache_metrics
[
"total"
]
)
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Prefill batch. "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
new_batch
_input_tokens
}
, "
f
"#cached-token:
{
hit_tokens
}
, "
f
"#new-token:
{
adder
.
log
_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_
hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
take
_inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
-
len
(
can_run_list
)
+
has
_inflight
}
"
)
# Return the new batch
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
f724f1f1
...
...
@@ -130,7 +130,7 @@ class ModelRunner:
server_args
.
max_total_tokens
,
)
self
.
init_cublas
()
self
.
init_flash
_
infer
()
self
.
init_flashinfer
()
# Capture cuda graphs
self
.
init_cuda_graphs
()
...
...
@@ -287,7 +287,7 @@ class ModelRunner:
c
=
a
@
b
return
c
def
init_flash
_
infer
(
self
):
def
init_flashinfer
(
self
):
if
self
.
server_args
.
disable_flashinfer
:
self
.
flashinfer_prefill_wrapper_ragged
=
None
self
.
flashinfer_prefill_wrapper_paged
=
None
...
...
python/sglang/srt/models/gemma2.py
View file @
f724f1f1
...
...
@@ -38,7 +38,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
f724f1f1
...
...
@@ -46,8 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -368,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -394,14 +391,6 @@ class Qwen2MoeForCausalLM(nn.Module):
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment