Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6da1ab6b
Unverified
Commit
6da1ab6b
authored
Sep 24, 2024
by
Archit Patke
Committed by
GitHub
Sep 24, 2024
Browse files
[Core] Adding Priority Scheduling (#5958)
parent
01b6f9e1
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
410 additions
and
8 deletions
+410
-8
benchmarks/benchmark_prioritization.py
benchmarks/benchmark_prioritization.py
+295
-0
vllm/config.py
vllm/config.py
+4
-2
vllm/core/scheduler.py
vllm/core/scheduler.py
+77
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+20
-4
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+10
-2
vllm/sequence.py
vllm/sequence.py
+4
-0
No files found.
benchmarks/benchmark_prioritization.py
0 → 100644
View file @
6da1ab6b
"""Benchmark offline prioritization."""
import
argparse
import
json
import
random
import
time
from
typing
import
List
,
Optional
,
Tuple
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
def
sample_requests
(
dataset_path
:
str
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
fixed_output_len
:
Optional
[
int
],
)
->
List
[
Tuple
[
str
,
int
,
int
]]:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
raise
ValueError
(
"output_len too small"
)
# Load the dataset.
with
open
(
dataset_path
)
as
f
:
dataset
=
json
.
load
(
f
)
# Filter out the conversations with less than 2 turns.
dataset
=
[
data
for
data
in
dataset
if
len
(
data
[
"conversations"
])
>=
2
]
# Only keep the first two turns of each conversation.
dataset
=
[(
data
[
"conversations"
][
0
][
"value"
],
data
[
"conversations"
][
1
][
"value"
])
for
data
in
dataset
]
# Shuffle the dataset.
random
.
shuffle
(
dataset
)
# Filter out sequences that are too long or too short
filtered_dataset
:
List
[
Tuple
[
str
,
int
,
int
]]
=
[]
for
i
in
range
(
len
(
dataset
)):
if
len
(
filtered_dataset
)
==
num_requests
:
break
# Tokenize the prompts and completions.
prompt
=
dataset
[
i
][
0
]
prompt_token_ids
=
tokenizer
(
prompt
).
input_ids
completion
=
dataset
[
i
][
1
]
completion_token_ids
=
tokenizer
(
completion
).
input_ids
prompt_len
=
len
(
prompt_token_ids
)
output_len
=
len
(
completion_token_ids
)
if
fixed_output_len
is
None
else
fixed_output_len
if
prompt_len
<
4
or
output_len
<
4
:
# Prune too short sequences.
continue
if
prompt_len
>
1024
or
prompt_len
+
output_len
>
2048
:
# Prune too long sequences.
continue
#Select a equi-probable random priority
priority
=
0
if
random
.
random
()
<
0.5
else
1
filtered_dataset
.
append
((
prompt
,
prompt_len
,
output_len
,
priority
))
return
filtered_dataset
def
run_vllm
(
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
model
:
str
,
tokenizer
:
str
,
quantization
:
Optional
[
str
],
tensor_parallel_size
:
int
,
seed
:
int
,
n
:
int
,
use_beam_search
:
bool
,
trust_remote_code
:
bool
,
dtype
:
str
,
max_model_len
:
Optional
[
int
],
enforce_eager
:
bool
,
kv_cache_dtype
:
str
,
quantization_param_path
:
Optional
[
str
],
device
:
str
,
enable_prefix_caching
:
bool
,
enable_chunked_prefill
:
bool
,
max_num_batched_tokens
:
int
,
gpu_memory_utilization
:
float
=
0.9
,
download_dir
:
Optional
[
str
]
=
None
,
)
->
float
:
from
vllm
import
LLM
,
SamplingParams
llm
=
LLM
(
model
=
model
,
tokenizer
=
tokenizer
,
quantization
=
quantization
,
tensor_parallel_size
=
tensor_parallel_size
,
seed
=
seed
,
trust_remote_code
=
trust_remote_code
,
dtype
=
dtype
,
max_model_len
=
max_model_len
,
gpu_memory_utilization
=
gpu_memory_utilization
,
enforce_eager
=
enforce_eager
,
kv_cache_dtype
=
kv_cache_dtype
,
quantization_param_path
=
quantization_param_path
,
device
=
device
,
enable_prefix_caching
=
enable_prefix_caching
,
download_dir
=
download_dir
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
disable_log_stats
=
False
,
)
# Add the requests to the engine.
prompts
=
[]
sampling_params
=
[]
priority
=
[]
for
prompt
,
_
,
output_len
,
_priority
in
requests
:
prompts
.
append
(
prompt
)
priority
.
append
(
_priority
)
sampling_params
.
append
(
SamplingParams
(
n
=
n
,
temperature
=
0.0
if
use_beam_search
else
1.0
,
top_p
=
1.0
,
use_beam_search
=
use_beam_search
,
ignore_eos
=
True
,
max_tokens
=
output_len
,
))
start
=
time
.
perf_counter
()
llm
.
generate
(
prompts
,
sampling_params
,
priority
=
priority
,
use_tqdm
=
True
)
end
=
time
.
perf_counter
()
return
end
-
start
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
random
.
seed
(
args
.
seed
)
# Sample the requests.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
dataset
is
None
:
# Synthesize a prompt with the given input length.
prompt
=
"hi"
*
(
args
.
input_len
-
1
)
requests
=
[(
prompt
,
args
.
input_len
,
args
.
output_len
)
for
_
in
range
(
args
.
num_prompts
)]
else
:
requests
=
sample_requests
(
args
.
dataset
,
args
.
num_prompts
,
tokenizer
,
args
.
output_len
)
if
args
.
backend
==
"vllm"
:
elapsed_time
=
run_vllm
(
requests
,
args
.
model
,
args
.
tokenizer
,
args
.
quantization
,
args
.
tensor_parallel_size
,
args
.
seed
,
args
.
n
,
args
.
use_beam_search
,
args
.
trust_remote_code
,
args
.
dtype
,
args
.
max_model_len
,
args
.
enforce_eager
,
args
.
kv_cache_dtype
,
args
.
quantization_param_path
,
args
.
device
,
args
.
enable_prefix_caching
,
args
.
enable_chunked_prefill
,
args
.
max_num_batched_tokens
,
args
.
gpu_memory_utilization
,
args
.
download_dir
)
else
:
raise
ValueError
(
f
"Unknown backend:
{
args
.
backend
}
"
)
total_num_tokens
=
sum
(
prompt_len
+
output_len
for
_
,
prompt_len
,
output_len
,
priority
in
requests
)
print
(
f
"Throughput:
{
len
(
requests
)
/
elapsed_time
:.
2
f
}
requests/s, "
f
"
{
total_num_tokens
/
elapsed_time
:.
2
f
}
tokens/s"
)
# Output JSON results if specified
if
args
.
output_json
:
results
=
{
"elapsed_time"
:
elapsed_time
,
"num_requests"
:
len
(
requests
),
"total_num_tokens"
:
total_num_tokens
,
"requests_per_second"
:
len
(
requests
)
/
elapsed_time
,
"tokens_per_second"
:
total_num_tokens
/
elapsed_time
,
}
with
open
(
args
.
output_json
,
"w"
)
as
f
:
json
.
dump
(
results
,
f
,
indent
=
4
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark the throughput."
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
choices
=
[
"vllm"
,
"hf"
,
"mii"
],
default
=
"vllm"
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
None
,
help
=
"Path to the dataset."
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
None
,
help
=
"Input prompt length for each request"
)
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
None
,
help
=
"Output length for each request. Overrides the "
"output length from the dataset."
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"facebook/opt-125m"
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--quantization'
,
'-q'
,
choices
=
[
*
QUANTIZATION_METHODS
,
None
],
default
=
None
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1
,
help
=
"Number of generated sequences per prompt."
)
parser
.
add_argument
(
"--use-beam-search"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-prompts"
,
type
=
int
,
default
=
200
,
help
=
"Number of prompts to process."
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--trust-remote-code'
,
action
=
'store_true'
,
help
=
'trust remote code from huggingface'
)
parser
.
add_argument
(
'--max-model-len'
,
type
=
int
,
default
=
None
,
help
=
'Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'auto'
,
choices
=
[
'auto'
,
'half'
,
'float16'
,
'bfloat16'
,
'float'
,
'float32'
],
help
=
'data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
0.9
,
help
=
'the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.'
)
parser
.
add_argument
(
"--enforce-eager"
,
action
=
"store_true"
,
help
=
"enforce eager execution"
)
parser
.
add_argument
(
'--kv-cache-dtype'
,
type
=
str
,
choices
=
[
'auto'
,
'fp8'
,
'fp8_e5m2'
,
'fp8_e4m3'
],
default
=
"auto"
,
help
=
'Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)'
)
parser
.
add_argument
(
'--quantization-param-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.'
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cuda"
,
"cpu"
],
help
=
'device type for vLLM execution, supporting CUDA and CPU.'
)
parser
.
add_argument
(
"--enable-prefix-caching"
,
action
=
'store_true'
,
help
=
"enable automatic prefix caching for vLLM backend."
)
parser
.
add_argument
(
"--enable-chunked-prefill"
,
action
=
'store_true'
,
help
=
"enable chunked prefill for vLLM backend."
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
None
,
help
=
'maximum number of batched tokens per '
'iteration'
)
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
None
,
help
=
'directory to download and load the weights, '
'default to the default cache dir of huggingface'
)
parser
.
add_argument
(
'--output-json'
,
type
=
str
,
default
=
None
,
help
=
'Path to save the throughput results in JSON format.'
)
args
=
parser
.
parse_args
()
if
args
.
tokenizer
is
None
:
args
.
tokenizer
=
args
.
model
if
args
.
dataset
is
None
:
assert
args
.
input_len
is
not
None
assert
args
.
output_len
is
not
None
else
:
assert
args
.
input_len
is
None
main
(
args
)
vllm/config.py
View file @
6da1ab6b
...
@@ -961,7 +961,7 @@ class SchedulerConfig:
...
@@ -961,7 +961,7 @@ class SchedulerConfig:
workers instead of an entire data. It should be enabled only
workers instead of an entire data. It should be enabled only
when SPMD worker architecture is enabled. I.e.,
when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1
VLLM_USE_RAY_SPMD_WORKER=1
policy: The scheduling policy to use. "fcfs" (default) or "priority".
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -977,7 +977,8 @@ class SchedulerConfig:
...
@@ -977,7 +977,8 @@ class SchedulerConfig:
preemption_mode
:
Optional
[
str
]
=
None
,
preemption_mode
:
Optional
[
str
]
=
None
,
num_scheduler_steps
:
int
=
1
,
num_scheduler_steps
:
int
=
1
,
multi_step_stream_outputs
:
bool
=
False
,
multi_step_stream_outputs
:
bool
=
False
,
send_delta_data
:
bool
=
False
)
->
None
:
send_delta_data
:
bool
=
False
,
policy
:
str
=
"fcfs"
)
->
None
:
if
max_num_batched_tokens
is
None
:
if
max_num_batched_tokens
is
None
:
if
enable_chunked_prefill
:
if
enable_chunked_prefill
:
# It is the values that have the best balance between ITL
# It is the values that have the best balance between ITL
...
@@ -1019,6 +1020,7 @@ class SchedulerConfig:
...
@@ -1019,6 +1020,7 @@ class SchedulerConfig:
self
.
num_scheduler_steps
=
num_scheduler_steps
self
.
num_scheduler_steps
=
num_scheduler_steps
self
.
multi_step_stream_outputs
=
multi_step_stream_outputs
self
.
multi_step_stream_outputs
=
multi_step_stream_outputs
self
.
send_delta_data
=
send_delta_data
self
.
send_delta_data
=
send_delta_data
self
.
policy
=
policy
self
.
_verify_args
()
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
...
...
vllm/core/scheduler.py
View file @
6da1ab6b
...
@@ -766,6 +766,79 @@ class Scheduler:
...
@@ -766,6 +766,79 @@ class Scheduler:
else
:
else
:
return
prompt_limit
return
prompt_limit
def
_get_priority
(
self
,
seq_group
:
SequenceGroup
)
->
Tuple
[
Optional
[
int
],
float
]:
""" Get the priority of the sequence group.
Highest preference to user-defined priority, followed by arrival time.
Args:
seq_group: The sequence group input.
Returns:
The priority of the sequence group.
"""
return
seq_group
.
priority
,
seq_group
.
arrival_time
def
_schedule_priority_preemption
(
self
,
budget
:
SchedulingBudget
,
)
->
int
:
"""Sorts waiting and running queue. Also, force preempt requests
from the running queue if their priority is lower.
Priority-based preemption is used with the priority policy.
Args:
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
Returns:
A count of priority-based preemptions.
"""
waiting_queue
=
self
.
waiting
running_queue
=
deque
(
sorted
(
self
.
running
,
key
=
self
.
_get_priority
))
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
force_preemption_count
=
0
if
waiting_queue
:
seq_group
=
waiting_queue
.
popleft
()
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
num_new_tokens
=
self
.
_get_num_new_tokens
(
seq_group
,
SequenceStatus
.
WAITING
,
False
,
budget
)
#Only preempt if priority inversion exists
while
running_queue
and
self
.
_get_priority
(
running_queue
[
-
1
])
>
self
.
_get_priority
(
seq_group
):
#Only preempt if waiting sequence cannot be allocated
can_allocate
=
self
.
block_manager
.
can_allocate
(
seq_group
)
if
(
num_new_tokens
and
can_allocate
==
AllocStatus
.
OK
and
budget
.
can_schedule
(
num_new_tokens
=
num_new_tokens
,
num_new_seqs
=
num_new_seqs
)):
break
#Adjust budget to remove the victim sequence group
vseq_group
=
running_queue
.
pop
()
num_running_tokens
=
self
.
_get_num_new_tokens
(
vseq_group
,
SequenceStatus
.
RUNNING
,
False
,
budget
)
budget
.
subtract_num_batched_tokens
(
vseq_group
.
request_id
,
num_running_tokens
)
num_running_seqs
=
vseq_group
.
get_max_num_running_seqs
()
budget
.
subtract_num_seqs
(
vseq_group
.
request_id
,
num_running_seqs
)
#Preempt out the victim sequence group
self
.
_preempt
(
vseq_group
,
blocks_to_swap_out
,
PreemptionMode
.
RECOMPUTE
)
waiting_queue
.
appendleft
(
vseq_group
)
force_preemption_count
+=
1
#Put the sequence back into the waiting queue
waiting_queue
.
appendleft
(
seq_group
)
waiting_queue
=
deque
(
sorted
(
waiting_queue
,
key
=
self
.
_get_priority
))
self
.
waiting
=
waiting_queue
self
.
running
=
running_queue
return
force_preemption_count
def
_schedule_prefills
(
def
_schedule_prefills
(
self
,
self
,
budget
:
SchedulingBudget
,
budget
:
SchedulingBudget
,
...
@@ -917,6 +990,10 @@ class Scheduler:
...
@@ -917,6 +990,10 @@ class Scheduler:
curr_loras
,
curr_loras
,
enable_chunking
=
False
)
enable_chunking
=
False
)
if
len
(
prefills
.
seq_groups
)
==
0
and
self
.
scheduler_config
.
policy
==
"priority"
:
self
.
_schedule_priority_preemption
(
budget
)
# Don't schedule decodes if prefills are scheduled.
# Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills.
# only contains decode requests, not chunked prefills.
...
...
vllm/engine/llm_engine.py
View file @
6da1ab6b
...
@@ -631,6 +631,7 @@ class LLMEngine:
...
@@ -631,6 +631,7 @@ class LLMEngine:
lora_request
:
Optional
[
LoRARequest
],
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
None
:
)
->
None
:
self
.
_validate_model_inputs
(
processed_inputs
)
self
.
_validate_model_inputs
(
processed_inputs
)
# Create the sequences.
# Create the sequences.
...
@@ -661,7 +662,8 @@ class LLMEngine:
...
@@ -661,7 +662,8 @@ class LLMEngine:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
encoder_seq
=
encoder_seq
,
priority
=
priority
)
elif
isinstance
(
params
,
PoolingParams
):
elif
isinstance
(
params
,
PoolingParams
):
seq_group
=
self
.
_create_sequence_group_with_pooling
(
seq_group
=
self
.
_create_sequence_group_with_pooling
(
request_id
,
request_id
,
...
@@ -670,7 +672,8 @@ class LLMEngine:
...
@@ -670,7 +672,8 @@ class LLMEngine:
arrival_time
=
arrival_time
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
encoder_seq
=
encoder_seq
,
priority
=
priority
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Either SamplingParams or PoolingParams must be provided."
)
"Either SamplingParams or PoolingParams must be provided."
)
...
@@ -695,6 +698,7 @@ class LLMEngine:
...
@@ -695,6 +698,7 @@ class LLMEngine:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
)
->
None
:
"""Add a request to the engine's request pool.
"""Add a request to the engine's request pool.
...
@@ -713,6 +717,8 @@ class LLMEngine:
...
@@ -713,6 +717,8 @@ class LLMEngine:
arrival_time: The arrival time of the request. If None, we use
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
the current monotonic time.
trace_headers: OpenTelemetry trace headers.
trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Details:
Details:
- Set arrival_time to the current time if it is None.
- Set arrival_time to the current time if it is None.
...
@@ -741,6 +747,11 @@ class LLMEngine:
...
@@ -741,6 +747,11 @@ class LLMEngine:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
"not enabled!"
)
if
priority
>
0
and
not
self
.
scheduler_config
.
policy
==
"priority"
:
raise
ValueError
(
f
"Got priority
{
priority
}
but "
"Priority scheduling is not enabled."
)
if
arrival_time
is
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
...
@@ -760,6 +771,7 @@ class LLMEngine:
...
@@ -760,6 +771,7 @@ class LLMEngine:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
priority
,
)
)
def
_create_sequence_group_with_sampling
(
def
_create_sequence_group_with_sampling
(
...
@@ -772,6 +784,7 @@ class LLMEngine:
...
@@ -772,6 +784,7 @@ class LLMEngine:
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
priority
:
int
=
0
,
)
->
SequenceGroup
:
)
->
SequenceGroup
:
"""Creates a SequenceGroup with SamplingParams."""
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs
=
self
.
get_model_config
().
max_logprobs
max_logprobs
=
self
.
get_model_config
().
max_logprobs
...
@@ -798,7 +811,8 @@ class LLMEngine:
...
@@ -798,7 +811,8 @@ class LLMEngine:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
encoder_seq
=
encoder_seq
,
priority
=
priority
)
return
seq_group
return
seq_group
...
@@ -811,6 +825,7 @@ class LLMEngine:
...
@@ -811,6 +825,7 @@ class LLMEngine:
lora_request
:
Optional
[
LoRARequest
],
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
encoder_seq
:
Optional
[
Sequence
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
priority
:
int
=
0
,
)
->
SequenceGroup
:
)
->
SequenceGroup
:
"""Creates a SequenceGroup with PoolingParams."""
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
# Defensive copy of PoolingParams, which are used by the pooler
...
@@ -823,7 +838,8 @@ class LLMEngine:
...
@@ -823,7 +838,8 @@ class LLMEngine:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
pooling_params
=
pooling_params
,
pooling_params
=
pooling_params
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
encoder_seq
=
encoder_seq
,
priority
=
priority
)
return
seq_group
return
seq_group
def
abort_request
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
def
abort_request
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
...
...
vllm/entrypoints/llm.py
View file @
6da1ab6b
...
@@ -320,7 +320,8 @@ class LLM:
...
@@ -320,7 +320,8 @@ class LLM:
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
GuidedDecodingRequest
]]
=
None
GuidedDecodingRequest
]]
=
None
,
priority
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
"""Generates the completions for the input prompts.
"""Generates the completions for the input prompts.
...
@@ -339,6 +340,8 @@ class LLM:
...
@@ -339,6 +340,8 @@ class LLM:
lora_request: LoRA request to use for generation, if any.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
generation, if any.
priority: The priority of the requests, if any.
Only applicable when priority scheduling policy is enabled.
Returns:
Returns:
A list of ``RequestOutput`` objects containing the
A list of ``RequestOutput`` objects containing the
...
@@ -379,7 +382,8 @@ class LLM:
...
@@ -379,7 +382,8 @@ class LLM:
params
=
sampling_params
,
params
=
sampling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
guided_options
=
guided_options_request
)
guided_options
=
guided_options_request
,
priority
=
priority
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
...
@@ -782,6 +786,7 @@ class LLM:
...
@@ -782,6 +786,7 @@ class LLM:
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
priority
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
)
->
None
:
if
isinstance
(
inputs
,
(
str
,
dict
)):
if
isinstance
(
inputs
,
(
str
,
dict
)):
# Convert a single prompt to a list.
# Convert a single prompt to a list.
...
@@ -811,6 +816,7 @@ class LLM:
...
@@ -811,6 +816,7 @@ class LLM:
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
,
Sequence
)
else
lora_request
,
lora_request
,
Sequence
)
else
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
[
i
]
if
priority
else
0
,
)
)
def
_add_request
(
def
_add_request
(
...
@@ -819,6 +825,7 @@ class LLM:
...
@@ -819,6 +825,7 @@ class LLM:
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
)
->
None
:
request_id
=
str
(
next
(
self
.
request_counter
))
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
self
.
llm_engine
.
add_request
(
...
@@ -827,6 +834,7 @@ class LLM:
...
@@ -827,6 +834,7 @@ class LLM:
params
,
params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
)
)
def
_add_guided_processor
(
def
_add_guided_processor
(
...
...
vllm/sequence.py
View file @
6da1ab6b
...
@@ -646,6 +646,7 @@ class SequenceGroup:
...
@@ -646,6 +646,7 @@ class SequenceGroup:
unless you are working with an encoder/decoder model.
unless you are working with an encoder/decoder model.
trace_headers: OpenTelemetry trace headers.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request.
prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -660,9 +661,11 @@ class SequenceGroup:
...
@@ -660,9 +661,11 @@ class SequenceGroup:
encoder_seq
:
Optional
[
Sequence
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
seqs
=
seqs
self
.
seqs
=
seqs
self
.
arrival_time
=
arrival_time
self
.
is_single_seq
=
len
(
seqs
)
==
1
self
.
is_single_seq
=
len
(
seqs
)
==
1
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
...
@@ -680,6 +683,7 @@ class SequenceGroup:
...
@@ -680,6 +683,7 @@ class SequenceGroup:
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
encoder_seq
=
encoder_seq
self
.
encoder_seq
=
encoder_seq
self
.
trace_headers
=
trace_headers
self
.
trace_headers
=
trace_headers
self
.
priority
=
priority
self
.
cached_request_output
=
None
self
.
cached_request_output
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment