Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf069aa8
Unverified
Commit
cf069aa8
authored
Mar 03, 2025
by
Harry Mellor
Committed by
GitHub
Mar 02, 2025
Browse files
Update deprecated Python 3.8 typing (#13971)
parent
bf33700e
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
303 additions
and
306 deletions
+303
-306
tests/v1/sample/utils.py
tests/v1/sample/utils.py
+2
-3
tests/v1/test_utils.py
tests/v1/test_utils.py
+2
-4
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+11
-11
tests/vllm_test_utils/vllm_test_utils/blame.py
tests/vllm_test_utils/vllm_test_utils/blame.py
+2
-1
tests/vllm_test_utils/vllm_test_utils/monitor.py
tests/vllm_test_utils/vllm_test_utils/monitor.py
+2
-1
tests/worker/test_encoder_decoder_model_runner.py
tests/worker/test_encoder_decoder_model_runner.py
+10
-11
tests/worker/test_model_input.py
tests/worker/test_model_input.py
+5
-6
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+9
-11
tools/profiler/print_layerwise_table.py
tools/profiler/print_layerwise_table.py
+1
-2
tools/profiler/visualize_layerwise_profile.py
tools/profiler/visualize_layerwise_profile.py
+7
-7
vllm/_custom_ops.py
vllm/_custom_ops.py
+28
-28
vllm/_ipex_ops.py
vllm/_ipex_ops.py
+4
-4
vllm/beam_search.py
vllm/beam_search.py
+9
-9
vllm/config.py
vllm/config.py
+71
-70
vllm/connections.py
vllm/connections.py
+2
-1
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+2
-1
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+25
-24
vllm/entrypoints/cli/openai.py
vllm/entrypoints/cli/openai.py
+5
-5
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+1
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+105
-105
No files found.
tests/v1/sample/utils.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
re
from
typing
import
List
,
Tuple
from
vllm
import
CompletionOutput
def
get_test_batch
(
batch_logprobs_composition
:
str
)
->
L
ist
[
T
uple
]:
def
get_test_batch
(
batch_logprobs_composition
:
str
)
->
l
ist
[
t
uple
]:
"""Generate logprobs configs for a batch of requests
A given request's logprobs configuration is (1) num_sample_logprobs and (2)
...
...
@@ -32,7 +31,7 @@ def get_test_batch(batch_logprobs_composition: str) -> List[Tuple]:
Returns:
L
ist of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs])
l
ist of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs])
tuples
"""
if
batch_logprobs_composition
==
"NONE"
:
...
...
tests/v1/test_utils.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
torch
from
vllm.v1.utils
import
bind_kv_cache
...
...
@@ -22,7 +20,7 @@ def test_bind_kv_cache():
'layers.2.self_attn'
:
torch
.
zeros
((
1
,
)),
'layers.3.self_attn'
:
torch
.
zeros
((
1
,
)),
}
runner_kv_caches
:
L
ist
[
torch
.
Tensor
]
=
[]
runner_kv_caches
:
l
ist
[
torch
.
Tensor
]
=
[]
bind_kv_cache
(
kv_cache
,
ctx
,
runner_kv_caches
)
assert
ctx
[
'layers.0.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
'layers.0.self_attn'
]
...
...
@@ -52,7 +50,7 @@ def test_bind_kv_cache_non_attention():
'model.layers.28.attn'
:
torch
.
zeros
((
1
,
)),
}
runner_kv_caches
:
L
ist
[
torch
.
Tensor
]
=
[]
runner_kv_caches
:
l
ist
[
torch
.
Tensor
]
=
[]
bind_kv_cache
(
kv_cache
,
ctx
,
runner_kv_caches
)
assert
ctx
[
'model.layers.20.attn'
].
kv_cache
[
0
]
is
kv_cache
[
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Optional
import
numpy
as
np
import
pytest
...
...
@@ -22,22 +22,22 @@ MAX_NUM_PROMPT_TOKENS = 64
def
_remove_requests
(
input_batch
:
InputBatch
,
batch_size
:
int
,
reqs
:
L
ist
[
CachedRequestState
])
->
T
uple
[
S
et
[
str
],
L
ist
[
int
]]:
reqs
:
l
ist
[
CachedRequestState
])
->
t
uple
[
s
et
[
str
],
l
ist
[
int
]]:
"""
Remove some requests randomly from the batch and returns a
T
uple
Remove some requests randomly from the batch and returns a
t
uple
of 1) set of request removed 2) indices of the requests removed
ordered in descending order
"""
num_reqs_to_remove
=
np
.
random
.
randint
(
0
,
batch_size
)
req_indices_to_remove
:
S
et
[
int
]
=
set
()
req_indices_to_remove
:
s
et
[
int
]
=
set
()
for
_
in
range
(
num_reqs_to_remove
):
req_index_to_remove
=
np
.
random
.
randint
(
0
,
batch_size
)
req_indices_to_remove
.
add
(
req_index_to_remove
)
req_indices_to_remove_list
=
list
(
req_indices_to_remove
)
req_indices_to_remove_list
.
sort
(
reverse
=
True
)
req_ids_to_remove
:
S
et
[
str
]
=
set
()
req_ids_to_remove
:
s
et
[
str
]
=
set
()
for
index
in
req_indices_to_remove
:
input_batch
.
remove_request
(
reqs
[
index
].
req_id
)
req_ids_to_remove
.
add
(
reqs
[
index
].
req_id
)
...
...
@@ -45,9 +45,9 @@ def _remove_requests(
def
_construct_expected_sampling_metadata
(
reqs
:
L
ist
[
CachedRequestState
],
req_ids_retained
:
S
et
[
int
],
req_id_index_in_input_batch
:
D
ict
[
str
,
int
],
reqs
:
l
ist
[
CachedRequestState
],
req_ids_retained
:
s
et
[
int
],
req_id_index_in_input_batch
:
d
ict
[
str
,
int
],
device
:
torch
.
device
,
)
->
SamplingMetadata
:
"""
...
...
@@ -55,8 +55,8 @@ def _construct_expected_sampling_metadata(
batch.
"""
num_reqs
=
len
(
req_ids_retained
)
output_token_ids
:
L
ist
[
L
ist
[
int
]]
=
[
list
()
for
_
in
range
(
num_reqs
)]
prompt_token_ids
:
L
ist
[
L
ist
[
int
]]
=
[
list
()
for
_
in
range
(
num_reqs
)]
output_token_ids
:
l
ist
[
l
ist
[
int
]]
=
[
list
()
for
_
in
range
(
num_reqs
)]
prompt_token_ids
:
l
ist
[
l
ist
[
int
]]
=
[
list
()
for
_
in
range
(
num_reqs
)]
presence_penalties
=
[
0.0
for
_
in
range
(
num_reqs
)]
frequency_penalties
=
[
0.0
for
_
in
range
(
num_reqs
)]
repetition_penalties
=
[
1.0
for
_
in
range
(
num_reqs
)]
...
...
@@ -191,7 +191,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
)
reqs
:
L
ist
[
CachedRequestState
]
=
[]
reqs
:
l
ist
[
CachedRequestState
]
=
[]
req_id_reqs
=
{}
req_id_output_token_ids
=
{}
# Add requests
...
...
tests/vllm_test_utils/vllm_test_utils/blame.py
View file @
cf069aa8
...
...
@@ -4,7 +4,8 @@ import contextlib
import
dataclasses
import
sys
import
traceback
from
typing
import
Callable
,
Generator
from
collections.abc
import
Generator
from
typing
import
Callable
@
dataclasses
.
dataclass
...
...
tests/vllm_test_utils/vllm_test_utils/monitor.py
View file @
cf069aa8
...
...
@@ -4,7 +4,8 @@ import contextlib
import
dataclasses
import
sys
import
traceback
from
typing
import
Callable
,
Generator
,
Generic
,
TypeVar
from
collections.abc
import
Generator
from
typing
import
Callable
,
Generic
,
TypeVar
_T
=
TypeVar
(
"_T"
)
...
...
tests/worker/test_encoder_decoder_model_runner.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
itertools
from
typing
import
List
import
pytest
import
torch
...
...
@@ -43,7 +42,7 @@ def test_empty_seq_group():
enable_chunked_prefill
=
False
,
enforce_eager
=
True
,
)
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
model_input
=
model_runner
.
_prepare_model_input_tensors
(
seq_group_metadata_list
)
(
...
...
@@ -103,9 +102,9 @@ def test_prepare_prompt(batch_size):
enforce_eager
=
True
,
)
seq_lens
:
L
ist
[
int
]
=
[]
encoder_seq_lens
:
L
ist
[
int
]
=
[]
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
encoder_seq_lens
:
l
ist
[
int
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
]}
cross_block_table
=
[
2
]
for
i
in
range
(
batch_size
):
...
...
@@ -295,9 +294,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
enforce_eager
=
True
,
)
seq_lens
:
L
ist
[
int
]
=
[]
encoder_seq_lens
:
L
ist
[
int
]
=
[]
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
encoder_seq_lens
:
l
ist
[
int
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
],
1
:
[
3
]
...
...
@@ -503,9 +502,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
}
if
multiple_seqs_per_seq_group
else
{
0
:
[
1
]
}
seq_lens
:
L
ist
[
int
]
=
[]
encoder_seq_lens
:
L
ist
[
int
]
=
[]
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
encoder_seq_lens
:
l
ist
[
int
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
cross_block_table
=
[
2
]
expanded_batch_size
=
0
...
...
tests/worker/test_model_input.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
dataclasses
from
typing
import
List
,
Tuple
,
Type
import
torch
...
...
@@ -27,15 +26,15 @@ class MockAttentionBackend(AttentionBackend):
raise
NotImplementedError
@
staticmethod
def
get_metadata_cls
()
->
T
ype
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
t
ype
[
"AttentionMetadata"
]:
return
AttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
T
ype
[
"AttentionMetadataBuilder"
]:
def
get_builder_cls
()
->
t
ype
[
"AttentionMetadataBuilder"
]:
return
AttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
T
ype
[
"CommonAttentionState"
]:
def
get_state_cls
()
->
t
ype
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
...
...
@@ -44,7 +43,7 @@ class MockAttentionBackend(AttentionBackend):
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
T
uple
[
int
,
...]:
)
->
t
uple
[
int
,
...]:
raise
NotImplementedError
@
staticmethod
...
...
@@ -57,7 +56,7 @@ class MockAttentionBackend(AttentionBackend):
@
staticmethod
def
copy_blocks
(
kv_caches
:
L
ist
[
torch
.
Tensor
],
kv_caches
:
l
ist
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
pass
...
...
tests/worker/test_model_runner.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
pytest
import
torch
...
...
@@ -42,8 +40,8 @@ def test_prepare_prompt(batch_size):
enable_chunked_prefill
=
False
,
)
seq_lens
:
L
ist
[
int
]
=
[]
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
]}
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
...
...
@@ -159,8 +157,8 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill
=
False
,
)
context_lens
:
L
ist
[
int
]
=
[]
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
context_lens
:
l
ist
[
int
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
# Assume each seq group finishes prefill.
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
...
...
@@ -265,7 +263,7 @@ def test_empty_seq_group():
dtype
=
"float16"
,
enforce_eager
=
False
,
)
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
model_input
=
model_runner
.
_prepare_model_input_tensors
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
attn_metadata
=
(
...
...
@@ -315,10 +313,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
)
# Add prefill requests.
seq_lens
:
L
ist
[
int
]
=
[]
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
prefill_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
decode_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
=
[]
seq_lens
:
l
ist
[
int
]
=
[]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
prefill_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
decode_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
]}
prefill_batch_size
=
batch_size
//
2
decode_batch_size
=
batch_size
-
prefill_batch_size
...
...
tools/profiler/print_layerwise_table.py
View file @
cf069aa8
...
...
@@ -2,13 +2,12 @@
import
argparse
import
json
from
typing
import
Dict
from
vllm.profiler.layerwise_profile
import
ModelStatsEntry
,
SummaryStatsEntry
from
vllm.profiler.utils
import
TablePrinter
,
indent_string
def
flatten_entries
(
entry_cls
,
profile_dict
:
D
ict
):
def
flatten_entries
(
entry_cls
,
profile_dict
:
d
ict
):
entries_and_depth
=
[]
def
get_entries
(
node
,
curr_depth
=
0
):
...
...
tools/profiler/visualize_layerwise_profile.py
View file @
cf069aa8
...
...
@@ -6,7 +6,7 @@ import json
import
math
import
os
from
pathlib
import
Path
from
typing
import
Any
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Optional
import
matplotlib.pyplot
as
plt
import
pandas
as
pd
...
...
@@ -24,7 +24,7 @@ def largest_dist_from_leaf(node: dict, depth: int = 0):
def
get_entries_at_depth
(
depth
:
int
,
entries_and_traces
:
L
ist
[
T
uple
[
Any
,
Any
]],
entries_and_traces
:
l
ist
[
t
uple
[
Any
,
Any
]],
node
:
dict
,
curr_depth
:
int
=
0
,
trace
=
()):
...
...
@@ -48,9 +48,9 @@ def get_entries_at_depth(depth: int,
trace
=
trace
)
def
fold_nodes
(
root
:
dict
,
nodes_to_fold
:
L
ist
[
str
]):
def
fold_nodes
(
root
:
dict
,
nodes_to_fold
:
l
ist
[
str
]):
stack
:
L
ist
[
dict
]
=
[
root
]
stack
:
l
ist
[
dict
]
=
[
root
]
while
len
(
stack
)
!=
0
:
node
=
stack
.
pop
()
if
node
[
'entry'
][
'name'
]
in
nodes_to_fold
:
...
...
@@ -427,12 +427,12 @@ def main(
plot_metric
:
str
,
make_names_unique
:
bool
,
top_k
:
int
,
json_nodes_to_fold
:
L
ist
[
str
]):
json_nodes_to_fold
:
l
ist
[
str
]):
def
prepare_data
(
profile_json
:
dict
,
step_keys
:
L
ist
[
str
])
->
pd
.
DataFrame
:
def
prepare_data
(
profile_json
:
dict
,
step_keys
:
l
ist
[
str
])
->
pd
.
DataFrame
:
def
get_entries_and_traces
(
key
:
str
):
entries_and_traces
:
L
ist
[
T
uple
[
Any
,
Any
]]
=
[]
entries_and_traces
:
l
ist
[
t
uple
[
Any
,
Any
]]
=
[]
for
root
in
profile_json
[
key
][
"summary_stats"
]:
# Fold nodes in the traces as per user request. i.e. simply
# make the requested nodes leaf-nodes.
...
...
vllm/_custom_ops.py
View file @
cf069aa8
...
...
@@ -2,7 +2,7 @@
import
contextlib
import
importlib
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
import
torch.library
...
...
@@ -198,7 +198,7 @@ def rms_norm_dynamic_per_token_quant(
quant_dtype
:
torch
.
dtype
,
scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
output
=
torch
.
empty_like
(
input
,
dtype
=
quant_dtype
)
scales
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
...
...
@@ -347,7 +347,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@
register_fake
(
"_C::aqlm_gemm"
)
def
_aqlm_gemm_fake
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebook_partition_sizes
:
L
ist
[
int
],
codebook_partition_sizes
:
l
ist
[
int
],
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
out_features
=
codes
.
size
(
0
)
*
codebooks
.
size
(
2
)
flat_input
=
input
.
reshape
((
-
1
,
input
.
size
(
-
1
)))
...
...
@@ -363,7 +363,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@
register_fake
(
"_C::aqlm_dequant"
)
def
_aqlm_dequant_fake
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
codebook_partition_sizes
:
L
ist
[
int
])
->
torch
.
Tensor
:
codebook_partition_sizes
:
l
ist
[
int
])
->
torch
.
Tensor
:
in_features
=
codes
.
size
(
1
)
*
8
out_features
=
codes
.
size
(
0
)
return
torch
.
empty
((
out_features
,
in_features
),
...
...
@@ -554,7 +554,7 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
def
cutlass_sparse_compress
(
a
:
torch
.
Tensor
)
\
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compresses a sparse matrix for use with Cutlass sparse operations.
...
...
@@ -571,7 +571,7 @@ def cutlass_sparse_compress(a: torch.Tensor) \
- `torch.float16`
Returns:
T
uple[torch.Tensor, torch.Tensor]:
t
uple[torch.Tensor, torch.Tensor]:
A tuple containing:
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
...
...
@@ -646,14 +646,14 @@ def cutlass_scaled_sparse_mm(
# aqlm
def
aqlm_gemm
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebook_partition_sizes
:
L
ist
[
int
],
codebook_partition_sizes
:
l
ist
[
int
],
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
aqlm_gemm
(
input
,
codes
,
codebooks
,
scales
,
codebook_partition_sizes
,
bias
)
def
aqlm_dequant
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
codebook_partition_sizes
:
L
ist
[
int
])
->
torch
.
Tensor
:
codebook_partition_sizes
:
l
ist
[
int
])
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
aqlm_dequant
(
codes
,
codebooks
,
codebook_partition_sizes
)
...
...
@@ -738,7 +738,7 @@ def machete_supported_schedules(
group_zeros_type
:
Optional
[
torch
.
dtype
]
=
None
,
channel_scales_type
:
Optional
[
torch
.
dtype
]
=
None
,
token_scales_type
:
Optional
[
torch
.
dtype
]
=
None
,
out_type
:
Optional
[
torch
.
dtype
]
=
None
)
->
L
ist
[
str
]:
out_type
:
Optional
[
torch
.
dtype
]
=
None
)
->
l
ist
[
str
]:
return
torch
.
ops
.
_C
.
machete_supported_schedules
(
a_type
,
b_type
.
id
,
group_scales_type
,
group_zeros_type
,
channel_scales_type
,
token_scales_type
,
out_type
)
...
...
@@ -783,7 +783,7 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
# fp4
def
scaled_fp4_quant
(
input
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
input_global_scale
:
torch
.
Tensor
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
...
...
@@ -798,7 +798,7 @@ def scaled_fp4_quant(
input_global_scale: A scalar scaling factor for the entire tensor.
Returns:
T
uple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
t
uple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
two values are packed into a uint8 and float8_e4m3 scaling factors
in the sizzled layout.
"""
...
...
@@ -845,7 +845,7 @@ def scaled_fp8_quant(
num_token_padding
:
Optional
[
int
]
=
None
,
scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
...
...
@@ -866,12 +866,12 @@ def scaled_fp8_quant(
in the dynamic quantization case.
Returns:
T
uple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
t
uple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert
(
input
.
ndim
==
2
)
shape
:
Union
[
T
uple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
shape
:
Union
[
t
uple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fnuz
\
if
current_platform
.
is_rocm
()
else
torch
.
float8_e4m3fn
...
...
@@ -903,7 +903,7 @@ def allspark_repack_weight(
scale
:
torch
.
Tensor
,
zero_point
:
Optional
[
torch
.
Tensor
]
=
None
,
has_zp
:
bool
=
False
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
for Ampere W8A16 Fused Gemm kernel
...
...
@@ -917,7 +917,7 @@ def allspark_repack_weight(
if use asymmetric quantization, has_zp = True.
Returns:
T
uple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
t
uple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
rearranged weight, scale, and optionally zero_point.
"""
K
=
qweight
.
shape
[
0
]
...
...
@@ -964,7 +964,7 @@ def scaled_int8_quant(
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
azp
:
Optional
[
torch
.
Tensor
]
=
None
,
symmetric
:
bool
=
True
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
...
...
@@ -977,7 +977,7 @@ def scaled_int8_quant(
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
Returns:
T
uple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
t
uple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
"""
output
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
int8
)
if
scale
is
not
None
:
...
...
@@ -1165,13 +1165,13 @@ def concat_and_cache_mla(
scale
)
def
copy_blocks
(
key_caches
:
L
ist
[
torch
.
Tensor
],
value_caches
:
L
ist
[
torch
.
Tensor
],
def
copy_blocks
(
key_caches
:
l
ist
[
torch
.
Tensor
],
value_caches
:
l
ist
[
torch
.
Tensor
],
block_mapping
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
def
copy_blocks_mla
(
kv_caches
:
L
ist
[
torch
.
Tensor
],
def
copy_blocks_mla
(
kv_caches
:
l
ist
[
torch
.
Tensor
],
block_mapping
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
copy_blocks_mla
(
kv_caches
,
block_mapping
)
...
...
@@ -1209,7 +1209,7 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
# custom ar
def
init_custom_ar
(
ipc_tensors
:
L
ist
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
def
init_custom_ar
(
ipc_tensors
:
l
ist
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
rank
:
int
,
full_nvlink
:
bool
)
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
ipc_tensors
,
rank_data
,
rank
,
full_nvlink
)
...
...
@@ -1229,16 +1229,16 @@ def meta_size() -> int:
return
torch
.
ops
.
_C_custom_ar
.
meta_size
()
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
L
ist
[
int
])
->
None
:
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
l
ist
[
int
])
->
None
:
return
torch
.
ops
.
_C_custom_ar
.
register_buffer
(
fa
,
ipc_tensors
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]:
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]:
return
torch
.
ops
.
_C_custom_ar
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
L
ist
[
L
ist
[
int
]],
offsets
:
L
ist
[
L
ist
[
int
]])
->
None
:
def
register_graph_buffers
(
fa
:
int
,
handles
:
l
ist
[
l
ist
[
int
]],
offsets
:
l
ist
[
l
ist
[
int
]])
->
None
:
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
...
...
@@ -1246,7 +1246,7 @@ def get_flash_mla_metadata(
cache_seqlens
:
torch
.
Tensor
,
num_heads_per_head_k
:
int
,
num_heads_k
:
int
,
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
...
...
@@ -1272,7 +1272,7 @@ def flash_mla_with_kvcache(
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
...
...
vllm/_ipex_ops.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
torch
...
...
@@ -18,7 +18,7 @@ class ipex_ops:
@
staticmethod
def
_reshape_activation_tensor
(
x
:
torch
.
Tensor
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x
:
torch
.
Tensor
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num
=
x
.
size
(
0
)
d
=
x
.
size
(
1
)
//
2
x
=
x
.
reshape
(
num
,
2
,
d
)
...
...
@@ -213,8 +213,8 @@ class ipex_ops:
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
@
staticmethod
def
copy_blocks
(
key_caches
:
L
ist
[
torch
.
Tensor
],
value_caches
:
L
ist
[
torch
.
Tensor
],
def
copy_blocks
(
key_caches
:
l
ist
[
torch
.
Tensor
],
value_caches
:
l
ist
[
torch
.
Tensor
],
block_mapping
:
torch
.
Tensor
)
->
None
:
torch
.
xpu
.
copy_blocks
(
# type: ignore
key_caches
,
...
...
vllm/beam_search.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
from
vllm.sequence
import
Logprob
...
...
@@ -17,14 +17,14 @@ class BeamSearchSequence:
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens
:
L
ist
[
int
]
logprobs
:
L
ist
[
D
ict
[
int
,
Logprob
]]
tokens
:
l
ist
[
int
]
logprobs
:
l
ist
[
d
ict
[
int
,
Logprob
]]
cum_logprob
:
float
=
0.0
text
:
Optional
[
str
]
=
None
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
mm_processor_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
mm_processor_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
@
dataclass
...
...
@@ -33,20 +33,20 @@ class BeamSearchOutput:
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences
:
L
ist
[
BeamSearchSequence
]
sequences
:
l
ist
[
BeamSearchSequence
]
class
BeamSearchInstance
:
def
__init__
(
self
,
prompt_tokens
:
L
ist
[
int
]):
self
.
beams
:
L
ist
[
BeamSearchSequence
]
=
[
def
__init__
(
self
,
prompt_tokens
:
l
ist
[
int
]):
self
.
beams
:
l
ist
[
BeamSearchSequence
]
=
[
BeamSearchSequence
(
tokens
=
prompt_tokens
,
logprobs
=
[])
]
self
.
completed
:
L
ist
[
BeamSearchSequence
]
=
[]
self
.
completed
:
l
ist
[
BeamSearchSequence
]
=
[]
def
get_beam_search_score
(
tokens
:
L
ist
[
int
],
tokens
:
l
ist
[
int
],
cumulative_logprob
:
float
,
eos_token_id
:
int
,
length_penalty
:
float
=
1.0
,
...
...
vllm/config.py
View file @
cf069aa8
...
...
@@ -7,13 +7,14 @@ import hashlib
import
json
import
sys
import
warnings
from
collections
import
Counter
from
collections.abc
import
Mapping
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
,
replace
from
importlib.util
import
find_spec
from
pathlib
import
Path
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Counter
,
Dict
,
Final
,
List
,
Literal
,
Mapping
,
Optional
,
Protocol
,
Set
,
Tuple
,
Type
,
Union
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Final
,
Literal
,
Optional
,
Protocol
,
Union
)
import
torch
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
...
...
@@ -67,20 +68,20 @@ _ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
RunnerType
=
Literal
[
"generate"
,
"pooling"
,
"draft"
,
"transcription"
]
_RUNNER_TASKS
:
D
ict
[
RunnerType
,
L
ist
[
_ResolvedTask
]]
=
{
_RUNNER_TASKS
:
d
ict
[
RunnerType
,
l
ist
[
_ResolvedTask
]]
=
{
"generate"
:
[
"generate"
],
"pooling"
:
[
"embed"
,
"classify"
,
"score"
,
"reward"
],
"draft"
:
[
"draft"
],
"transcription"
:
[
"transcription"
],
}
_TASK_RUNNER
:
D
ict
[
_ResolvedTask
,
RunnerType
]
=
{
_TASK_RUNNER
:
d
ict
[
_ResolvedTask
,
RunnerType
]
=
{
task
:
runner
for
runner
,
tasks
in
_RUNNER_TASKS
.
items
()
for
task
in
tasks
}
HfOverrides
=
Union
[
D
ict
[
str
,
Any
],
Callable
[[
PretrainedConfig
],
HfOverrides
=
Union
[
d
ict
[
str
,
Any
],
Callable
[[
PretrainedConfig
],
PretrainedConfig
]]
...
...
@@ -92,7 +93,7 @@ class SupportsHash(Protocol):
class
SupportsMetricsInfo
(
Protocol
):
def
metrics_info
(
self
)
->
D
ict
[
str
,
str
]:
def
metrics_info
(
self
)
->
d
ict
[
str
,
str
]:
...
...
...
@@ -209,7 +210,7 @@ class ModelConfig:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
factors
.
append
(
self
.
model
)
factors
.
append
(
self
.
dtype
)
factors
.
append
(
self
.
quantization
)
...
...
@@ -233,7 +234,7 @@ class ModelConfig:
allowed_local_media_path
:
str
=
""
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
max_model_len
:
Optional
[
int
]
=
None
,
...
...
@@ -244,19 +245,19 @@ class ModelConfig:
max_logprobs
:
int
=
20
,
disable_sliding_window
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
served_model_name
:
Optional
[
Union
[
str
,
L
ist
[
str
]]]
=
None
,
served_model_name
:
Optional
[
Union
[
str
,
l
ist
[
str
]]]
=
None
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
use_async_output_proc
:
bool
=
True
,
config_format
:
ConfigFormat
=
ConfigFormat
.
AUTO
,
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
mm_processor_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
disable_mm_preprocessor_cache
:
bool
=
False
,
override_neuron_config
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
override_neuron_config
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
override_pooler_config
:
Optional
[
"PoolerConfig"
]
=
None
,
logits_processor_pattern
:
Optional
[
str
]
=
None
,
generation_config
:
Optional
[
str
]
=
None
,
enable_sleep_mode
:
bool
=
False
,
override_generation_config
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
override_generation_config
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
model_impl
:
Union
[
str
,
ModelImpl
]
=
ModelImpl
.
AUTO
,
)
->
None
:
self
.
model
=
model
...
...
@@ -283,7 +284,7 @@ class ModelConfig:
hf_overrides_fn
=
None
if
rope_scaling
is
not
None
:
hf_override
:
D
ict
[
str
,
Any
]
=
{
"rope_scaling"
:
rope_scaling
}
hf_override
:
d
ict
[
str
,
Any
]
=
{
"rope_scaling"
:
rope_scaling
}
hf_overrides_kw
.
update
(
hf_override
)
msg
=
(
"`--rope-scaling` will be removed in a future release. "
f
"'Please instead use `--hf-overrides '
{
hf_override
!
r
}
'`"
)
...
...
@@ -505,8 +506,8 @@ class ModelConfig:
def
_get_preferred_task
(
self
,
architectures
:
L
ist
[
str
],
supported_tasks
:
S
et
[
_ResolvedTask
],
architectures
:
l
ist
[
str
],
supported_tasks
:
s
et
[
_ResolvedTask
],
)
->
Optional
[
_ResolvedTask
]:
model_id
=
self
.
model
if
get_pooling_config
(
model_id
,
self
.
revision
):
...
...
@@ -516,7 +517,7 @@ class ModelConfig:
if
self
.
registry
.
is_transcription_model
(
architectures
):
return
"transcription"
suffix_to_preferred_task
:
L
ist
[
T
uple
[
str
,
_ResolvedTask
]]
=
[
suffix_to_preferred_task
:
l
ist
[
t
uple
[
str
,
_ResolvedTask
]]
=
[
# Other models follow this pattern
(
"ForCausalLM"
,
"generate"
),
(
"ForConditionalGeneration"
,
"generate"
),
...
...
@@ -537,27 +538,27 @@ class ModelConfig:
def
_resolve_task
(
self
,
task_option
:
Union
[
TaskOption
,
Literal
[
"draft"
]],
)
->
T
uple
[
S
et
[
_ResolvedTask
],
_ResolvedTask
]:
)
->
t
uple
[
s
et
[
_ResolvedTask
],
_ResolvedTask
]:
if
task_option
==
"draft"
:
return
{
"draft"
},
"draft"
registry
=
self
.
registry
architectures
=
self
.
architectures
runner_support
:
D
ict
[
RunnerType
,
bool
]
=
{
runner_support
:
d
ict
[
RunnerType
,
bool
]
=
{
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"transcription"
:
registry
.
is_transcription_model
(
architectures
),
"generate"
:
registry
.
is_text_generation_model
(
architectures
),
"pooling"
:
registry
.
is_pooling_model
(
architectures
),
}
supported_runner_types_lst
:
L
ist
[
RunnerType
]
=
[
supported_runner_types_lst
:
l
ist
[
RunnerType
]
=
[
runner_type
for
runner_type
,
is_supported
in
runner_support
.
items
()
if
is_supported
]
supported_tasks_lst
:
L
ist
[
_ResolvedTask
]
=
[
supported_tasks_lst
:
l
ist
[
_ResolvedTask
]
=
[
task
for
runner_type
in
supported_runner_types_lst
for
task
in
_RUNNER_TASKS
[
runner_type
]
]
...
...
@@ -767,7 +768,7 @@ class ModelConfig:
self
.
use_async_output_proc
=
False
def
get_hf_config_sliding_window
(
self
)
->
Union
[
Optional
[
int
],
L
ist
[
Optional
[
int
]]]:
self
)
->
Union
[
Optional
[
int
],
l
ist
[
Optional
[
int
]]]:
"""Get the sliding window size, or None if disabled."""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
...
...
@@ -778,7 +779,7 @@ class ModelConfig:
return
None
return
getattr
(
self
.
hf_text_config
,
"sliding_window"
,
None
)
def
get_sliding_window
(
self
)
->
Optional
[
Union
[
int
,
L
ist
[
Optional
[
int
]]]]:
def
get_sliding_window
(
self
)
->
Optional
[
Union
[
int
,
l
ist
[
Optional
[
int
]]]]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.
...
...
@@ -888,7 +889,7 @@ class ModelConfig:
return
num_heads
//
parallel_config
.
tensor_parallel_size
def
get_layers_start_end_indices
(
self
,
parallel_config
:
"ParallelConfig"
)
->
T
uple
[
int
,
int
]:
self
,
parallel_config
:
"ParallelConfig"
)
->
t
uple
[
int
,
int
]:
from
vllm.distributed.utils
import
get_pp_indices
if
self
.
hf_text_config
.
model_type
==
"deepseek_mtp"
:
total_num_hidden_layers
=
getattr
(
self
.
hf_text_config
,
...
...
@@ -949,7 +950,7 @@ class ModelConfig:
return
self
.
multimodal_config
def
try_get_generation_config
(
self
)
->
D
ict
[
str
,
Any
]:
def
try_get_generation_config
(
self
)
->
d
ict
[
str
,
Any
]:
if
self
.
generation_config
is
None
or
self
.
generation_config
==
"auto"
:
config
=
try_get_generation_config
(
self
.
hf_config_path
or
self
.
model
,
...
...
@@ -967,7 +968,7 @@ class ModelConfig:
return
config
.
to_diff_dict
()
def
get_diff_sampling_param
(
self
)
->
D
ict
[
str
,
Any
]:
def
get_diff_sampling_param
(
self
)
->
d
ict
[
str
,
Any
]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
...
...
@@ -975,7 +976,7 @@ class ModelConfig:
set, an empty dictionary is returned.
Returns:
D
ict[str, Any]: A dictionary with the differing sampling
d
ict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
empty dictionary.
"""
...
...
@@ -1032,7 +1033,7 @@ class ModelConfig:
return
self
.
is_deepseek_mla
and
not
envs
.
VLLM_MLA_DISABLE
@
property
def
supported_runner_types
(
self
)
->
S
et
[
RunnerType
]:
def
supported_runner_types
(
self
)
->
s
et
[
RunnerType
]:
return
{
_TASK_RUNNER
[
task
]
for
task
in
self
.
supported_tasks
}
@
property
...
...
@@ -1075,7 +1076,7 @@ class CacheConfig:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
factors
.
append
(
self
.
cache_dtype
)
# `cpu_offload_gb` does not use `torch.compile` yet.
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
...
...
@@ -1183,7 +1184,7 @@ class TokenizerPoolConfig:
pool type.
"""
pool_size
:
int
pool_type
:
Union
[
str
,
T
ype
[
"BaseTokenizerGroup"
]]
pool_type
:
Union
[
str
,
t
ype
[
"BaseTokenizerGroup"
]]
extra_config
:
dict
def
compute_hash
(
self
)
->
str
:
...
...
@@ -1200,7 +1201,7 @@ class TokenizerPoolConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
...
...
@@ -1214,7 +1215,7 @@ class TokenizerPoolConfig:
@
classmethod
def
create_config
(
cls
,
tokenizer_pool_size
:
int
,
tokenizer_pool_type
:
Union
[
str
,
T
ype
[
"BaseTokenizerGroup"
]],
tokenizer_pool_type
:
Union
[
str
,
t
ype
[
"BaseTokenizerGroup"
]],
tokenizer_pool_extra_config
:
Optional
[
Union
[
str
,
dict
]]
)
->
Optional
[
"TokenizerPoolConfig"
]:
"""Create a TokenizerPoolConfig from the given parameters.
...
...
@@ -1285,7 +1286,7 @@ class LoadConfig:
download_dir
:
Optional
[
str
]
=
None
model_loader_extra_config
:
Optional
[
Union
[
str
,
dict
]]
=
field
(
default_factory
=
dict
)
ignore_patterns
:
Optional
[
Union
[
L
ist
[
str
],
str
]]
=
None
ignore_patterns
:
Optional
[
Union
[
l
ist
[
str
],
str
]]
=
None
def
compute_hash
(
self
)
->
str
:
"""
...
...
@@ -1301,7 +1302,7 @@ class LoadConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
...
...
@@ -1359,7 +1360,7 @@ class ParallelConfig:
# to "ray" if Ray is installed and fail otherwise. Note that tpu
# and hpu only support Ray for distributed inference.
distributed_executor_backend
:
Optional
[
Union
[
str
,
T
ype
[
"ExecutorBase"
]]]
=
None
t
ype
[
"ExecutorBase"
]]]
=
None
# the full name of the worker class to use. If "auto", the worker class
# will be determined based on the platform.
...
...
@@ -1423,7 +1424,7 @@ class ParallelConfig:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
factors
.
append
(
self
.
pipeline_parallel_size
)
factors
.
append
(
self
.
tensor_parallel_size
)
return
hashlib
.
sha256
(
str
(
factors
).
encode
()).
hexdigest
()
...
...
@@ -1600,7 +1601,7 @@ class SchedulerConfig:
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
# or "mod.custom_class".
scheduler_cls
:
Union
[
str
,
T
ype
[
object
]]
=
"vllm.core.scheduler.Scheduler"
scheduler_cls
:
Union
[
str
,
t
ype
[
object
]]
=
"vllm.core.scheduler.Scheduler"
def
compute_hash
(
self
)
->
str
:
"""
...
...
@@ -1616,7 +1617,7 @@ class SchedulerConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
...
...
@@ -1752,7 +1753,7 @@ class DeviceConfig:
# no factors to consider.
# the device/platform information will be summarized
# by torch/vllm automatically.
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
...
...
@@ -1798,7 +1799,7 @@ class SpeculativeConfig:
"""
# no factors to consider.
# spec decode does not use `torch.compile` yet.
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
...
...
@@ -2261,7 +2262,7 @@ class LoRAConfig:
lora_extra_vocab_size
:
int
=
256
# This is a constant.
lora_vocab_padding_size
:
ClassVar
[
int
]
=
256
long_lora_scaling_factors
:
Optional
[
T
uple
[
float
]]
=
None
long_lora_scaling_factors
:
Optional
[
t
uple
[
float
]]
=
None
bias_enabled
:
bool
=
False
def
compute_hash
(
self
)
->
str
:
...
...
@@ -2278,7 +2279,7 @@ class LoRAConfig:
"""
# no factors to consider.
# LoRA is not compatible with `torch.compile` .
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
...
...
@@ -2350,7 +2351,7 @@ class PromptAdapterConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
...
...
@@ -2395,7 +2396,7 @@ class MultiModalConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
...
...
@@ -2431,7 +2432,7 @@ class PoolerConfig:
are returned.
"""
returned_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
returned_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
"""
A list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of ``good_token`` and ``bad_token`` in the
...
...
@@ -2452,7 +2453,7 @@ class PoolerConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
...
...
@@ -2469,7 +2470,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16"
:
torch
.
bfloat16
,
}
_ROCM_NOT_SUPPORTED_DTYPE
:
L
ist
[
str
]
=
[]
#
_ROCM_NOT_SUPPORTED_DTYPE
:
l
ist
[
str
]
=
[]
#
def
_get_and_verify_dtype
(
...
...
@@ -2558,7 +2559,7 @@ def _get_and_verify_max_len(
hf_config
:
PretrainedConfig
,
max_model_len
:
Optional
[
int
],
disable_sliding_window
:
bool
,
sliding_window_len
:
Optional
[
Union
[
int
,
L
ist
[
Optional
[
int
]]]],
sliding_window_len
:
Optional
[
Union
[
int
,
l
ist
[
Optional
[
int
]]]],
spec_target_max_model_len
:
Optional
[
int
]
=
None
,
encoder_config
:
Optional
[
Any
]
=
None
,
)
->
int
:
...
...
@@ -2684,7 +2685,7 @@ def _get_and_verify_max_len(
def
get_min_sliding_window
(
sliding_window
:
Union
[
int
,
L
ist
[
Optional
[
int
]]])
->
int
:
sliding_window
:
Union
[
int
,
l
ist
[
Optional
[
int
]]])
->
int
:
if
isinstance
(
sliding_window
,
list
):
return
min
(
s
for
s
in
sliding_window
if
s
is
not
None
)
...
...
@@ -2692,7 +2693,7 @@ def get_min_sliding_window(
def
get_served_model_name
(
model
:
str
,
served_model_name
:
Optional
[
Union
[
str
,
L
ist
[
str
]]]):
served_model_name
:
Optional
[
Union
[
str
,
l
ist
[
str
]]]):
"""
If the input is a non-empty list, the first model_name in
`served_model_name` is taken.
...
...
@@ -2731,7 +2732,7 @@ class DecodingConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
...
...
@@ -2774,7 +2775,7 @@ class ObservabilityConfig:
"""
# no factors to consider.
# this config will not affect the computation graph.
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
...
...
@@ -2833,7 +2834,7 @@ class KVTransferConfig(BaseModel):
"""
# no factors to consider.
# this config will not affect the computation graph.
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
return
hash_str
...
...
@@ -2930,7 +2931,7 @@ class CompilationConfig(BaseModel):
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
- None (default): capture sizes are inferred from vllm config.
-
L
ist[int]: capture sizes are specified as given.
-
l
ist[int]: capture sizes are specified as given.
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
...
...
@@ -2972,17 +2973,17 @@ class CompilationConfig(BaseModel):
debug_dump_path
:
str
=
""
cache_dir
:
str
=
""
backend
:
str
=
""
custom_ops
:
L
ist
[
str
]
=
Field
(
default_factory
=
list
)
splitting_ops
:
L
ist
[
str
]
=
Field
(
default
=
None
)
# type: ignore
custom_ops
:
l
ist
[
str
]
=
Field
(
default_factory
=
list
)
splitting_ops
:
l
ist
[
str
]
=
Field
(
default
=
None
)
# type: ignore
use_inductor
:
bool
=
True
compile_sizes
:
Optional
[
L
ist
[
Union
[
int
,
str
]]]
=
Field
(
default
=
None
)
inductor_compile_config
:
D
ict
=
Field
(
default_factory
=
dict
)
inductor_passes
:
D
ict
[
str
,
str
]
=
Field
(
default_factory
=
dict
)
compile_sizes
:
Optional
[
l
ist
[
Union
[
int
,
str
]]]
=
Field
(
default
=
None
)
inductor_compile_config
:
d
ict
=
Field
(
default_factory
=
dict
)
inductor_passes
:
d
ict
[
str
,
str
]
=
Field
(
default_factory
=
dict
)
use_cudagraph
:
bool
=
False
cudagraph_num_of_warmups
:
int
=
0
cudagraph_capture_sizes
:
Optional
[
L
ist
[
int
]]
=
None
cudagraph_capture_sizes
:
Optional
[
l
ist
[
int
]]
=
None
cudagraph_copy_inputs
:
bool
=
False
class
PassConfig
(
BaseModel
):
...
...
@@ -2998,7 +2999,7 @@ class CompilationConfig(BaseModel):
- enable_noop: whether to enable the custom no-op elimination pass.
TODO(luka) better pass enabling system.
"""
dump_graph_stages
:
L
ist
[
str
]
=
Field
(
default_factory
=
list
)
dump_graph_stages
:
l
ist
[
str
]
=
Field
(
default_factory
=
list
)
dump_graph_dir
:
Path
=
Field
(
default
=
Path
(
"."
))
enable_fusion
:
bool
=
True
enable_noop
:
bool
=
True
...
...
@@ -3026,20 +3027,20 @@ class CompilationConfig(BaseModel):
max_capture_size
:
int
=
PrivateAttr
local_cache_dir
:
str
=
PrivateAttr
# local cache dir for each rank
# optimization:
# Intuitively, bs_to_padded_graph_size should be
D
ict[int, int].
# Intuitively, bs_to_padded_graph_size should be
d
ict[int, int].
# since we know all keys are in a range [0, max_capture_size],
# we can optimize it to
L
ist[int] for better lookup performance.
bs_to_padded_graph_size
:
L
ist
[
int
]
=
PrivateAttr
# we can optimize it to
l
ist[int] for better lookup performance.
bs_to_padded_graph_size
:
l
ist
[
int
]
=
PrivateAttr
# keep track of enabled and disabled custom ops
enabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
disabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
traced_files
:
S
et
[
str
]
=
PrivateAttr
traced_files
:
s
et
[
str
]
=
PrivateAttr
compilation_time
:
float
=
PrivateAttr
# Per-model forward context
# Map from layer name to the attention cls
static_forward_context
:
D
ict
[
str
,
Any
]
=
PrivateAttr
static_forward_context
:
d
ict
[
str
,
Any
]
=
PrivateAttr
def
compute_hash
(
self
)
->
str
:
"""
...
...
@@ -3053,7 +3054,7 @@ class CompilationConfig(BaseModel):
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
factors
.
append
(
self
.
level
)
factors
.
append
(
self
.
backend
)
factors
.
append
(
self
.
custom_ops
)
...
...
@@ -3150,7 +3151,7 @@ class CompilationConfig(BaseModel):
return
VllmBackend
(
vllm_config
)
def
init_with_cudagraph_sizes
(
self
,
cudagraph_capture_sizes
:
L
ist
[
int
])
->
None
:
cudagraph_capture_sizes
:
l
ist
[
int
])
->
None
:
"""To complete the initialization of config,
we need to know the cudagraph sizes."""
...
...
@@ -3243,10 +3244,10 @@ class VllmConfig:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors
:
L
ist
[
Any
]
=
[]
factors
:
l
ist
[
Any
]
=
[]
# summarize vllm config
vllm_factors
:
L
ist
[
Any
]
=
[]
vllm_factors
:
l
ist
[
Any
]
=
[]
from
vllm
import
__version__
vllm_factors
.
append
(
__version__
)
if
self
.
model_config
:
...
...
vllm/connections.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
Mapping
,
MutableMapping
from
pathlib
import
Path
from
typing
import
Mapping
,
MutableMapping
,
Optional
from
typing
import
Optional
from
urllib.parse
import
urlparse
import
aiohttp
...
...
vllm/entrypoints/api_server.py
View file @
cf069aa8
...
...
@@ -10,7 +10,8 @@ import asyncio
import
json
import
ssl
from
argparse
import
Namespace
from
typing
import
Any
,
AsyncGenerator
,
Optional
from
collections.abc
import
AsyncGenerator
from
typing
import
Any
,
Optional
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
...
...
vllm/entrypoints/chat_utils.py
View file @
cf069aa8
...
...
@@ -5,10 +5,11 @@ import codecs
import
json
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
,
deque
from
collections.abc
import
Awaitable
,
Iterable
from
functools
import
cache
,
lru_cache
,
partial
from
pathlib
import
Path
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
Generic
,
I
tera
ble
,
List
,
Literal
,
Optional
,
Tuple
,
TypeVar
,
Union
,
cast
)
from
typing
import
(
Any
,
Callable
,
Generic
,
Li
tera
l
,
Optional
,
TypeVar
,
Union
,
cast
)
import
jinja2.nodes
import
transformers.utils.chat_template_utils
as
hf_chat_utils
...
...
@@ -117,7 +118,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
role
:
Required
[
str
]
"""The role of the message's author."""
content
:
Union
[
str
,
L
ist
[
ChatCompletionContentPartParam
]]
content
:
Union
[
str
,
l
ist
[
ChatCompletionContentPartParam
]]
"""The contents of the message."""
name
:
str
...
...
@@ -143,7 +144,7 @@ class ConversationMessage(TypedDict, total=False):
role
:
Required
[
str
]
"""The role of the message's author."""
content
:
Union
[
Optional
[
str
],
L
ist
[
D
ict
[
str
,
str
]]]
content
:
Union
[
Optional
[
str
],
l
ist
[
d
ict
[
str
,
str
]]]
"""The contents of the message"""
tool_call_id
:
Optional
[
str
]
...
...
@@ -495,13 +496,13 @@ class BaseMultiModalContentParser(ABC):
super
().
__init__
()
# multimodal placeholder_string : count
self
.
_placeholder_counts
:
D
ict
[
str
,
int
]
=
defaultdict
(
lambda
:
0
)
self
.
_placeholder_counts
:
d
ict
[
str
,
int
]
=
defaultdict
(
lambda
:
0
)
def
_add_placeholder
(
self
,
placeholder
:
Optional
[
str
]):
if
placeholder
:
self
.
_placeholder_counts
[
placeholder
]
+=
1
def
mm_placeholder_counts
(
self
)
->
D
ict
[
str
,
int
]:
def
mm_placeholder_counts
(
self
)
->
d
ict
[
str
,
int
]:
return
dict
(
self
.
_placeholder_counts
)
@
abstractmethod
...
...
@@ -652,12 +653,12 @@ def load_chat_template(
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def
_get_full_multimodal_text_prompt
(
placeholder_counts
:
D
ict
[
str
,
int
],
def
_get_full_multimodal_text_prompt
(
placeholder_counts
:
d
ict
[
str
,
int
],
text_prompt
:
str
)
->
str
:
"""Combine multimodal prompts for a multimodal language model."""
# Look through the text prompt to check for missing placeholders
missing_placeholders
:
L
ist
[
str
]
=
[]
missing_placeholders
:
l
ist
[
str
]
=
[]
for
placeholder
in
placeholder_counts
:
# For any existing placeholder in the text prompt, we leave it as is
...
...
@@ -684,10 +685,10 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
_VideoParser
=
partial
(
cast
,
ChatCompletionContentPartVideoParam
)
_ContentPart
:
TypeAlias
=
Union
[
str
,
D
ict
[
str
,
str
],
InputAudio
]
_ContentPart
:
TypeAlias
=
Union
[
str
,
d
ict
[
str
,
str
],
InputAudio
]
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP
:
D
ict
[
MM_PARSER_MAP
:
d
ict
[
str
,
Callable
[[
ChatCompletionContentPartParam
],
_ContentPart
],
]
=
{
...
...
@@ -749,7 +750,7 @@ def _parse_chat_message_content_mm_part(
part
)
return
"audio_url"
,
audio_params
.
get
(
"audio_url"
,
""
)
if
part
.
get
(
"input_audio"
)
is
not
None
:
input_audio_params
=
cast
(
D
ict
[
str
,
str
],
part
)
input_audio_params
=
cast
(
d
ict
[
str
,
str
],
part
)
return
"input_audio"
,
input_audio_params
if
part
.
get
(
"video_url"
)
is
not
None
:
video_params
=
cast
(
CustomChatCompletionContentSimpleVideoParam
,
...
...
@@ -773,7 +774,7 @@ def _parse_chat_message_content_parts(
mm_tracker
:
BaseMultiModalItemTracker
,
*
,
wrap_dicts
:
bool
,
)
->
L
ist
[
ConversationMessage
]:
)
->
l
ist
[
ConversationMessage
]:
content
=
list
[
_ContentPart
]()
mm_parser
=
mm_tracker
.
create_parser
()
...
...
@@ -791,7 +792,7 @@ def _parse_chat_message_content_parts(
# Parsing wraps images and texts as interleaved dictionaries
return
[
ConversationMessage
(
role
=
role
,
content
=
content
)]
# type: ignore
texts
=
cast
(
L
ist
[
str
],
content
)
texts
=
cast
(
l
ist
[
str
],
content
)
text_prompt
=
"
\n
"
.
join
(
texts
)
mm_placeholder_counts
=
mm_parser
.
mm_placeholder_counts
()
if
mm_placeholder_counts
:
...
...
@@ -866,7 +867,7 @@ def _parse_chat_message_content(
message
:
ChatCompletionMessageParam
,
mm_tracker
:
BaseMultiModalItemTracker
,
content_format
:
_ChatTemplateContentFormat
,
)
->
L
ist
[
ConversationMessage
]:
)
->
l
ist
[
ConversationMessage
]:
role
=
message
[
"role"
]
content
=
message
.
get
(
"content"
)
...
...
@@ -900,7 +901,7 @@ def _parse_chat_message_content(
return
result
def
_postprocess_messages
(
messages
:
L
ist
[
ConversationMessage
])
->
None
:
def
_postprocess_messages
(
messages
:
l
ist
[
ConversationMessage
])
->
None
:
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
...
...
@@ -916,12 +917,12 @@ def _postprocess_messages(messages: List[ConversationMessage]) -> None:
def
parse_chat_messages
(
messages
:
L
ist
[
ChatCompletionMessageParam
],
messages
:
l
ist
[
ChatCompletionMessageParam
],
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
,
content_format
:
_ChatTemplateContentFormat
,
)
->
T
uple
[
L
ist
[
ConversationMessage
],
Optional
[
MultiModalDataDict
]]:
conversation
:
L
ist
[
ConversationMessage
]
=
[]
)
->
t
uple
[
l
ist
[
ConversationMessage
],
Optional
[
MultiModalDataDict
]]:
conversation
:
l
ist
[
ConversationMessage
]
=
[]
mm_tracker
=
MultiModalItemTracker
(
model_config
,
tokenizer
)
for
msg
in
messages
:
...
...
@@ -939,12 +940,12 @@ def parse_chat_messages(
def
parse_chat_messages_futures
(
messages
:
L
ist
[
ChatCompletionMessageParam
],
messages
:
l
ist
[
ChatCompletionMessageParam
],
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
,
content_format
:
_ChatTemplateContentFormat
,
)
->
T
uple
[
L
ist
[
ConversationMessage
],
Awaitable
[
Optional
[
MultiModalDataDict
]]]:
conversation
:
L
ist
[
ConversationMessage
]
=
[]
)
->
t
uple
[
l
ist
[
ConversationMessage
],
Awaitable
[
Optional
[
MultiModalDataDict
]]]:
conversation
:
l
ist
[
ConversationMessage
]
=
[]
mm_tracker
=
AsyncMultiModalItemTracker
(
model_config
,
tokenizer
)
for
msg
in
messages
:
...
...
@@ -963,7 +964,7 @@ def parse_chat_messages_futures(
def
apply_hf_chat_template
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
conversation
:
L
ist
[
ConversationMessage
],
conversation
:
l
ist
[
ConversationMessage
],
chat_template
:
Optional
[
str
],
*
,
tokenize
:
bool
=
False
,
# Different from HF's default
...
...
@@ -985,10 +986,10 @@ def apply_hf_chat_template(
def
apply_mistral_chat_template
(
tokenizer
:
MistralTokenizer
,
messages
:
L
ist
[
ChatCompletionMessageParam
],
messages
:
l
ist
[
ChatCompletionMessageParam
],
chat_template
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
,
)
->
L
ist
[
int
]:
)
->
l
ist
[
int
]:
if
chat_template
is
not
None
:
logger
.
warning_once
(
"'chat_template' cannot be overridden for mistral tokenizer."
)
...
...
vllm/entrypoints/cli/openai.py
View file @
cf069aa8
...
...
@@ -5,7 +5,7 @@ import argparse
import
os
import
signal
import
sys
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
from
openai
import
OpenAI
from
openai.types.chat
import
ChatCompletionMessageParam
...
...
@@ -23,7 +23,7 @@ def _register_signal_handlers():
signal
.
signal
(
signal
.
SIGTSTP
,
signal_handler
)
def
_interactive_cli
(
args
:
argparse
.
Namespace
)
->
T
uple
[
str
,
OpenAI
]:
def
_interactive_cli
(
args
:
argparse
.
Namespace
)
->
t
uple
[
str
,
OpenAI
]:
_register_signal_handlers
()
base_url
=
args
.
url
...
...
@@ -43,7 +43,7 @@ def _interactive_cli(args: argparse.Namespace) -> Tuple[str, OpenAI]:
def
chat
(
system_prompt
:
Optional
[
str
],
model_name
:
str
,
client
:
OpenAI
)
->
None
:
conversation
:
L
ist
[
ChatCompletionMessageParam
]
=
[]
conversation
:
l
ist
[
ChatCompletionMessageParam
]
=
[]
if
system_prompt
is
not
None
:
conversation
.
append
({
"role"
:
"system"
,
"content"
:
system_prompt
})
...
...
@@ -100,7 +100,7 @@ class ChatCommand(CLISubcommand):
def
cmd
(
args
:
argparse
.
Namespace
)
->
None
:
model_name
,
client
=
_interactive_cli
(
args
)
system_prompt
=
args
.
system_prompt
conversation
:
L
ist
[
ChatCompletionMessageParam
]
=
[]
conversation
:
l
ist
[
ChatCompletionMessageParam
]
=
[]
if
system_prompt
is
not
None
:
conversation
.
append
({
"role"
:
"system"
,
"content"
:
system_prompt
})
...
...
@@ -168,5 +168,5 @@ class CompleteCommand(CLISubcommand):
return
complete_parser
def
cmd_init
()
->
L
ist
[
CLISubcommand
]:
def
cmd_init
()
->
l
ist
[
CLISubcommand
]:
return
[
ChatCommand
(),
CompleteCommand
()]
vllm/entrypoints/cli/serve.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
argparse
from
typing
import
List
import
uvloop
...
...
@@ -59,5 +58,5 @@ class ServeSubcommand(CLISubcommand):
return
make_arg_parser
(
serve_parser
)
def
cmd_init
()
->
L
ist
[
CLISubcommand
]:
def
cmd_init
()
->
l
ist
[
CLISubcommand
]:
return
[
ServeSubcommand
()]
vllm/entrypoints/llm.py
View file @
cf069aa8
...
...
@@ -2,9 +2,9 @@
import
itertools
import
warnings
from
collections.abc
import
Sequence
from
contextlib
import
contextmanager
from
typing
import
(
Any
,
Callable
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
,
cast
,
overload
)
from
typing
import
Any
,
Callable
,
ClassVar
,
Optional
,
Union
,
cast
,
overload
import
cloudpickle
import
torch.nn
as
nn
...
...
@@ -177,11 +177,11 @@ class LLM:
disable_custom_all_reduce
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
mm_processor_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
# After positional args are removed, move this right below `model`
task
:
TaskOption
=
"auto"
,
override_pooler_config
:
Optional
[
PoolerConfig
]
=
None
,
compilation_config
:
Optional
[
Union
[
int
,
D
ict
[
str
,
Any
]]]
=
None
,
compilation_config
:
Optional
[
Union
[
int
,
d
ict
[
str
,
Any
]]]
=
None
,
**
kwargs
,
)
->
None
:
'''
...
...
@@ -246,7 +246,7 @@ class LLM:
self
.
request_counter
=
Counter
()
@
staticmethod
def
get_engine_class
()
->
T
ype
[
LLMEngine
]:
def
get_engine_class
()
->
t
ype
[
LLMEngine
]:
if
envs
.
VLLM_USE_V1
:
# Lazy import: the v1 package isn't distributed
from
vllm.v1.engine.llm_engine
import
LLMEngine
as
V1LLMEngine
...
...
@@ -283,11 +283,11 @@ class LLM:
Sequence
[
SamplingParams
]]]
=
None
,
*
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
GuidedDecodingRequest
]]
=
None
,
)
->
L
ist
[
RequestOutput
]:
)
->
l
ist
[
RequestOutput
]:
...
@
overload
# LEGACY: single (prompt + optional token ids)
...
...
@@ -296,30 +296,30 @@ class LLM:
self
,
prompts
:
str
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
L
ist
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
l
ist
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
GuidedDecodingRequest
]]
=
None
,
)
->
L
ist
[
RequestOutput
]:
)
->
l
ist
[
RequestOutput
]:
...
@
overload
# LEGACY: multi (prompt + optional token ids)
@
deprecated
(
"'prompt_token_ids' will become part of 'prompts'"
)
def
generate
(
self
,
prompts
:
L
ist
[
str
],
prompts
:
l
ist
[
str
],
sampling_params
:
Optional
[
Union
[
SamplingParams
,
L
ist
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
L
ist
[
L
ist
[
int
]]]
=
None
,
l
ist
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
l
ist
[
l
ist
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
GuidedDecodingRequest
]]
=
None
,
)
->
L
ist
[
RequestOutput
]:
)
->
l
ist
[
RequestOutput
]:
...
@
overload
# LEGACY: single (token ids + optional prompt)
...
...
@@ -328,32 +328,32 @@ class LLM:
self
,
prompts
:
Optional
[
str
]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
L
ist
[
SamplingParams
]]]
=
None
,
l
ist
[
SamplingParams
]]]
=
None
,
*
,
prompt_token_ids
:
L
ist
[
int
],
prompt_token_ids
:
l
ist
[
int
],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
GuidedDecodingRequest
]]
=
None
,
)
->
L
ist
[
RequestOutput
]:
)
->
l
ist
[
RequestOutput
]:
...
@
overload
# LEGACY: multi (token ids + optional prompt)
@
deprecated
(
"'prompt_token_ids' will become part of 'prompts'"
)
def
generate
(
self
,
prompts
:
Optional
[
L
ist
[
str
]]
=
None
,
prompts
:
Optional
[
l
ist
[
str
]]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
L
ist
[
SamplingParams
]]]
=
None
,
l
ist
[
SamplingParams
]]]
=
None
,
*
,
prompt_token_ids
:
L
ist
[
L
ist
[
int
]],
prompt_token_ids
:
l
ist
[
l
ist
[
int
]],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
GuidedDecodingRequest
]]
=
None
,
)
->
L
ist
[
RequestOutput
]:
)
->
l
ist
[
RequestOutput
]:
...
@
overload
# LEGACY: single or multi token ids [pos-only]
...
...
@@ -362,13 +362,13 @@ class LLM:
self
,
prompts
:
None
,
sampling_params
:
None
,
prompt_token_ids
:
Union
[
L
ist
[
int
],
L
ist
[
L
ist
[
int
]]],
prompt_token_ids
:
Union
[
l
ist
[
int
],
l
ist
[
l
ist
[
int
]]],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
GuidedDecodingRequest
]]
=
None
,
)
->
L
ist
[
RequestOutput
]:
)
->
l
ist
[
RequestOutput
]:
...
@
deprecate_kwargs
(
...
...
@@ -379,17 +379,17 @@ class LLM:
def
generate
(
self
,
prompts
:
Union
[
Union
[
PromptType
,
Sequence
[
PromptType
]],
Optional
[
Union
[
str
,
L
ist
[
str
]]]]
=
None
,
Optional
[
Union
[
str
,
l
ist
[
str
]]]]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
Union
[
L
ist
[
int
],
L
ist
[
L
ist
[
int
]]]]
=
None
,
prompt_token_ids
:
Optional
[
Union
[
l
ist
[
int
],
l
ist
[
l
ist
[
int
]]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
GuidedDecodingRequest
]]
=
None
,
priority
:
Optional
[
L
ist
[
int
]]
=
None
,
)
->
L
ist
[
RequestOutput
]:
priority
:
Optional
[
l
ist
[
int
]]
=
None
,
)
->
l
ist
[
RequestOutput
]:
"""Generates the completions for the input prompts.
This class automatically batches the given prompts, considering
...
...
@@ -440,7 +440,7 @@ class LLM:
if
prompt_token_ids
is
not
None
:
parsed_prompts
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
L
ist
[
str
]]],
prompts
),
prompts
=
cast
(
Optional
[
Union
[
str
,
l
ist
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
)
else
:
...
...
@@ -473,8 +473,8 @@ class LLM:
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
[...,
_R
]],
timeout
:
Optional
[
float
]
=
None
,
args
:
T
uple
=
(),
kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
)
->
L
ist
[
_R
]:
args
:
t
uple
=
(),
kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
)
->
l
ist
[
_R
]:
"""
Execute an RPC call on all workers.
...
...
@@ -510,9 +510,9 @@ class LLM:
def
beam_search
(
self
,
prompts
:
L
ist
[
Union
[
TokensPrompt
,
TextPrompt
]],
prompts
:
l
ist
[
Union
[
TokensPrompt
,
TextPrompt
]],
params
:
BeamSearchParams
,
)
->
L
ist
[
BeamSearchOutput
]:
)
->
l
ist
[
BeamSearchOutput
]:
"""
Generate sequences using beam search.
...
...
@@ -543,7 +543,7 @@ class LLM:
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
temperature
)
instances
:
L
ist
[
BeamSearchInstance
]
=
[]
instances
:
l
ist
[
BeamSearchInstance
]
=
[]
for
prompt
in
prompts
:
if
is_token_prompt
(
prompt
):
...
...
@@ -553,12 +553,12 @@ class LLM:
instances
.
append
(
BeamSearchInstance
(
prompt_tokens
))
for
_
in
range
(
max_tokens
):
all_beams
:
L
ist
[
BeamSearchSequence
]
=
list
(
all_beams
:
l
ist
[
BeamSearchSequence
]
=
list
(
sum
((
instance
.
beams
for
instance
in
instances
),
[]))
pos
=
[
0
]
+
list
(
itertools
.
accumulate
(
len
(
instance
.
beams
)
for
instance
in
instances
))
instance_start_and_end
:
L
ist
[
T
uple
[
int
,
int
]]
=
list
(
instance_start_and_end
:
l
ist
[
t
uple
[
int
,
int
]]
=
list
(
zip
(
pos
[:
-
1
],
pos
[
1
:]))
if
len
(
all_beams
)
==
0
:
...
...
@@ -620,19 +620,19 @@ class LLM:
def
chat
(
self
,
messages
:
Union
[
L
ist
[
ChatCompletionMessageParam
],
L
ist
[
L
ist
[
ChatCompletionMessageParam
]]],
messages
:
Union
[
l
ist
[
ChatCompletionMessageParam
],
l
ist
[
l
ist
[
ChatCompletionMessageParam
]]],
sampling_params
:
Optional
[
Union
[
SamplingParams
,
L
ist
[
SamplingParams
]]]
=
None
,
l
ist
[
SamplingParams
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
add_generation_prompt
:
bool
=
True
,
continue_final_message
:
bool
=
False
,
tools
:
Optional
[
L
ist
[
D
ict
[
str
,
Any
]]]
=
None
,
mm_processor_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
)
->
L
ist
[
RequestOutput
]:
tools
:
Optional
[
l
ist
[
d
ict
[
str
,
Any
]]]
=
None
,
mm_processor_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
)
->
l
ist
[
RequestOutput
]:
"""
Generate responses for a chat conversation.
...
...
@@ -678,17 +678,17 @@ class LLM:
A list of ``RequestOutput`` objects containing the generated
responses in the same order as the input messages.
"""
list_of_messages
:
L
ist
[
L
ist
[
ChatCompletionMessageParam
]]
list_of_messages
:
l
ist
[
l
ist
[
ChatCompletionMessageParam
]]
# Handle multi and single conversations
if
is_list_of
(
messages
,
list
):
# messages is
L
ist[
L
ist[...]]
list_of_messages
=
cast
(
L
ist
[
L
ist
[
ChatCompletionMessageParam
]],
# messages is
l
ist[
l
ist[...]]
list_of_messages
=
cast
(
l
ist
[
l
ist
[
ChatCompletionMessageParam
]],
messages
)
else
:
# messages is
L
ist[...]
# messages is
l
ist[...]
list_of_messages
=
[
cast
(
L
ist
[
ChatCompletionMessageParam
],
messages
)
cast
(
l
ist
[
ChatCompletionMessageParam
],
messages
)
]
tokenizer
=
self
.
get_tokenizer
()
...
...
@@ -699,7 +699,7 @@ class LLM:
tokenizer
,
)
prompts
:
L
ist
[
Union
[
TokensPrompt
,
TextPrompt
]]
=
[]
prompts
:
l
ist
[
Union
[
TokensPrompt
,
TextPrompt
]]
=
[]
for
msgs
in
list_of_messages
:
# NOTE: _parse_chat_message_content_parts() currently doesn't
...
...
@@ -712,7 +712,7 @@ class LLM:
content_format
=
resolved_content_format
,
)
prompt_data
:
Union
[
str
,
L
ist
[
int
]]
prompt_data
:
Union
[
str
,
l
ist
[
int
]]
if
isinstance
(
tokenizer
,
MistralTokenizer
):
prompt_data
=
apply_mistral_chat_template
(
tokenizer
,
...
...
@@ -762,9 +762,9 @@ class LLM:
Sequence
[
PoolingParams
]]]
=
None
,
*
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
L
ist
[
PoolingRequestOutput
]:
)
->
l
ist
[
PoolingRequestOutput
]:
...
@
overload
# LEGACY: single (prompt + optional token ids)
...
...
@@ -774,25 +774,25 @@ class LLM:
prompts
:
str
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
L
ist
[
PoolingRequestOutput
]:
)
->
l
ist
[
PoolingRequestOutput
]:
...
@
overload
# LEGACY: multi (prompt + optional token ids)
@
deprecated
(
"'prompt_token_ids' will become part of 'prompts'"
)
def
encode
(
self
,
prompts
:
L
ist
[
str
],
prompts
:
l
ist
[
str
],
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
L
ist
[
L
ist
[
int
]]]
=
None
,
prompt_token_ids
:
Optional
[
l
ist
[
l
ist
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
L
ist
[
PoolingRequestOutput
]:
)
->
l
ist
[
PoolingRequestOutput
]:
...
@
overload
# LEGACY: single (token ids + optional prompt)
...
...
@@ -803,26 +803,26 @@ class LLM:
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
*
,
prompt_token_ids
:
L
ist
[
int
],
prompt_token_ids
:
l
ist
[
int
],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
L
ist
[
PoolingRequestOutput
]:
)
->
l
ist
[
PoolingRequestOutput
]:
...
@
overload
# LEGACY: multi (token ids + optional prompt)
@
deprecated
(
"'prompt_token_ids' will become part of 'prompts'"
)
def
encode
(
self
,
prompts
:
Optional
[
L
ist
[
str
]]
=
None
,
prompts
:
Optional
[
l
ist
[
str
]]
=
None
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
*
,
prompt_token_ids
:
L
ist
[
L
ist
[
int
]],
prompt_token_ids
:
l
ist
[
l
ist
[
int
]],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
L
ist
[
PoolingRequestOutput
]:
)
->
l
ist
[
PoolingRequestOutput
]:
...
@
overload
# LEGACY: single or multi token ids [pos-only]
...
...
@@ -831,11 +831,11 @@ class LLM:
self
,
prompts
:
None
,
pooling_params
:
None
,
prompt_token_ids
:
Union
[
L
ist
[
int
],
L
ist
[
L
ist
[
int
]]],
prompt_token_ids
:
Union
[
l
ist
[
int
],
l
ist
[
l
ist
[
int
]]],
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
L
ist
[
PoolingRequestOutput
]:
)
->
l
ist
[
PoolingRequestOutput
]:
...
@
deprecate_kwargs
(
...
...
@@ -846,14 +846,14 @@ class LLM:
def
encode
(
self
,
prompts
:
Union
[
Union
[
PromptType
,
Sequence
[
PromptType
]],
Optional
[
Union
[
str
,
L
ist
[
str
]]]]
=
None
,
Optional
[
Union
[
str
,
l
ist
[
str
]]]]
=
None
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
Union
[
L
ist
[
int
],
L
ist
[
L
ist
[
int
]]]]
=
None
,
prompt_token_ids
:
Optional
[
Union
[
l
ist
[
int
],
l
ist
[
l
ist
[
int
]]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
L
ist
[
PoolingRequestOutput
]:
)
->
l
ist
[
PoolingRequestOutput
]:
"""Apply pooling to the hidden states corresponding to the input
prompts.
...
...
@@ -898,7 +898,7 @@ class LLM:
if
prompt_token_ids
is
not
None
:
parsed_prompts
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
L
ist
[
str
]]],
prompts
),
prompts
=
cast
(
Optional
[
Union
[
str
,
l
ist
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
)
else
:
...
...
@@ -926,9 +926,9 @@ class LLM:
/
,
*
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
L
ist
[
EmbeddingRequestOutput
]:
)
->
l
ist
[
EmbeddingRequestOutput
]:
"""
Generate an embedding vector for each prompt.
...
...
@@ -966,9 +966,9 @@ class LLM:
/
,
*
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
L
ist
[
ClassificationRequestOutput
]:
)
->
l
ist
[
ClassificationRequestOutput
]:
"""
Generate class logits for each prompt.
...
...
@@ -1003,29 +1003,29 @@ class LLM:
def
_embedding_score
(
self
,
tokenizer
:
AnyTokenizer
,
text_1
:
L
ist
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
text_2
:
L
ist
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
text_1
:
l
ist
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
text_2
:
l
ist
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]],
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
L
ist
[
ScoringRequestOutput
]:
)
->
l
ist
[
ScoringRequestOutput
]:
encoded_output
:
L
ist
[
PoolingRequestOutput
]
=
self
.
encode
(
encoded_output
:
l
ist
[
PoolingRequestOutput
]
=
self
.
encode
(
text_1
+
text_2
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
encoded_output_1
:
L
ist
[
PoolingRequestOutput
]
=
encoded_output
[
encoded_output_1
:
l
ist
[
PoolingRequestOutput
]
=
encoded_output
[
0
:
len
(
text_1
)]
encoded_output_2
:
L
ist
[
PoolingRequestOutput
]
=
encoded_output
[
encoded_output_2
:
l
ist
[
PoolingRequestOutput
]
=
encoded_output
[
len
(
text_1
):]
if
len
(
encoded_output_1
)
==
1
:
encoded_output_1
=
encoded_output_1
*
len
(
encoded_output_2
)
scores
:
L
ist
[
PoolingRequestOutput
]
=
[]
scores
:
l
ist
[
PoolingRequestOutput
]
=
[]
scores
=
_cosine_similarity
(
tokenizer
=
tokenizer
,
embed_1
=
encoded_output_1
,
...
...
@@ -1038,13 +1038,13 @@ class LLM:
def
_cross_encoding_score
(
self
,
tokenizer
:
AnyTokenizer
,
text_1
:
L
ist
[
str
],
text_2
:
L
ist
[
str
],
text_1
:
l
ist
[
str
],
text_2
:
l
ist
[
str
],
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
L
ist
[
ScoringRequestOutput
]:
)
->
l
ist
[
ScoringRequestOutput
]:
if
isinstance
(
tokenizer
,
MistralTokenizer
):
raise
ValueError
(
...
...
@@ -1057,7 +1057,7 @@ class LLM:
pooling_params
=
PoolingParams
()
tokenization_kwargs
:
D
ict
[
str
,
Any
]
=
{}
tokenization_kwargs
:
d
ict
[
str
,
Any
]
=
{}
if
truncate_prompt_tokens
is
not
None
:
tokenization_kwargs
[
"truncation"
]
=
True
tokenization_kwargs
[
"max_length"
]
=
truncate_prompt_tokens
...
...
@@ -1094,9 +1094,9 @@ class LLM:
*
,
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
L
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
l
ist
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
L
ist
[
ScoringRequestOutput
]:
)
->
l
ist
[
ScoringRequestOutput
]:
"""Generate similarity scores for all pairs ``<text,text_pair>``.
The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
...
...
@@ -1162,12 +1162,12 @@ class LLM:
if
isinstance
(
text_1
,
(
str
,
dict
)):
# Convert a single prompt to a list.
text_1
=
[
text_1
]
input_text_1
:
L
ist
[
str
]
=
[
ensure_str
(
t
)
for
t
in
text_1
]
input_text_1
:
l
ist
[
str
]
=
[
ensure_str
(
t
)
for
t
in
text_1
]
if
isinstance
(
text_2
,
(
str
,
dict
)):
# Convert a single prompt to a list.
text_2
=
[
text_2
]
input_text_2
:
L
ist
[
str
]
=
[
ensure_str
(
t
)
for
t
in
text_2
]
input_text_2
:
l
ist
[
str
]
=
[
ensure_str
(
t
)
for
t
in
text_2
]
_validate_score_input_lens
(
input_text_1
,
input_text_2
)
...
...
@@ -1226,8 +1226,8 @@ class LLM:
# LEGACY
def
_convert_v1_inputs
(
self
,
prompts
:
Optional
[
Union
[
str
,
L
ist
[
str
]]],
prompt_token_ids
:
Optional
[
Union
[
L
ist
[
int
],
L
ist
[
L
ist
[
int
]]]],
prompts
:
Optional
[
Union
[
str
,
l
ist
[
str
]]],
prompt_token_ids
:
Optional
[
Union
[
l
ist
[
int
],
l
ist
[
l
ist
[
int
]]]],
):
# skip_tokenizer_init is now checked in engine
...
...
@@ -1252,7 +1252,7 @@ class LLM:
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
"provided."
)
parsed_prompts
:
L
ist
[
PromptType
]
=
[]
parsed_prompts
:
l
ist
[
PromptType
]
=
[]
for
i
in
range
(
num_requests
):
item
:
PromptType
...
...
@@ -1275,7 +1275,7 @@ class LLM:
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
priority
:
Optional
[
L
ist
[
int
]]
=
None
,
priority
:
Optional
[
l
ist
[
int
]]
=
None
,
)
->
None
:
if
guided_options
is
not
None
:
warnings
.
warn
(
...
...
@@ -1357,7 +1357,7 @@ class LLM:
def
_run_engine
(
self
,
*
,
use_tqdm
:
bool
)
->
L
ist
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
)
->
l
ist
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
# Initialize tqdm.
if
use_tqdm
:
num_requests
=
self
.
llm_engine
.
get_num_unfinished_requests
()
...
...
@@ -1370,7 +1370,7 @@ class LLM:
)
# Run the engine.
outputs
:
L
ist
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]
=
[]
outputs
:
l
ist
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]
=
[]
total_in_toks
=
0
total_out_toks
=
0
while
self
.
llm_engine
.
has_unfinished_requests
():
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment