Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
776dbd74
Unverified
Commit
776dbd74
authored
Oct 16, 2024
by
Russell Bryant
Committed by
GitHub
Oct 16, 2024
Browse files
[CI/Build] mypy: Resolve some errors from checking vllm/engine (#9267)
Signed-off-by:
Russell Bryant
<
rbryant@redhat.com
>
parent
83450458
Changes
20
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
109 additions
and
74 deletions
+109
-74
tools/mypy.sh
tools/mypy.sh
+1
-11
vllm/attention/layer.py
vllm/attention/layer.py
+1
-1
vllm/compilation/backends.py
vllm/compilation/backends.py
+2
-2
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+5
-3
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+1
-1
vllm/config.py
vllm/config.py
+6
-4
vllm/core/scheduler.py
vllm/core/scheduler.py
+4
-3
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+7
-5
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+12
-8
vllm/engine/metrics.py
vllm/engine/metrics.py
+9
-5
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+12
-5
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+2
-4
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+19
-6
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+5
-3
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+2
-2
vllm/engine/output_processor/util.py
vllm/engine/output_processor/util.py
+8
-5
vllm/inputs/parse.py
vllm/inputs/parse.py
+4
-1
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+5
-2
vllm/outputs.py
vllm/outputs.py
+2
-1
vllm/sequence.py
vllm/sequence.py
+2
-2
No files found.
tools/mypy.sh
View file @
776dbd74
...
...
@@ -13,24 +13,14 @@ run_mypy() {
run_mypy
# Note that this is less strict than CI
run_mypy tests
run_mypy vllm/assets
run_mypy vllm/attention
#run_mypy vllm/compilation
#run_mypy vllm/core
run_mypy vllm/compilation
run_mypy vllm/distributed
run_mypy vllm/engine
run_mypy vllm/entrypoints
run_mypy vllm/executor
#run_mypy vllm/inputs
run_mypy vllm/logging
run_mypy vllm/lora
run_mypy vllm/model_executor
run_mypy vllm/multimodal
run_mypy vllm/platforms
run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/transformers_utils
run_mypy vllm/usage
#run_mypy vllm/vllm_flash_attn
run_mypy vllm/worker
vllm/attention/layer.py
View file @
776dbd74
...
...
@@ -92,7 +92,7 @@ class Attention(nn.Module):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
]
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
...
...
vllm/compilation/backends.py
View file @
776dbd74
...
...
@@ -244,8 +244,8 @@ def vllm_backend(
def
select_default_backend
(
level
:
int
)
->
Union
[
str
,
Callable
]:
if
level
in
[
CompilationLevel
.
DYNAMO_AS_IS
,
CompilationLevel
.
DYNAMO_ONCE
]:
backend
=
"eager"
return
backend
backend
_str
=
"eager"
return
backend
_str
assert
level
in
[
CompilationLevel
.
INDUCTOR
,
CompilationLevel
.
INDUCTOR_MAX_AUTOTUNE
],
f
"Invalid level
{
level
}
"
...
...
vllm/compilation/decorators.py
View file @
776dbd74
...
...
@@ -35,6 +35,8 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
def
cls_decorator_helper
(
cls
:
type
):
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile``
if
not
hasattr
(
cls
,
'forward'
):
raise
TypeError
(
"decorated class should have a forward method."
)
sig
=
inspect
.
signature
(
cls
.
forward
)
for
k
in
dynamic_arg_dims
:
if
k
not
in
sig
.
parameters
:
...
...
@@ -63,13 +65,13 @@ def _support_torch_compile(cls: type,
# other than TorchCompileWrapperWithCustomDispatcher
cls
.
__bases__
=
cls
.
__bases__
+
(
TorchCompileWrapperWithCustomDispatcher
,
)
old_init
=
cls
.
__init__
old_init
=
cls
.
__init__
# type: ignore
def
__init__
(
self
,
*
args
,
**
kwargs
):
old_init
(
self
,
*
args
,
**
kwargs
)
TorchCompileWrapperWithCustomDispatcher
.
__init__
(
self
)
cls
.
__init__
=
__init__
cls
.
__init__
=
__init__
# type: ignore
def
__call__
(
self
,
*
args
,
**
kwargs
):
# torch.compiler.is_compiling() means we are inside the compilation
...
...
@@ -109,5 +111,5 @@ def _support_torch_compile(cls: type,
model_output
=
self
.
forward
(
*
args
,
**
kwargs
)
return
model_output
cls
.
__call__
=
__call__
cls
.
__call__
=
__call__
# type: ignore
return
cls
vllm/compilation/wrapper.py
View file @
776dbd74
...
...
@@ -73,7 +73,7 @@ class TorchCompileWrapperWithCustomDispatcher:
return
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
frame
=
sys
.
_getframe
()
while
True
:
while
frame
and
frame
.
f_back
:
frame
=
frame
.
f_back
code_name
=
frame
.
f_code
.
co_name
file_name
=
frame
.
f_code
.
co_filename
.
split
(
os
.
path
.
sep
)[
-
1
]
...
...
vllm/config.py
View file @
776dbd74
...
...
@@ -626,13 +626,14 @@ class CacheConfig:
self
.
sliding_window
=
sliding_window
self
.
enable_prefix_caching
=
enable_prefix_caching
self
.
cpu_offload_gb
=
cpu_offload_gb
self
.
_verify_args
()
self
.
_verify_cache_dtype
()
self
.
_verify_prefix_caching
()
# Will be set after profiling.
self
.
num_gpu_blocks
=
None
self
.
num_cpu_blocks
=
None
self
.
num_gpu_blocks
:
Optional
[
int
]
=
None
self
.
num_cpu_blocks
:
Optional
[
int
]
=
None
def
metrics_info
(
self
):
# convert cache_config to dict(key: str, value: str) for prometheus
...
...
@@ -709,7 +710,8 @@ class TokenizerPoolConfig:
@
classmethod
def
create_config
(
cls
,
tokenizer_pool_size
:
int
,
tokenizer_pool_type
:
str
,
cls
,
tokenizer_pool_size
:
int
,
tokenizer_pool_type
:
Union
[
str
,
Type
[
"BaseTokenizerGroup"
]],
tokenizer_pool_extra_config
:
Optional
[
Union
[
str
,
dict
]]
)
->
Optional
[
"TokenizerPoolConfig"
]:
"""Create a TokenizerPoolConfig from the given parameters.
...
...
@@ -1544,7 +1546,7 @@ class LoRAConfig:
max_loras
:
int
fully_sharded_loras
:
bool
=
False
max_cpu_loras
:
Optional
[
int
]
=
None
lora_dtype
:
Optional
[
torch
.
dtype
]
=
None
lora_dtype
:
Optional
[
Union
[
torch
.
dtype
,
str
]
]
=
None
lora_extra_vocab_size
:
int
=
256
# This is a constant.
lora_vocab_padding_size
:
ClassVar
[
int
]
=
256
...
...
vllm/core/scheduler.py
View file @
776dbd74
...
...
@@ -4,8 +4,9 @@ import random
import
time
from
collections
import
deque
from
dataclasses
import
dataclass
,
field
from
typing
import
(
Callable
,
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
)
from
typing
import
Callable
,
Deque
,
Dict
,
Iterable
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
...
...
@@ -115,7 +116,7 @@ class ScheduledSequenceGroup:
class
SchedulerOutputs
:
"""The scheduling decision made from a scheduler."""
# Scheduled sequence groups.
scheduled_seq_groups
:
Iterabl
e
[
ScheduledSequenceGroup
]
scheduled_seq_groups
:
GenericSequenc
e
[
ScheduledSequenceGroup
]
# Number of prefill groups scheduled.
num_prefill_groups
:
int
# Total number of batched tokens.
...
...
vllm/engine/arg_utils.py
View file @
776dbd74
...
...
@@ -3,7 +3,7 @@ import dataclasses
import
json
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
)
Tuple
,
Type
,
Union
,
cast
)
import
torch
...
...
@@ -89,7 +89,7 @@ class EngineArgs:
trust_remote_code
:
bool
=
False
download_dir
:
Optional
[
str
]
=
None
load_format
:
str
=
'auto'
config_format
:
str
=
'auto'
config_format
:
ConfigFormat
=
ConfigFormat
.
AUTO
dtype
:
str
=
'auto'
kv_cache_dtype
:
str
=
'auto'
quantization_param_path
:
Optional
[
str
]
=
None
...
...
@@ -181,7 +181,7 @@ class EngineArgs:
scheduling_policy
:
Literal
[
"fcfs"
,
"priority"
]
=
"fcfs"
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
if
not
self
.
tokenizer
:
self
.
tokenizer
=
self
.
model
# Setup plugins
...
...
@@ -837,7 +837,8 @@ class EngineArgs:
def
create_model_config
(
self
)
->
ModelConfig
:
return
ModelConfig
(
model
=
self
.
model
,
tokenizer
=
self
.
tokenizer
,
# We know this is not None because we set it in __post_init__
tokenizer
=
cast
(
str
,
self
.
tokenizer
),
tokenizer_mode
=
self
.
tokenizer_mode
,
trust_remote_code
=
self
.
trust_remote_code
,
dtype
=
self
.
dtype
,
...
...
@@ -908,8 +909,9 @@ class EngineArgs:
self
.
enable_prefix_caching
=
False
cache_config
=
CacheConfig
(
# neuron needs block_size = max_model_len
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
self
.
max_model_len
,
# neuron needs block_size = max_model_len
(
self
.
max_model_len
if
self
.
max_model_len
is
not
None
else
0
),
gpu_memory_utilization
=
self
.
gpu_memory_utilization
,
swap_space
=
self
.
swap_space
,
cache_dtype
=
self
.
kv_cache_dtype
,
...
...
vllm/engine/llm_engine.py
View file @
776dbd74
...
...
@@ -6,7 +6,7 @@ from functools import partial
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Type
,
Union
,
overload
from
typing
import
Set
,
Type
,
Union
,
cast
,
overload
import
torch
from
typing_extensions
import
TypeVar
...
...
@@ -44,7 +44,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceGroupOutput
,
SequenceStatus
)
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
init_tracer
)
from
vllm.transformers_utils.config
import
try_get_generation_config
...
...
@@ -188,7 +188,7 @@ class LLMEngine:
raise
TypeError
(
f
"Expected output of type
{
output_type
}
, "
f
"but found type
{
type
(
output
)
}
"
)
return
output
return
cast
(
_O
,
output
)
@
classmethod
def
validate_outputs
(
...
...
@@ -1039,6 +1039,7 @@ class LLMEngine:
scheduler_outputs
.
scheduled_seq_groups
)
has_multiple_outputs
:
bool
=
len
(
outputs
)
>
1
outputs_by_sequence_group
:
List
[
List
[
SequenceGroupOutput
]]
if
has_multiple_outputs
:
assert
self
.
scheduler_config
.
is_multi_step
or
\
self
.
speculative_config
...
...
@@ -1084,6 +1085,7 @@ class LLMEngine:
finished_before
.
append
(
i
)
continue
output
:
List
[
SequenceGroupOutput
]
if
has_multiple_outputs
:
output
=
outputs_by_sequence_group
[
i
]
else
:
...
...
@@ -1096,7 +1098,7 @@ class LLMEngine:
seq_group
,
seq_group_meta
,
is_first_step_output
)
else
:
seq_group
.
update_num_computed_tokens
(
seq_group_meta
.
token_chunk_size
)
seq_group_meta
.
token_chunk_size
or
0
)
if
outputs
:
for
o
in
outputs
:
...
...
@@ -1104,13 +1106,13 @@ class LLMEngine:
and
seq_group
.
metrics
is
not
None
):
if
seq_group
.
metrics
.
model_forward_time
is
not
None
:
seq_group
.
metrics
.
model_forward_time
+=
(
o
.
model_forward_time
)
o
.
model_forward_time
or
0
)
else
:
seq_group
.
metrics
.
model_forward_time
=
(
o
.
model_forward_time
)
if
seq_group
.
metrics
.
model_execute_time
is
not
None
:
seq_group
.
metrics
.
model_execute_time
+=
(
o
.
model_execute_time
)
o
.
model_execute_time
or
0
)
else
:
seq_group
.
metrics
.
model_execute_time
=
(
o
.
model_execute_time
)
...
...
@@ -1236,8 +1238,10 @@ class LLMEngine:
seq_group
,
seq_group_metadata
,
seq_group
.
state
.
num_steps
==
1
)
else
:
seq_group
.
update_num_computed_tokens
(
seq_group_metadata
.
token_chunk_size
)
token_chunk_size
=
(
seq_group_metadata
.
token_chunk_size
if
seq_group_metadata
.
token_chunk_size
is
not
None
else
0
)
seq_group
.
update_num_computed_tokens
(
token_chunk_size
)
if
seq_group_metadata
.
do_sample
:
assert
len
(
sequence_group_outputs
.
samples
)
==
1
,
(
...
...
vllm/engine/metrics.py
View file @
776dbd74
from
typing
import
TYPE_CHECKING
from
typing
import
Counter
as
CollectionsCounter
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Type
,
Union
,
cast
import
numpy
as
np
import
prometheus_client
...
...
@@ -249,10 +249,11 @@ class _RayHistogramWrapper:
labelnames
:
Optional
[
List
[
str
]]
=
None
,
buckets
:
Optional
[
List
[
float
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
boundaries
=
buckets
if
buckets
else
[]
self
.
_histogram
=
ray_metrics
.
Histogram
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
,
boundaries
=
b
ucket
s
)
boundaries
=
b
oundarie
s
)
def
labels
(
self
,
**
labels
):
self
.
_histogram
.
set_default_tags
(
labels
)
...
...
@@ -267,9 +268,12 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_gauge_cls
=
_RayGaugeWrapper
_counter_cls
=
_RayCounterWrapper
_histogram_cls
=
_RayHistogramWrapper
_gauge_cls
:
Type
[
prometheus_client
.
Gauge
]
=
cast
(
Type
[
prometheus_client
.
Gauge
],
_RayGaugeWrapper
)
_counter_cls
:
Type
[
prometheus_client
.
Counter
]
=
cast
(
Type
[
prometheus_client
.
Counter
],
_RayCounterWrapper
)
_histogram_cls
:
Type
[
prometheus_client
.
Histogram
]
=
cast
(
Type
[
prometheus_client
.
Histogram
],
_RayHistogramWrapper
)
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
if
ray_metrics
is
None
:
...
...
vllm/engine/multiprocessing/client.py
View file @
776dbd74
...
...
@@ -3,7 +3,7 @@ import copy
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
List
,
Mapping
,
Optional
,
Union
,
overload
)
Optional
,
Union
,
cast
,
overload
)
import
cloudpickle
import
zmq
...
...
@@ -513,9 +513,14 @@ class MQLLMEngineClient(EngineClient):
assert
(
prompt
is
not
None
and
pooling_params
is
not
None
and
request_id
is
not
None
)
return
self
.
_process_request
(
prompt
,
pooling_params
,
request_id
,
lora_request
,
trace_headers
,
None
,
priority
)
return
cast
(
AsyncGenerator
[
EmbeddingRequestOutput
,
None
],
self
.
_process_request
(
prompt
,
pooling_params
,
request_id
,
lora_request
,
trace_headers
,
priority
=
priority
))
async
def
_process_request
(
self
,
...
...
@@ -543,7 +548,9 @@ class MQLLMEngineClient(EngineClient):
build_guided_decoding_logits_processor_async
(
sampling_params
=
params
,
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
),
default_guided_backend
=
self
.
decoding_config
.
guided_decoding_backend
default_guided_backend
=
(
self
.
decoding_config
.
guided_decoding_backend
if
self
.
decoding_config
else
DecodingConfig
.
guided_decoding_backend
),
)
# 1) Create output queue for this requests.
...
...
vllm/engine/multiprocessing/engine.py
View file @
776dbd74
...
...
@@ -73,11 +73,9 @@ class MQLLMEngine:
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs
=
True
kwargs
[
'
use_cached_outputs
'
]
=
True
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
,
use_cached_outputs
=
use_cached_outputs
)
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
)
self
.
log_requests
=
log_requests
self
.
use_async_sockets
=
use_async_sockets
...
...
vllm/engine/output_processor/multi_step.py
View file @
776dbd74
import
functools
from
typing
import
Callable
,
List
from
typing
import
Callable
,
List
,
cast
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.output_processor.interfaces
import
(
...
...
@@ -9,8 +9,10 @@ from vllm.engine.output_processor.single_step import (
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
CompletionSequenceGroupOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
Counter
...
...
@@ -57,6 +59,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"""
for
output
in
outputs
:
# Concatenate single-step prompt logprob processing results.
assert
isinstance
(
output
,
CompletionSequenceGroupOutput
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
@
staticmethod
...
...
@@ -100,8 +103,18 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Beam search not supported in multi-step decoding."
)
seq
=
seqs
[
0
]
seq_id
=
seq
.
seq_id
assert
all
(
[
seq_id
==
output
.
samples
[
0
].
parent_seq_id
for
output
in
outputs
])
# This method is defined in the more generic
# SequenceGroupOutputProcessor, but here we assume that the outputs are
# of a more specific type.
assert
all
([
isinstance
(
output
,
CompletionSequenceGroupOutput
)
for
output
in
outputs
])
compl_outputs
=
cast
(
List
[
CompletionSequenceGroupOutput
],
outputs
)
assert
all
([
seq_id
==
output
.
samples
[
0
].
parent_seq_id
for
output
in
compl_outputs
])
if
is_async
:
# Async case: We process tokens one by one. Here, we know the token
...
...
@@ -113,7 +126,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Since there's only one sequence per sequence group,
# we can take the first sample.
samples
=
[
output
.
samples
[
0
]
for
output
in
outputs
]
samples
=
[
output
.
samples
[
0
]
for
output
in
compl_
outputs
]
# entries in sample tokens may be invalid (eg. due to spec decode
# rejecting tokens).
...
...
vllm/engine/output_processor/single_step.py
View file @
776dbd74
...
...
@@ -6,8 +6,9 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor
)
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.utils
import
Counter
...
...
@@ -16,7 +17,7 @@ logger = init_logger(__name__)
def
single_step_process_prompt_logprob
(
sg_output_proc
:
SequenceGroupOutputProcessor
,
seq_group
:
SequenceGroup
,
output
:
SequenceGroupOutput
)
->
None
:
output
:
Completion
SequenceGroupOutput
)
->
None
:
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
for a given step.
...
...
@@ -106,6 +107,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
"""
assert
len
(
outputs
)
==
1
,
(
"Single step should only has 1 output."
)
output
=
outputs
[
0
]
assert
isinstance
(
output
,
CompletionSequenceGroupOutput
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
...
...
vllm/engine/output_processor/stop_checker.py
View file @
776dbd74
...
...
@@ -57,7 +57,7 @@ class StopChecker:
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
()
if
last_token_id
in
sampling_params
.
stop_token_ids
:
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
())
:
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
...
...
@@ -92,7 +92,7 @@ class StopChecker:
Returns the stop string if matched or else None.
"""
if
not
new_char_count
:
if
not
new_char_count
or
not
sampling_params
.
stop
:
return
None
for
stop_str
in
sampling_params
.
stop
:
...
...
vllm/engine/output_processor/util.py
View file @
776dbd74
from
typing
import
List
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
typing
import
cast
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
Pooler
Output
,
SequenceGroupOutput
from
vllm.sequence
import
CompletionSequenceGroup
Output
,
SequenceGroupOutput
def
create_output_by_sequence_group
(
outputs
:
GenericSequence
[
Union
[
SamplerOutput
,
PoolerOutput
]
],
outputs
:
GenericSequence
[
SamplerOutput
],
num_seq_groups
:
int
)
->
List
[
List
[
SequenceGroupOutput
]]:
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group
:
List
[
List
[
SequenceGroupOutput
]]
=
[
output_by_sequence_group
:
List
[
List
[
Completion
SequenceGroupOutput
]]
=
[
[]
for
_
in
range
(
num_seq_groups
)
]
for
step
in
outputs
:
sequence_group_output
:
CompletionSequenceGroupOutput
for
i
,
sequence_group_output
in
enumerate
(
step
):
output_by_sequence_group
[
i
].
append
(
sequence_group_output
)
return
output_by_sequence_group
# Cast to the more generic type that CompletionSequenceGroupOutput
# inherits from.
return
cast
(
List
[
List
[
SequenceGroupOutput
]],
output_by_sequence_group
)
vllm/inputs/parse.py
View file @
776dbd74
from
typing
import
List
,
Literal
,
Sequence
,
TypedDict
,
Union
,
overload
from
typing
import
List
,
Literal
,
Sequence
,
TypedDict
,
Union
,
cast
,
overload
from
typing_extensions
import
TypeIs
...
...
@@ -44,13 +44,16 @@ def parse_and_batch_prompt(
if
is_list_of
(
prompt
,
str
):
# case 2: array of strings
prompt
=
cast
(
List
[
str
],
prompt
)
return
[
ParsedText
(
content
=
elem
,
is_tokens
=
False
)
for
elem
in
prompt
]
if
is_list_of
(
prompt
,
int
):
# case 3: array of tokens
prompt
=
cast
(
List
[
int
],
prompt
)
return
[
ParsedTokens
(
content
=
prompt
,
is_tokens
=
True
)]
if
is_list_of
(
prompt
,
list
):
prompt
=
cast
(
List
[
List
[
int
]],
prompt
)
if
len
(
prompt
[
0
])
==
0
:
raise
ValueError
(
"please provide at least one prompt"
)
...
...
vllm/model_executor/layers/sampler.py
View file @
776dbd74
...
...
@@ -4,7 +4,7 @@ import warnings
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
from
math
import
inf
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
msgspec
import
torch
...
...
@@ -117,12 +117,15 @@ class SamplerOutput(
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time
:
Optional
[
float
]
=
None
def
__getitem__
(
self
,
idx
:
int
):
def
__getitem__
(
self
,
idx
:
int
)
->
CompletionSequenceGroupOutput
:
return
self
.
outputs
[
idx
]
def
__setitem__
(
self
,
idx
:
int
,
value
):
self
.
outputs
[
idx
]
=
value
def
__iter__
(
self
)
->
Iterator
[
CompletionSequenceGroupOutput
]:
return
iter
(
self
.
outputs
)
def
__len__
(
self
):
return
len
(
self
.
outputs
)
...
...
vllm/outputs.py
View file @
776dbd74
...
...
@@ -4,6 +4,7 @@ from typing import List, Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
vllm.inputs
import
PromptType
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
...
...
@@ -92,7 +93,7 @@ class RequestOutput:
def
__init__
(
self
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
PromptType
],
prompt_token_ids
:
Optional
[
List
[
int
]],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
outputs
:
List
[
CompletionOutput
],
...
...
vllm/sequence.py
View file @
776dbd74
...
...
@@ -788,7 +788,7 @@ class SequenceGroup:
assert
num_lookahead_slots
+
1
==
num_scheduler_steps
or
is_prefill
self
.
init_multi_step
(
num_steps
=
num_lookahead_slots
+
1
)
def
get_last_latency
(
self
,
now
:
float
)
->
Optional
[
float
]
:
def
get_last_latency
(
self
,
now
:
float
)
->
float
:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
if
self
.
is_prefill
():
...
...
@@ -1198,7 +1198,7 @@ class PoolerOutput(
spec_decode_worker_metrics
:
Optional
[
SpecDecodeWorkerMetrics
]
=
None
def
__getitem__
(
self
,
idx
:
int
):
def
__getitem__
(
self
,
idx
:
int
)
->
EmbeddingSequenceGroupOutput
:
return
self
.
outputs
[
idx
]
def
__setitem__
(
self
,
idx
:
int
,
value
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment