Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
776dbd74
Unverified
Commit
776dbd74
authored
Oct 16, 2024
by
Russell Bryant
Committed by
GitHub
Oct 16, 2024
Browse files
[CI/Build] mypy: Resolve some errors from checking vllm/engine (#9267)
Signed-off-by:
Russell Bryant
<
rbryant@redhat.com
>
parent
83450458
Changes
20
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
109 additions
and
74 deletions
+109
-74
tools/mypy.sh
tools/mypy.sh
+1
-11
vllm/attention/layer.py
vllm/attention/layer.py
+1
-1
vllm/compilation/backends.py
vllm/compilation/backends.py
+2
-2
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+5
-3
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+1
-1
vllm/config.py
vllm/config.py
+6
-4
vllm/core/scheduler.py
vllm/core/scheduler.py
+4
-3
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+7
-5
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+12
-8
vllm/engine/metrics.py
vllm/engine/metrics.py
+9
-5
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+12
-5
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+2
-4
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+19
-6
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+5
-3
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+2
-2
vllm/engine/output_processor/util.py
vllm/engine/output_processor/util.py
+8
-5
vllm/inputs/parse.py
vllm/inputs/parse.py
+4
-1
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+5
-2
vllm/outputs.py
vllm/outputs.py
+2
-1
vllm/sequence.py
vllm/sequence.py
+2
-2
No files found.
tools/mypy.sh
View file @
776dbd74
...
@@ -13,24 +13,14 @@ run_mypy() {
...
@@ -13,24 +13,14 @@ run_mypy() {
run_mypy
# Note that this is less strict than CI
run_mypy
# Note that this is less strict than CI
run_mypy tests
run_mypy tests
run_mypy vllm/assets
run_mypy vllm/attention
run_mypy vllm/attention
#run_mypy vllm/compilation
run_mypy vllm/compilation
#run_mypy vllm/core
run_mypy vllm/distributed
run_mypy vllm/distributed
run_mypy vllm/engine
run_mypy vllm/engine
run_mypy vllm/entrypoints
run_mypy vllm/executor
run_mypy vllm/executor
#run_mypy vllm/inputs
run_mypy vllm/logging
run_mypy vllm/lora
run_mypy vllm/lora
run_mypy vllm/model_executor
run_mypy vllm/model_executor
run_mypy vllm/multimodal
run_mypy vllm/platforms
run_mypy vllm/plugins
run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/spec_decode
run_mypy vllm/transformers_utils
run_mypy vllm/usage
#run_mypy vllm/vllm_flash_attn
run_mypy vllm/worker
run_mypy vllm/worker
vllm/attention/layer.py
View file @
776dbd74
...
@@ -92,7 +92,7 @@ class Attention(nn.Module):
...
@@ -92,7 +92,7 @@ class Attention(nn.Module):
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
]
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
vllm/compilation/backends.py
View file @
776dbd74
...
@@ -244,8 +244,8 @@ def vllm_backend(
...
@@ -244,8 +244,8 @@ def vllm_backend(
def
select_default_backend
(
level
:
int
)
->
Union
[
str
,
Callable
]:
def
select_default_backend
(
level
:
int
)
->
Union
[
str
,
Callable
]:
if
level
in
[
CompilationLevel
.
DYNAMO_AS_IS
,
CompilationLevel
.
DYNAMO_ONCE
]:
if
level
in
[
CompilationLevel
.
DYNAMO_AS_IS
,
CompilationLevel
.
DYNAMO_ONCE
]:
backend
=
"eager"
backend
_str
=
"eager"
return
backend
return
backend
_str
assert
level
in
[
assert
level
in
[
CompilationLevel
.
INDUCTOR
,
CompilationLevel
.
INDUCTOR_MAX_AUTOTUNE
CompilationLevel
.
INDUCTOR
,
CompilationLevel
.
INDUCTOR_MAX_AUTOTUNE
],
f
"Invalid level
{
level
}
"
],
f
"Invalid level
{
level
}
"
...
...
vllm/compilation/decorators.py
View file @
776dbd74
...
@@ -35,6 +35,8 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
...
@@ -35,6 +35,8 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
def
cls_decorator_helper
(
cls
:
type
):
def
cls_decorator_helper
(
cls
:
type
):
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile``
if
not
hasattr
(
cls
,
'forward'
):
raise
TypeError
(
"decorated class should have a forward method."
)
sig
=
inspect
.
signature
(
cls
.
forward
)
sig
=
inspect
.
signature
(
cls
.
forward
)
for
k
in
dynamic_arg_dims
:
for
k
in
dynamic_arg_dims
:
if
k
not
in
sig
.
parameters
:
if
k
not
in
sig
.
parameters
:
...
@@ -63,13 +65,13 @@ def _support_torch_compile(cls: type,
...
@@ -63,13 +65,13 @@ def _support_torch_compile(cls: type,
# other than TorchCompileWrapperWithCustomDispatcher
# other than TorchCompileWrapperWithCustomDispatcher
cls
.
__bases__
=
cls
.
__bases__
+
(
TorchCompileWrapperWithCustomDispatcher
,
)
cls
.
__bases__
=
cls
.
__bases__
+
(
TorchCompileWrapperWithCustomDispatcher
,
)
old_init
=
cls
.
__init__
old_init
=
cls
.
__init__
# type: ignore
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
old_init
(
self
,
*
args
,
**
kwargs
)
old_init
(
self
,
*
args
,
**
kwargs
)
TorchCompileWrapperWithCustomDispatcher
.
__init__
(
self
)
TorchCompileWrapperWithCustomDispatcher
.
__init__
(
self
)
cls
.
__init__
=
__init__
cls
.
__init__
=
__init__
# type: ignore
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
# torch.compiler.is_compiling() means we are inside the compilation
# torch.compiler.is_compiling() means we are inside the compilation
...
@@ -109,5 +111,5 @@ def _support_torch_compile(cls: type,
...
@@ -109,5 +111,5 @@ def _support_torch_compile(cls: type,
model_output
=
self
.
forward
(
*
args
,
**
kwargs
)
model_output
=
self
.
forward
(
*
args
,
**
kwargs
)
return
model_output
return
model_output
cls
.
__call__
=
__call__
cls
.
__call__
=
__call__
# type: ignore
return
cls
return
cls
vllm/compilation/wrapper.py
View file @
776dbd74
...
@@ -73,7 +73,7 @@ class TorchCompileWrapperWithCustomDispatcher:
...
@@ -73,7 +73,7 @@ class TorchCompileWrapperWithCustomDispatcher:
return
return
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
frame
=
sys
.
_getframe
()
frame
=
sys
.
_getframe
()
while
True
:
while
frame
and
frame
.
f_back
:
frame
=
frame
.
f_back
frame
=
frame
.
f_back
code_name
=
frame
.
f_code
.
co_name
code_name
=
frame
.
f_code
.
co_name
file_name
=
frame
.
f_code
.
co_filename
.
split
(
os
.
path
.
sep
)[
-
1
]
file_name
=
frame
.
f_code
.
co_filename
.
split
(
os
.
path
.
sep
)[
-
1
]
...
...
vllm/config.py
View file @
776dbd74
...
@@ -626,13 +626,14 @@ class CacheConfig:
...
@@ -626,13 +626,14 @@ class CacheConfig:
self
.
sliding_window
=
sliding_window
self
.
sliding_window
=
sliding_window
self
.
enable_prefix_caching
=
enable_prefix_caching
self
.
enable_prefix_caching
=
enable_prefix_caching
self
.
cpu_offload_gb
=
cpu_offload_gb
self
.
cpu_offload_gb
=
cpu_offload_gb
self
.
_verify_args
()
self
.
_verify_args
()
self
.
_verify_cache_dtype
()
self
.
_verify_cache_dtype
()
self
.
_verify_prefix_caching
()
self
.
_verify_prefix_caching
()
# Will be set after profiling.
# Will be set after profiling.
self
.
num_gpu_blocks
=
None
self
.
num_gpu_blocks
:
Optional
[
int
]
=
None
self
.
num_cpu_blocks
=
None
self
.
num_cpu_blocks
:
Optional
[
int
]
=
None
def
metrics_info
(
self
):
def
metrics_info
(
self
):
# convert cache_config to dict(key: str, value: str) for prometheus
# convert cache_config to dict(key: str, value: str) for prometheus
...
@@ -709,7 +710,8 @@ class TokenizerPoolConfig:
...
@@ -709,7 +710,8 @@ class TokenizerPoolConfig:
@
classmethod
@
classmethod
def
create_config
(
def
create_config
(
cls
,
tokenizer_pool_size
:
int
,
tokenizer_pool_type
:
str
,
cls
,
tokenizer_pool_size
:
int
,
tokenizer_pool_type
:
Union
[
str
,
Type
[
"BaseTokenizerGroup"
]],
tokenizer_pool_extra_config
:
Optional
[
Union
[
str
,
dict
]]
tokenizer_pool_extra_config
:
Optional
[
Union
[
str
,
dict
]]
)
->
Optional
[
"TokenizerPoolConfig"
]:
)
->
Optional
[
"TokenizerPoolConfig"
]:
"""Create a TokenizerPoolConfig from the given parameters.
"""Create a TokenizerPoolConfig from the given parameters.
...
@@ -1544,7 +1546,7 @@ class LoRAConfig:
...
@@ -1544,7 +1546,7 @@ class LoRAConfig:
max_loras
:
int
max_loras
:
int
fully_sharded_loras
:
bool
=
False
fully_sharded_loras
:
bool
=
False
max_cpu_loras
:
Optional
[
int
]
=
None
max_cpu_loras
:
Optional
[
int
]
=
None
lora_dtype
:
Optional
[
torch
.
dtype
]
=
None
lora_dtype
:
Optional
[
Union
[
torch
.
dtype
,
str
]
]
=
None
lora_extra_vocab_size
:
int
=
256
lora_extra_vocab_size
:
int
=
256
# This is a constant.
# This is a constant.
lora_vocab_padding_size
:
ClassVar
[
int
]
=
256
lora_vocab_padding_size
:
ClassVar
[
int
]
=
256
...
...
vllm/core/scheduler.py
View file @
776dbd74
...
@@ -4,8 +4,9 @@ import random
...
@@ -4,8 +4,9 @@ import random
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
(
Callable
,
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
from
typing
import
Callable
,
Deque
,
Dict
,
Iterable
,
List
,
Optional
Tuple
,
Union
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
...
@@ -115,7 +116,7 @@ class ScheduledSequenceGroup:
...
@@ -115,7 +116,7 @@ class ScheduledSequenceGroup:
class
SchedulerOutputs
:
class
SchedulerOutputs
:
"""The scheduling decision made from a scheduler."""
"""The scheduling decision made from a scheduler."""
# Scheduled sequence groups.
# Scheduled sequence groups.
scheduled_seq_groups
:
Iterabl
e
[
ScheduledSequenceGroup
]
scheduled_seq_groups
:
GenericSequenc
e
[
ScheduledSequenceGroup
]
# Number of prefill groups scheduled.
# Number of prefill groups scheduled.
num_prefill_groups
:
int
num_prefill_groups
:
int
# Total number of batched tokens.
# Total number of batched tokens.
...
...
vllm/engine/arg_utils.py
View file @
776dbd74
...
@@ -3,7 +3,7 @@ import dataclasses
...
@@ -3,7 +3,7 @@ import dataclasses
import
json
import
json
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Literal
,
Mapping
,
Optional
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
)
Tuple
,
Type
,
Union
,
cast
)
import
torch
import
torch
...
@@ -89,7 +89,7 @@ class EngineArgs:
...
@@ -89,7 +89,7 @@ class EngineArgs:
trust_remote_code
:
bool
=
False
trust_remote_code
:
bool
=
False
download_dir
:
Optional
[
str
]
=
None
download_dir
:
Optional
[
str
]
=
None
load_format
:
str
=
'auto'
load_format
:
str
=
'auto'
config_format
:
str
=
'auto'
config_format
:
ConfigFormat
=
ConfigFormat
.
AUTO
dtype
:
str
=
'auto'
dtype
:
str
=
'auto'
kv_cache_dtype
:
str
=
'auto'
kv_cache_dtype
:
str
=
'auto'
quantization_param_path
:
Optional
[
str
]
=
None
quantization_param_path
:
Optional
[
str
]
=
None
...
@@ -181,7 +181,7 @@ class EngineArgs:
...
@@ -181,7 +181,7 @@ class EngineArgs:
scheduling_policy
:
Literal
[
"fcfs"
,
"priority"
]
=
"fcfs"
scheduling_policy
:
Literal
[
"fcfs"
,
"priority"
]
=
"fcfs"
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
if
not
self
.
tokenizer
:
self
.
tokenizer
=
self
.
model
self
.
tokenizer
=
self
.
model
# Setup plugins
# Setup plugins
...
@@ -837,7 +837,8 @@ class EngineArgs:
...
@@ -837,7 +837,8 @@ class EngineArgs:
def
create_model_config
(
self
)
->
ModelConfig
:
def
create_model_config
(
self
)
->
ModelConfig
:
return
ModelConfig
(
return
ModelConfig
(
model
=
self
.
model
,
model
=
self
.
model
,
tokenizer
=
self
.
tokenizer
,
# We know this is not None because we set it in __post_init__
tokenizer
=
cast
(
str
,
self
.
tokenizer
),
tokenizer_mode
=
self
.
tokenizer_mode
,
tokenizer_mode
=
self
.
tokenizer_mode
,
trust_remote_code
=
self
.
trust_remote_code
,
trust_remote_code
=
self
.
trust_remote_code
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -908,8 +909,9 @@ class EngineArgs:
...
@@ -908,8 +909,9 @@ class EngineArgs:
self
.
enable_prefix_caching
=
False
self
.
enable_prefix_caching
=
False
cache_config
=
CacheConfig
(
cache_config
=
CacheConfig
(
# neuron needs block_size = max_model_len
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
self
.
max_model_len
,
# neuron needs block_size = max_model_len
(
self
.
max_model_len
if
self
.
max_model_len
is
not
None
else
0
),
gpu_memory_utilization
=
self
.
gpu_memory_utilization
,
gpu_memory_utilization
=
self
.
gpu_memory_utilization
,
swap_space
=
self
.
swap_space
,
swap_space
=
self
.
swap_space
,
cache_dtype
=
self
.
kv_cache_dtype
,
cache_dtype
=
self
.
kv_cache_dtype
,
...
...
vllm/engine/llm_engine.py
View file @
776dbd74
...
@@ -6,7 +6,7 @@ from functools import partial
...
@@ -6,7 +6,7 @@ from functools import partial
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Type
,
Union
,
overload
from
typing
import
Set
,
Type
,
Union
,
cast
,
overload
import
torch
import
torch
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
...
@@ -44,7 +44,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
...
@@ -44,7 +44,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceGroupOutput
,
SequenceStatus
)
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
init_tracer
)
init_tracer
)
from
vllm.transformers_utils.config
import
try_get_generation_config
from
vllm.transformers_utils.config
import
try_get_generation_config
...
@@ -188,7 +188,7 @@ class LLMEngine:
...
@@ -188,7 +188,7 @@ class LLMEngine:
raise
TypeError
(
f
"Expected output of type
{
output_type
}
, "
raise
TypeError
(
f
"Expected output of type
{
output_type
}
, "
f
"but found type
{
type
(
output
)
}
"
)
f
"but found type
{
type
(
output
)
}
"
)
return
output
return
cast
(
_O
,
output
)
@
classmethod
@
classmethod
def
validate_outputs
(
def
validate_outputs
(
...
@@ -1039,6 +1039,7 @@ class LLMEngine:
...
@@ -1039,6 +1039,7 @@ class LLMEngine:
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
has_multiple_outputs
:
bool
=
len
(
outputs
)
>
1
has_multiple_outputs
:
bool
=
len
(
outputs
)
>
1
outputs_by_sequence_group
:
List
[
List
[
SequenceGroupOutput
]]
if
has_multiple_outputs
:
if
has_multiple_outputs
:
assert
self
.
scheduler_config
.
is_multi_step
or
\
assert
self
.
scheduler_config
.
is_multi_step
or
\
self
.
speculative_config
self
.
speculative_config
...
@@ -1084,6 +1085,7 @@ class LLMEngine:
...
@@ -1084,6 +1085,7 @@ class LLMEngine:
finished_before
.
append
(
i
)
finished_before
.
append
(
i
)
continue
continue
output
:
List
[
SequenceGroupOutput
]
if
has_multiple_outputs
:
if
has_multiple_outputs
:
output
=
outputs_by_sequence_group
[
i
]
output
=
outputs_by_sequence_group
[
i
]
else
:
else
:
...
@@ -1096,7 +1098,7 @@ class LLMEngine:
...
@@ -1096,7 +1098,7 @@ class LLMEngine:
seq_group
,
seq_group_meta
,
is_first_step_output
)
seq_group
,
seq_group_meta
,
is_first_step_output
)
else
:
else
:
seq_group
.
update_num_computed_tokens
(
seq_group
.
update_num_computed_tokens
(
seq_group_meta
.
token_chunk_size
)
seq_group_meta
.
token_chunk_size
or
0
)
if
outputs
:
if
outputs
:
for
o
in
outputs
:
for
o
in
outputs
:
...
@@ -1104,13 +1106,13 @@ class LLMEngine:
...
@@ -1104,13 +1106,13 @@ class LLMEngine:
and
seq_group
.
metrics
is
not
None
):
and
seq_group
.
metrics
is
not
None
):
if
seq_group
.
metrics
.
model_forward_time
is
not
None
:
if
seq_group
.
metrics
.
model_forward_time
is
not
None
:
seq_group
.
metrics
.
model_forward_time
+=
(
seq_group
.
metrics
.
model_forward_time
+=
(
o
.
model_forward_time
)
o
.
model_forward_time
or
0
)
else
:
else
:
seq_group
.
metrics
.
model_forward_time
=
(
seq_group
.
metrics
.
model_forward_time
=
(
o
.
model_forward_time
)
o
.
model_forward_time
)
if
seq_group
.
metrics
.
model_execute_time
is
not
None
:
if
seq_group
.
metrics
.
model_execute_time
is
not
None
:
seq_group
.
metrics
.
model_execute_time
+=
(
seq_group
.
metrics
.
model_execute_time
+=
(
o
.
model_execute_time
)
o
.
model_execute_time
or
0
)
else
:
else
:
seq_group
.
metrics
.
model_execute_time
=
(
seq_group
.
metrics
.
model_execute_time
=
(
o
.
model_execute_time
)
o
.
model_execute_time
)
...
@@ -1236,8 +1238,10 @@ class LLMEngine:
...
@@ -1236,8 +1238,10 @@ class LLMEngine:
seq_group
,
seq_group_metadata
,
seq_group
,
seq_group_metadata
,
seq_group
.
state
.
num_steps
==
1
)
seq_group
.
state
.
num_steps
==
1
)
else
:
else
:
seq_group
.
update_num_computed_tokens
(
token_chunk_size
=
(
seq_group_metadata
.
token_chunk_size
seq_group_metadata
.
token_chunk_size
)
if
seq_group_metadata
.
token_chunk_size
is
not
None
else
0
)
seq_group
.
update_num_computed_tokens
(
token_chunk_size
)
if
seq_group_metadata
.
do_sample
:
if
seq_group_metadata
.
do_sample
:
assert
len
(
sequence_group_outputs
.
samples
)
==
1
,
(
assert
len
(
sequence_group_outputs
.
samples
)
==
1
,
(
...
...
vllm/engine/metrics.py
View file @
776dbd74
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
from
typing
import
Counter
as
CollectionsCounter
from
typing
import
Counter
as
CollectionsCounter
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Type
,
Union
,
cast
import
numpy
as
np
import
numpy
as
np
import
prometheus_client
import
prometheus_client
...
@@ -249,10 +249,11 @@ class _RayHistogramWrapper:
...
@@ -249,10 +249,11 @@ class _RayHistogramWrapper:
labelnames
:
Optional
[
List
[
str
]]
=
None
,
labelnames
:
Optional
[
List
[
str
]]
=
None
,
buckets
:
Optional
[
List
[
float
]]
=
None
):
buckets
:
Optional
[
List
[
float
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
boundaries
=
buckets
if
buckets
else
[]
self
.
_histogram
=
ray_metrics
.
Histogram
(
name
=
name
,
self
.
_histogram
=
ray_metrics
.
Histogram
(
name
=
name
,
description
=
documentation
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
,
tag_keys
=
labelnames_tuple
,
boundaries
=
b
ucket
s
)
boundaries
=
b
oundarie
s
)
def
labels
(
self
,
**
labels
):
def
labels
(
self
,
**
labels
):
self
.
_histogram
.
set_default_tags
(
labels
)
self
.
_histogram
.
set_default_tags
(
labels
)
...
@@ -267,9 +268,12 @@ class RayMetrics(Metrics):
...
@@ -267,9 +268,12 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
"""
_gauge_cls
=
_RayGaugeWrapper
_gauge_cls
:
Type
[
prometheus_client
.
Gauge
]
=
cast
(
_counter_cls
=
_RayCounterWrapper
Type
[
prometheus_client
.
Gauge
],
_RayGaugeWrapper
)
_histogram_cls
=
_RayHistogramWrapper
_counter_cls
:
Type
[
prometheus_client
.
Counter
]
=
cast
(
Type
[
prometheus_client
.
Counter
],
_RayCounterWrapper
)
_histogram_cls
:
Type
[
prometheus_client
.
Histogram
]
=
cast
(
Type
[
prometheus_client
.
Histogram
],
_RayHistogramWrapper
)
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
if
ray_metrics
is
None
:
if
ray_metrics
is
None
:
...
...
vllm/engine/multiprocessing/client.py
View file @
776dbd74
...
@@ -3,7 +3,7 @@ import copy
...
@@ -3,7 +3,7 @@ import copy
import
pickle
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
contextlib
import
contextmanager
,
suppress
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
List
,
Mapping
,
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
List
,
Mapping
,
Optional
,
Union
,
overload
)
Optional
,
Union
,
cast
,
overload
)
import
cloudpickle
import
cloudpickle
import
zmq
import
zmq
...
@@ -513,9 +513,14 @@ class MQLLMEngineClient(EngineClient):
...
@@ -513,9 +513,14 @@ class MQLLMEngineClient(EngineClient):
assert
(
prompt
is
not
None
and
pooling_params
is
not
None
assert
(
prompt
is
not
None
and
pooling_params
is
not
None
and
request_id
is
not
None
)
and
request_id
is
not
None
)
return
self
.
_process_request
(
prompt
,
pooling_params
,
request_id
,
return
cast
(
lora_request
,
trace_headers
,
None
,
AsyncGenerator
[
EmbeddingRequestOutput
,
None
],
priority
)
self
.
_process_request
(
prompt
,
pooling_params
,
request_id
,
lora_request
,
trace_headers
,
priority
=
priority
))
async
def
_process_request
(
async
def
_process_request
(
self
,
self
,
...
@@ -543,7 +548,9 @@ class MQLLMEngineClient(EngineClient):
...
@@ -543,7 +548,9 @@ class MQLLMEngineClient(EngineClient):
build_guided_decoding_logits_processor_async
(
build_guided_decoding_logits_processor_async
(
sampling_params
=
params
,
sampling_params
=
params
,
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
),
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
),
default_guided_backend
=
self
.
decoding_config
.
guided_decoding_backend
default_guided_backend
=
(
self
.
decoding_config
.
guided_decoding_backend
if
self
.
decoding_config
else
DecodingConfig
.
guided_decoding_backend
),
)
)
# 1) Create output queue for this requests.
# 1) Create output queue for this requests.
...
...
vllm/engine/multiprocessing/engine.py
View file @
776dbd74
...
@@ -73,11 +73,9 @@ class MQLLMEngine:
...
@@ -73,11 +73,9 @@ class MQLLMEngine:
# For MQLLMEngine, we can use cached outputs, since each new request
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
# the python object to be reused again.
use_cached_outputs
=
True
kwargs
[
'
use_cached_outputs
'
]
=
True
self
.
engine
=
LLMEngine
(
*
args
,
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
)
**
kwargs
,
use_cached_outputs
=
use_cached_outputs
)
self
.
log_requests
=
log_requests
self
.
log_requests
=
log_requests
self
.
use_async_sockets
=
use_async_sockets
self
.
use_async_sockets
=
use_async_sockets
...
...
vllm/engine/output_processor/multi_step.py
View file @
776dbd74
import
functools
import
functools
from
typing
import
Callable
,
List
from
typing
import
Callable
,
List
,
cast
from
vllm.core.scheduler
import
Scheduler
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.output_processor.interfaces
import
(
from
vllm.engine.output_processor.interfaces
import
(
...
@@ -9,8 +9,10 @@ from vllm.engine.output_processor.single_step import (
...
@@ -9,8 +9,10 @@ from vllm.engine.output_processor.single_step import (
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
Sequence
,
SequenceGroup
,
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
CompletionSequenceGroupOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -57,6 +59,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -57,6 +59,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"""
"""
for
output
in
outputs
:
for
output
in
outputs
:
# Concatenate single-step prompt logprob processing results.
# Concatenate single-step prompt logprob processing results.
assert
isinstance
(
output
,
CompletionSequenceGroupOutput
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
@
staticmethod
@
staticmethod
...
@@ -100,8 +103,18 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -100,8 +103,18 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Beam search not supported in multi-step decoding."
)
"Beam search not supported in multi-step decoding."
)
seq
=
seqs
[
0
]
seq
=
seqs
[
0
]
seq_id
=
seq
.
seq_id
seq_id
=
seq
.
seq_id
assert
all
(
# This method is defined in the more generic
[
seq_id
==
output
.
samples
[
0
].
parent_seq_id
for
output
in
outputs
])
# SequenceGroupOutputProcessor, but here we assume that the outputs are
# of a more specific type.
assert
all
([
isinstance
(
output
,
CompletionSequenceGroupOutput
)
for
output
in
outputs
])
compl_outputs
=
cast
(
List
[
CompletionSequenceGroupOutput
],
outputs
)
assert
all
([
seq_id
==
output
.
samples
[
0
].
parent_seq_id
for
output
in
compl_outputs
])
if
is_async
:
if
is_async
:
# Async case: We process tokens one by one. Here, we know the token
# Async case: We process tokens one by one. Here, we know the token
...
@@ -113,7 +126,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -113,7 +126,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Since there's only one sequence per sequence group,
# Since there's only one sequence per sequence group,
# we can take the first sample.
# we can take the first sample.
samples
=
[
output
.
samples
[
0
]
for
output
in
outputs
]
samples
=
[
output
.
samples
[
0
]
for
output
in
compl_
outputs
]
# entries in sample tokens may be invalid (eg. due to spec decode
# entries in sample tokens may be invalid (eg. due to spec decode
# rejecting tokens).
# rejecting tokens).
...
...
vllm/engine/output_processor/single_step.py
View file @
776dbd74
...
@@ -6,8 +6,9 @@ from vllm.engine.output_processor.interfaces import (
...
@@ -6,8 +6,9 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor
)
SequenceGroupOutputProcessor
)
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Sequence
,
SequenceOutput
,
SequenceStatus
)
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -16,7 +17,7 @@ logger = init_logger(__name__)
...
@@ -16,7 +17,7 @@ logger = init_logger(__name__)
def
single_step_process_prompt_logprob
(
def
single_step_process_prompt_logprob
(
sg_output_proc
:
SequenceGroupOutputProcessor
,
seq_group
:
SequenceGroup
,
sg_output_proc
:
SequenceGroupOutputProcessor
,
seq_group
:
SequenceGroup
,
output
:
SequenceGroupOutput
)
->
None
:
output
:
Completion
SequenceGroupOutput
)
->
None
:
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
for a given step.
for a given step.
...
@@ -106,6 +107,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -106,6 +107,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
"""
"""
assert
len
(
outputs
)
==
1
,
(
"Single step should only has 1 output."
)
assert
len
(
outputs
)
==
1
,
(
"Single step should only has 1 output."
)
output
=
outputs
[
0
]
output
=
outputs
[
0
]
assert
isinstance
(
output
,
CompletionSequenceGroupOutput
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
...
...
vllm/engine/output_processor/stop_checker.py
View file @
776dbd74
...
@@ -57,7 +57,7 @@ class StopChecker:
...
@@ -57,7 +57,7 @@ class StopChecker:
# Check if a stop token was encountered.
# Check if a stop token was encountered.
# This assumes a single token produced per step.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
()
last_token_id
=
seq
.
get_last_token_id
()
if
last_token_id
in
sampling_params
.
stop_token_ids
:
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
())
:
if
new_char_count
and
(
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
# Remove last token
...
@@ -92,7 +92,7 @@ class StopChecker:
...
@@ -92,7 +92,7 @@ class StopChecker:
Returns the stop string if matched or else None.
Returns the stop string if matched or else None.
"""
"""
if
not
new_char_count
:
if
not
new_char_count
or
not
sampling_params
.
stop
:
return
None
return
None
for
stop_str
in
sampling_params
.
stop
:
for
stop_str
in
sampling_params
.
stop
:
...
...
vllm/engine/output_processor/util.py
View file @
776dbd74
from
typing
import
List
from
typing
import
List
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
typing
import
cast
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
Pooler
Output
,
SequenceGroupOutput
from
vllm.sequence
import
CompletionSequenceGroup
Output
,
SequenceGroupOutput
def
create_output_by_sequence_group
(
def
create_output_by_sequence_group
(
outputs
:
GenericSequence
[
Union
[
SamplerOutput
,
PoolerOutput
]
],
outputs
:
GenericSequence
[
SamplerOutput
],
num_seq_groups
:
int
)
->
List
[
List
[
SequenceGroupOutput
]]:
num_seq_groups
:
int
)
->
List
[
List
[
SequenceGroupOutput
]]:
"""Helper method which transforms a 2d list organized by
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
[step][sequence group] into [sequence group][step].
"""
"""
output_by_sequence_group
:
List
[
List
[
SequenceGroupOutput
]]
=
[
output_by_sequence_group
:
List
[
List
[
Completion
SequenceGroupOutput
]]
=
[
[]
for
_
in
range
(
num_seq_groups
)
[]
for
_
in
range
(
num_seq_groups
)
]
]
for
step
in
outputs
:
for
step
in
outputs
:
sequence_group_output
:
CompletionSequenceGroupOutput
for
i
,
sequence_group_output
in
enumerate
(
step
):
for
i
,
sequence_group_output
in
enumerate
(
step
):
output_by_sequence_group
[
i
].
append
(
sequence_group_output
)
output_by_sequence_group
[
i
].
append
(
sequence_group_output
)
return
output_by_sequence_group
# Cast to the more generic type that CompletionSequenceGroupOutput
# inherits from.
return
cast
(
List
[
List
[
SequenceGroupOutput
]],
output_by_sequence_group
)
vllm/inputs/parse.py
View file @
776dbd74
from
typing
import
List
,
Literal
,
Sequence
,
TypedDict
,
Union
,
overload
from
typing
import
List
,
Literal
,
Sequence
,
TypedDict
,
Union
,
cast
,
overload
from
typing_extensions
import
TypeIs
from
typing_extensions
import
TypeIs
...
@@ -44,13 +44,16 @@ def parse_and_batch_prompt(
...
@@ -44,13 +44,16 @@ def parse_and_batch_prompt(
if
is_list_of
(
prompt
,
str
):
if
is_list_of
(
prompt
,
str
):
# case 2: array of strings
# case 2: array of strings
prompt
=
cast
(
List
[
str
],
prompt
)
return
[
return
[
ParsedText
(
content
=
elem
,
is_tokens
=
False
)
for
elem
in
prompt
ParsedText
(
content
=
elem
,
is_tokens
=
False
)
for
elem
in
prompt
]
]
if
is_list_of
(
prompt
,
int
):
if
is_list_of
(
prompt
,
int
):
# case 3: array of tokens
# case 3: array of tokens
prompt
=
cast
(
List
[
int
],
prompt
)
return
[
ParsedTokens
(
content
=
prompt
,
is_tokens
=
True
)]
return
[
ParsedTokens
(
content
=
prompt
,
is_tokens
=
True
)]
if
is_list_of
(
prompt
,
list
):
if
is_list_of
(
prompt
,
list
):
prompt
=
cast
(
List
[
List
[
int
]],
prompt
)
if
len
(
prompt
[
0
])
==
0
:
if
len
(
prompt
[
0
])
==
0
:
raise
ValueError
(
"please provide at least one prompt"
)
raise
ValueError
(
"please provide at least one prompt"
)
...
...
vllm/model_executor/layers/sampler.py
View file @
776dbd74
...
@@ -4,7 +4,7 @@ import warnings
...
@@ -4,7 +4,7 @@ import warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
from
importlib.util
import
find_spec
from
math
import
inf
from
math
import
inf
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
msgspec
import
msgspec
import
torch
import
torch
...
@@ -117,12 +117,15 @@ class SamplerOutput(
...
@@ -117,12 +117,15 @@ class SamplerOutput(
# block/sync across workers, cpu-gpu sync time and sampling time.
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time
:
Optional
[
float
]
=
None
model_execute_time
:
Optional
[
float
]
=
None
def
__getitem__
(
self
,
idx
:
int
):
def
__getitem__
(
self
,
idx
:
int
)
->
CompletionSequenceGroupOutput
:
return
self
.
outputs
[
idx
]
return
self
.
outputs
[
idx
]
def
__setitem__
(
self
,
idx
:
int
,
value
):
def
__setitem__
(
self
,
idx
:
int
,
value
):
self
.
outputs
[
idx
]
=
value
self
.
outputs
[
idx
]
=
value
def
__iter__
(
self
)
->
Iterator
[
CompletionSequenceGroupOutput
]:
return
iter
(
self
.
outputs
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
outputs
)
return
len
(
self
.
outputs
)
...
...
vllm/outputs.py
View file @
776dbd74
...
@@ -4,6 +4,7 @@ from typing import List, Optional
...
@@ -4,6 +4,7 @@ from typing import List, Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
typing
import
Union
from
vllm.inputs
import
PromptType
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
...
@@ -92,7 +93,7 @@ class RequestOutput:
...
@@ -92,7 +93,7 @@ class RequestOutput:
def
__init__
(
def
__init__
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
PromptType
],
prompt_token_ids
:
Optional
[
List
[
int
]],
prompt_token_ids
:
Optional
[
List
[
int
]],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
outputs
:
List
[
CompletionOutput
],
outputs
:
List
[
CompletionOutput
],
...
...
vllm/sequence.py
View file @
776dbd74
...
@@ -788,7 +788,7 @@ class SequenceGroup:
...
@@ -788,7 +788,7 @@ class SequenceGroup:
assert
num_lookahead_slots
+
1
==
num_scheduler_steps
or
is_prefill
assert
num_lookahead_slots
+
1
==
num_scheduler_steps
or
is_prefill
self
.
init_multi_step
(
num_steps
=
num_lookahead_slots
+
1
)
self
.
init_multi_step
(
num_steps
=
num_lookahead_slots
+
1
)
def
get_last_latency
(
self
,
now
:
float
)
->
Optional
[
float
]
:
def
get_last_latency
(
self
,
now
:
float
)
->
float
:
"""Sets the last token time for Request level timings."""
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
# If still in prefill phase, raise Error.
if
self
.
is_prefill
():
if
self
.
is_prefill
():
...
@@ -1198,7 +1198,7 @@ class PoolerOutput(
...
@@ -1198,7 +1198,7 @@ class PoolerOutput(
spec_decode_worker_metrics
:
Optional
[
SpecDecodeWorkerMetrics
]
=
None
spec_decode_worker_metrics
:
Optional
[
SpecDecodeWorkerMetrics
]
=
None
def
__getitem__
(
self
,
idx
:
int
):
def
__getitem__
(
self
,
idx
:
int
)
->
EmbeddingSequenceGroupOutput
:
return
self
.
outputs
[
idx
]
return
self
.
outputs
[
idx
]
def
__setitem__
(
self
,
idx
:
int
,
value
):
def
__setitem__
(
self
,
idx
:
int
,
value
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment