Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf069aa8
Unverified
Commit
cf069aa8
authored
Mar 03, 2025
by
Harry Mellor
Committed by
GitHub
Mar 02, 2025
Browse files
Update deprecated Python 3.8 typing (#13971)
parent
bf33700e
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
166 additions
and
173 deletions
+166
-173
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+5
-5
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+5
-5
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+3
-4
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+9
-9
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+10
-10
vllm/v1/outputs.py
vllm/v1/outputs.py
+10
-10
vllm/v1/request.py
vllm/v1/request.py
+12
-12
vllm/v1/sample/metadata.py
vllm/v1/sample/metadata.py
+5
-5
vllm/v1/sample/ops/penalties.py
vllm/v1/sample/ops/penalties.py
+5
-7
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+5
-5
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+3
-4
vllm/v1/stats/common.py
vllm/v1/stats/common.py
+9
-9
vllm/v1/utils.py
vllm/v1/utils.py
+10
-10
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+2
-4
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+31
-31
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+17
-17
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+2
-2
vllm/v1/worker/lora_model_runner_mixin.py
vllm/v1/worker/lora_model_runner_mixin.py
+8
-9
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+12
-12
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+3
-3
No files found.
vllm/v1/executor/abstract.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
typing
import
List
,
Type
,
Union
from
typing
import
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -22,8 +22,8 @@ class Executor(ExecutorBase):
...
@@ -22,8 +22,8 @@ class Executor(ExecutorBase):
For methods shared by v0 and v1, define them in ExecutorBase"""
For methods shared by v0 and v1, define them in ExecutorBase"""
@
staticmethod
@
staticmethod
def
get_class
(
vllm_config
:
VllmConfig
)
->
T
ype
[
"Executor"
]:
def
get_class
(
vllm_config
:
VllmConfig
)
->
t
ype
[
"Executor"
]:
executor_class
:
T
ype
[
Executor
]
executor_class
:
t
ype
[
Executor
]
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
distributed_executor_backend
=
(
distributed_executor_backend
=
(
parallel_config
.
distributed_executor_backend
)
parallel_config
.
distributed_executor_backend
)
...
@@ -53,7 +53,7 @@ class Executor(ExecutorBase):
...
@@ -53,7 +53,7 @@ class Executor(ExecutorBase):
return
executor_class
return
executor_class
def
initialize_from_config
(
self
,
def
initialize_from_config
(
self
,
kv_cache_configs
:
L
ist
[
KVCacheConfig
])
->
None
:
kv_cache_configs
:
l
ist
[
KVCacheConfig
])
->
None
:
"""
"""
Initialize the KV caches and begin the model execution loop of the
Initialize the KV caches and begin the model execution loop of the
underlying workers.
underlying workers.
...
@@ -69,7 +69,7 @@ class Executor(ExecutorBase):
...
@@ -69,7 +69,7 @@ class Executor(ExecutorBase):
# operators can be applied to all workers.
# operators can be applied to all workers.
return
min
(
output
)
return
min
(
output
)
def
get_kv_cache_specs
(
self
)
->
L
ist
[
KVCacheSpec
]:
def
get_kv_cache_specs
(
self
)
->
l
ist
[
KVCacheSpec
]:
output
=
self
.
collective_rpc
(
"get_kv_cache_spec"
)
output
=
self
.
collective_rpc
(
"get_kv_cache_spec"
)
return
output
return
output
...
...
vllm/v1/executor/multiproc_executor.py
View file @
cf069aa8
...
@@ -10,7 +10,7 @@ from dataclasses import dataclass
...
@@ -10,7 +10,7 @@ from dataclasses import dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
functools
import
partial
from
functools
import
partial
from
multiprocessing.process
import
BaseProcess
from
multiprocessing.process
import
BaseProcess
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
cloudpickle
import
cloudpickle
import
psutil
import
psutil
...
@@ -77,7 +77,7 @@ class MultiprocExecutor(Executor):
...
@@ -77,7 +77,7 @@ class MultiprocExecutor(Executor):
scheduler_output_handle
=
self
.
rpc_broadcast_mq
.
export_handle
()
scheduler_output_handle
=
self
.
rpc_broadcast_mq
.
export_handle
()
# Create workers
# Create workers
self
.
workers
:
L
ist
[
WorkerProcHandle
]
=
[]
self
.
workers
:
l
ist
[
WorkerProcHandle
]
=
[]
for
rank
in
range
(
self
.
world_size
):
for
rank
in
range
(
self
.
world_size
):
worker
=
WorkerProc
.
make_worker_process
(
self
.
vllm_config
,
rank
,
worker
=
WorkerProc
.
make_worker_process
(
self
.
vllm_config
,
rank
,
rank
,
rank
,
...
@@ -94,8 +94,8 @@ class MultiprocExecutor(Executor):
...
@@ -94,8 +94,8 @@ class MultiprocExecutor(Executor):
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
],
method
:
Union
[
str
,
Callable
],
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
T
uple
=
(),
args
:
t
uple
=
(),
kwargs
:
Optional
[
D
ict
]
=
None
)
->
L
ist
[
Any
]:
kwargs
:
Optional
[
d
ict
]
=
None
)
->
l
ist
[
Any
]:
start_time
=
time
.
monotonic
()
start_time
=
time
.
monotonic
()
kwargs
=
kwargs
or
{}
kwargs
=
kwargs
or
{}
...
@@ -208,7 +208,7 @@ class WorkerProc:
...
@@ -208,7 +208,7 @@ class WorkerProc:
self
.
rank
=
rank
self
.
rank
=
rank
wrapper
=
WorkerWrapperBase
(
vllm_config
=
vllm_config
,
rpc_rank
=
rank
)
wrapper
=
WorkerWrapperBase
(
vllm_config
=
vllm_config
,
rpc_rank
=
rank
)
# TODO: move `init_worker` to executor level as a collective rpc call
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs
:
L
ist
[
D
ict
]
=
[
all_kwargs
:
l
ist
[
d
ict
]
=
[
{}
for
_
in
range
(
vllm_config
.
parallel_config
.
world_size
)
{}
for
_
in
range
(
vllm_config
.
parallel_config
.
world_size
)
]
]
all_kwargs
[
rank
]
=
{
all_kwargs
[
rank
]
=
{
...
...
vllm/v1/kv_cache_interface.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
import
torch
import
torch
...
@@ -74,7 +73,7 @@ class FullAttentionSpec(KVCacheSpecBase):
...
@@ -74,7 +73,7 @@ class FullAttentionSpec(KVCacheSpecBase):
return
cdiv
(
num_tokens
,
self
.
block_size
)
*
self
.
page_size_bytes
return
cdiv
(
num_tokens
,
self
.
block_size
)
*
self
.
page_size_bytes
KVCacheSpec
=
D
ict
[
str
,
KVCacheSpecBase
]
KVCacheSpec
=
d
ict
[
str
,
KVCacheSpecBase
]
@
dataclass
@
dataclass
...
@@ -95,7 +94,7 @@ class KVCacheConfig:
...
@@ -95,7 +94,7 @@ class KVCacheConfig:
"""The number of KV cache blocks"""
"""The number of KV cache blocks"""
num_blocks
:
int
num_blocks
:
int
"""layer_name -> how to initialize KV cache for that layer"""
"""layer_name -> how to initialize KV cache for that layer"""
tensors
:
D
ict
[
str
,
KVCacheTensor
]
tensors
:
d
ict
[
str
,
KVCacheTensor
]
"""
"""
A list of kv-cache groups. Each group includes a set of layers with
A list of kv-cache groups. Each group includes a set of layers with
the same kv-cache spec, and the total page_size of layers inside a group
the same kv-cache spec, and the total page_size of layers inside a group
...
@@ -108,6 +107,6 @@ class KVCacheConfig:
...
@@ -108,6 +107,6 @@ class KVCacheConfig:
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
"""
"""
groups
:
L
ist
[
L
ist
[
str
]]
groups
:
l
ist
[
l
ist
[
str
]]
"""the KVCacheSpec of the model"""
"""the KVCacheSpec of the model"""
kv_cache_spec
:
KVCacheSpec
kv_cache_spec
:
KVCacheSpec
vllm/v1/metrics/loggers.py
View file @
cf069aa8
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
time
import
time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Optional
import
numpy
as
np
import
numpy
as
np
import
prometheus_client
import
prometheus_client
...
@@ -35,8 +35,8 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -35,8 +35,8 @@ class LoggingStatLogger(StatLoggerBase):
self
.
last_log_time
=
now
self
.
last_log_time
=
now
# Tracked stats over current local logging interval.
# Tracked stats over current local logging interval.
self
.
num_prompt_tokens
:
L
ist
[
int
]
=
[]
self
.
num_prompt_tokens
:
l
ist
[
int
]
=
[]
self
.
num_generation_tokens
:
L
ist
[
int
]
=
[]
self
.
num_generation_tokens
:
l
ist
[
int
]
=
[]
# Prefix cache metrics. TODO: Make the interval configurable.
# Prefix cache metrics. TODO: Make the interval configurable.
self
.
prefix_caching_metrics
=
PrefixCachingMetrics
()
self
.
prefix_caching_metrics
=
PrefixCachingMetrics
()
...
@@ -52,7 +52,7 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -52,7 +52,7 @@ class LoggingStatLogger(StatLoggerBase):
self
.
num_generation_tokens
.
append
(
self
.
num_generation_tokens
.
append
(
iteration_stats
.
num_generation_tokens
)
iteration_stats
.
num_generation_tokens
)
def
_get_throughput
(
self
,
tracked_stats
:
L
ist
[
int
],
now
:
float
)
->
float
:
def
_get_throughput
(
self
,
tracked_stats
:
l
ist
[
int
],
now
:
float
)
->
float
:
# Compute summary metrics for tracked stats
# Compute summary metrics for tracked stats
return
float
(
np
.
sum
(
tracked_stats
)
/
(
now
-
self
.
last_log_time
))
return
float
(
np
.
sum
(
tracked_stats
)
/
(
now
-
self
.
last_log_time
))
...
@@ -147,7 +147,7 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -147,7 +147,7 @@ class PrometheusStatLogger(StatLoggerBase):
documentation
=
"Number of generation tokens processed."
,
documentation
=
"Number of generation tokens processed."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_request_success
:
D
ict
[
FinishReason
,
self
.
counter_request_success
:
d
ict
[
FinishReason
,
prometheus_client
.
Counter
]
=
{}
prometheus_client
.
Counter
]
=
{}
counter_request_success_base
=
prometheus_client
.
Counter
(
counter_request_success_base
=
prometheus_client
.
Counter
(
name
=
"vllm:request_success_total"
,
name
=
"vllm:request_success_total"
,
...
@@ -338,14 +338,14 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -338,14 +338,14 @@ class PrometheusStatLogger(StatLoggerBase):
prometheus_client
.
REGISTRY
.
unregister
(
collector
)
prometheus_client
.
REGISTRY
.
unregister
(
collector
)
def
build_buckets
(
mantissa_lst
:
L
ist
[
int
],
max_value
:
int
)
->
L
ist
[
int
]:
def
build_buckets
(
mantissa_lst
:
l
ist
[
int
],
max_value
:
int
)
->
l
ist
[
int
]:
"""
"""
Builds a list of buckets with increasing powers of 10 multiplied by
Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values until the value exceeds the specified maximum.
mantissa values until the value exceeds the specified maximum.
"""
"""
exponent
=
0
exponent
=
0
buckets
:
L
ist
[
int
]
=
[]
buckets
:
l
ist
[
int
]
=
[]
while
True
:
while
True
:
for
m
in
mantissa_lst
:
for
m
in
mantissa_lst
:
value
=
m
*
10
**
exponent
value
=
m
*
10
**
exponent
...
@@ -356,7 +356,7 @@ def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
...
@@ -356,7 +356,7 @@ def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
exponent
+=
1
exponent
+=
1
def
build_1_2_5_buckets
(
max_value
:
int
)
->
L
ist
[
int
]:
def
build_1_2_5_buckets
(
max_value
:
int
)
->
l
ist
[
int
]:
"""
"""
Example:
Example:
>>> build_1_2_5_buckets(100)
>>> build_1_2_5_buckets(100)
...
@@ -365,7 +365,7 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
...
@@ -365,7 +365,7 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
return
build_buckets
([
1
,
2
,
5
],
max_value
)
return
build_buckets
([
1
,
2
,
5
],
max_value
)
def
build_cudagraph_buckets
(
vllm_config
:
VllmConfig
)
->
L
ist
[
int
]:
def
build_cudagraph_buckets
(
vllm_config
:
VllmConfig
)
->
l
ist
[
int
]:
if
not
vllm_config
.
model_config
.
enforce_eager
:
if
not
vllm_config
.
model_config
.
enforce_eager
:
buckets
=
vllm_config
.
compilation_config
.
\
buckets
=
vllm_config
.
compilation_config
.
\
cudagraph_capture_sizes
.
copy
()
cudagraph_capture_sizes
.
copy
()
...
...
vllm/v1/metrics/stats.py
View file @
cf069aa8
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
time
import
time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
from
typing
import
TYPE_CHECKING
,
Optional
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
...
@@ -39,8 +39,8 @@ class SchedulerStats:
...
@@ -39,8 +39,8 @@ class SchedulerStats:
@
dataclass
@
dataclass
class
LoRAStats
:
class
LoRAStats
:
waiting_requests
:
S
et
[
str
]
=
field
(
default_factory
=
set
)
waiting_requests
:
s
et
[
str
]
=
field
(
default_factory
=
set
)
running_requests
:
S
et
[
str
]
=
field
(
default_factory
=
set
)
running_requests
:
s
et
[
str
]
=
field
(
default_factory
=
set
)
@
dataclass
@
dataclass
...
@@ -81,11 +81,11 @@ class IterationStats:
...
@@ -81,11 +81,11 @@ class IterationStats:
self
.
num_generation_tokens
=
0
self
.
num_generation_tokens
=
0
self
.
num_prompt_tokens
=
0
self
.
num_prompt_tokens
=
0
self
.
num_preempted_reqs
=
0
self
.
num_preempted_reqs
=
0
self
.
finished_requests
:
L
ist
[
FinishedRequestStats
]
=
[]
self
.
finished_requests
:
l
ist
[
FinishedRequestStats
]
=
[]
self
.
time_to_first_tokens_iter
:
L
ist
[
float
]
=
[]
self
.
time_to_first_tokens_iter
:
l
ist
[
float
]
=
[]
self
.
time_per_output_tokens_iter
:
L
ist
[
float
]
=
[]
self
.
time_per_output_tokens_iter
:
l
ist
[
float
]
=
[]
self
.
waiting_lora_adapters
:
D
ict
[
str
,
int
]
=
{}
self
.
waiting_lora_adapters
:
d
ict
[
str
,
int
]
=
{}
self
.
running_lora_adapters
:
D
ict
[
str
,
int
]
=
{}
self
.
running_lora_adapters
:
d
ict
[
str
,
int
]
=
{}
def
_time_since
(
self
,
start
:
float
)
->
float
:
def
_time_since
(
self
,
start
:
float
)
->
float
:
"""Calculate an interval relative to this iteration's timestamp."""
"""Calculate an interval relative to this iteration's timestamp."""
...
@@ -132,7 +132,7 @@ class IterationStats:
...
@@ -132,7 +132,7 @@ class IterationStats:
if
num_new_generation_tokens
>
0
:
if
num_new_generation_tokens
>
0
:
req_stats
.
last_token_ts
=
engine_core_timestamp
req_stats
.
last_token_ts
=
engine_core_timestamp
def
update_from_events
(
self
,
req_id
:
str
,
events
:
L
ist
[
"EngineCoreEvent"
],
def
update_from_events
(
self
,
req_id
:
str
,
events
:
l
ist
[
"EngineCoreEvent"
],
is_prefilling
:
bool
,
req_stats
:
RequestStateStats
,
is_prefilling
:
bool
,
req_stats
:
RequestStateStats
,
lora_stats
:
Optional
[
LoRAStats
]):
lora_stats
:
Optional
[
LoRAStats
]):
# Avoid circular dependency
# Avoid circular dependency
...
@@ -185,7 +185,7 @@ class LoRARequestStates:
...
@@ -185,7 +185,7 @@ class LoRARequestStates:
"""Per-LoRA request state stats."""
"""Per-LoRA request state stats."""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
lora_name_to_stats
:
D
ict
[
str
,
LoRAStats
]
=
{}
self
.
lora_name_to_stats
:
d
ict
[
str
,
LoRAStats
]
=
{}
def
get_stats
(
self
,
req_state
:
'RequestState'
)
->
Optional
[
LoRAStats
]:
def
get_stats
(
self
,
req_state
:
'RequestState'
)
->
Optional
[
LoRAStats
]:
if
req_state
.
lora_name
is
None
:
if
req_state
.
lora_name
is
None
:
...
...
vllm/v1/outputs.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
from
typing
import
NamedTuple
,
Optional
import
torch
import
torch
...
@@ -9,11 +9,11 @@ import torch
...
@@ -9,11 +9,11 @@ import torch
class
LogprobsLists
(
NamedTuple
):
class
LogprobsLists
(
NamedTuple
):
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids
:
L
ist
[
L
ist
[
int
]]
logprob_token_ids
:
l
ist
[
l
ist
[
int
]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
logprobs
:
L
ist
[
L
ist
[
float
]]
logprobs
:
l
ist
[
l
ist
[
float
]]
# [num_reqs]
# [num_reqs]
sampled_token_ranks
:
L
ist
[
int
]
sampled_token_ranks
:
l
ist
[
int
]
def
slice
(
self
,
start
:
int
,
end
:
int
):
def
slice
(
self
,
start
:
int
,
end
:
int
):
return
LogprobsLists
(
return
LogprobsLists
(
...
@@ -52,23 +52,23 @@ class SamplerOutput:
...
@@ -52,23 +52,23 @@ class SamplerOutput:
# ModelRunnerOutput is serialized and sent to the scheduler process.
# ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use
L
ist instead.
# This is expensive for torch.Tensor so prefer to use
l
ist instead.
@
dataclass
@
dataclass
class
ModelRunnerOutput
:
class
ModelRunnerOutput
:
# [num_reqs]
# [num_reqs]
req_ids
:
L
ist
[
str
]
req_ids
:
l
ist
[
str
]
# req_id -> index
# req_id -> index
req_id_to_index
:
D
ict
[
str
,
int
]
req_id_to_index
:
d
ict
[
str
,
int
]
# num_reqs x num_generated_tokens
# num_reqs x num_generated_tokens
# num_generated_tokens is the number of tokens
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
# each request due to speculative/jump decoding.
sampled_token_ids
:
L
ist
[
L
ist
[
int
]]
sampled_token_ids
:
l
ist
[
l
ist
[
int
]]
# num_reqs x num_spec_tokens
# num_reqs x num_spec_tokens
spec_token_ids
:
Optional
[
L
ist
[
L
ist
[
int
]]]
spec_token_ids
:
Optional
[
l
ist
[
l
ist
[
int
]]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
...
@@ -79,4 +79,4 @@ class ModelRunnerOutput:
...
@@ -79,4 +79,4 @@ class ModelRunnerOutput:
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
# [prompt_len]
prompt_logprobs_dict
:
D
ict
[
str
,
Optional
[
LogprobsTensors
]]
prompt_logprobs_dict
:
d
ict
[
str
,
Optional
[
LogprobsTensors
]]
vllm/v1/request.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
enum
import
enum
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
...
@@ -20,10 +20,10 @@ class Request:
...
@@ -20,10 +20,10 @@ class Request:
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
str
],
prompt_token_ids
:
L
ist
[
int
],
prompt_token_ids
:
l
ist
[
int
],
multi_modal_inputs
:
Optional
[
L
ist
[
"MultiModalKwargs"
]],
multi_modal_inputs
:
Optional
[
l
ist
[
"MultiModalKwargs"
]],
multi_modal_hashes
:
Optional
[
L
ist
[
str
]],
multi_modal_hashes
:
Optional
[
l
ist
[
str
]],
multi_modal_placeholders
:
Optional
[
L
ist
[
"PlaceholderRange"
]],
multi_modal_placeholders
:
Optional
[
l
ist
[
"PlaceholderRange"
]],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
eos_token_id
:
Optional
[
int
],
eos_token_id
:
Optional
[
int
],
arrival_time
:
float
,
arrival_time
:
float
,
...
@@ -36,7 +36,7 @@ class Request:
...
@@ -36,7 +36,7 @@ class Request:
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
status
=
RequestStatus
.
WAITING
self
.
status
=
RequestStatus
.
WAITING
self
.
events
:
L
ist
[
EngineCoreEvent
]
=
[]
self
.
events
:
l
ist
[
EngineCoreEvent
]
=
[]
self
.
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
self
.
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
assert
sampling_params
.
max_tokens
is
not
None
assert
sampling_params
.
max_tokens
is
not
None
self
.
max_tokens
=
sampling_params
.
max_tokens
self
.
max_tokens
=
sampling_params
.
max_tokens
...
@@ -44,15 +44,15 @@ class Request:
...
@@ -44,15 +44,15 @@ class Request:
self
.
prompt
=
prompt
self
.
prompt
=
prompt
self
.
prompt_token_ids
=
prompt_token_ids
self
.
prompt_token_ids
=
prompt_token_ids
self
.
num_prompt_tokens
=
len
(
self
.
prompt_token_ids
)
self
.
num_prompt_tokens
=
len
(
self
.
prompt_token_ids
)
self
.
_output_token_ids
:
L
ist
[
int
]
=
[]
self
.
_output_token_ids
:
l
ist
[
int
]
=
[]
self
.
_all_token_ids
:
L
ist
[
int
]
=
self
.
prompt_token_ids
.
copy
()
self
.
_all_token_ids
:
l
ist
[
int
]
=
self
.
prompt_token_ids
.
copy
()
self
.
spec_token_ids
:
L
ist
[
int
]
=
[]
self
.
spec_token_ids
:
l
ist
[
int
]
=
[]
self
.
num_computed_tokens
=
0
self
.
num_computed_tokens
=
0
# Multi-modal related
# Multi-modal related
self
.
mm_positions
=
multi_modal_placeholders
or
[]
self
.
mm_positions
=
multi_modal_placeholders
or
[]
self
.
mm_inputs
=
multi_modal_inputs
or
[]
self
.
mm_inputs
=
multi_modal_inputs
or
[]
self
.
mm_hashes
:
L
ist
[
str
]
=
multi_modal_hashes
or
[]
self
.
mm_hashes
:
l
ist
[
str
]
=
multi_modal_hashes
or
[]
# Sanity check
# Sanity check
assert
len
(
self
.
mm_inputs
)
==
len
(
self
.
mm_positions
)
assert
len
(
self
.
mm_inputs
)
==
len
(
self
.
mm_positions
)
...
@@ -89,7 +89,7 @@ class Request:
...
@@ -89,7 +89,7 @@ class Request:
EngineCoreEvent
.
new_event
(
EngineCoreEventType
.
SCHEDULED
,
EngineCoreEvent
.
new_event
(
EngineCoreEventType
.
SCHEDULED
,
timestamp
))
timestamp
))
def
take_events
(
self
)
->
Optional
[
L
ist
[
EngineCoreEvent
]]:
def
take_events
(
self
)
->
Optional
[
l
ist
[
EngineCoreEvent
]]:
if
not
self
.
events
:
if
not
self
.
events
:
return
None
return
None
events
,
self
.
events
=
self
.
events
,
[]
events
,
self
.
events
=
self
.
events
,
[]
...
@@ -97,7 +97,7 @@ class Request:
...
@@ -97,7 +97,7 @@ class Request:
def
append_output_token_ids
(
def
append_output_token_ids
(
self
,
self
,
token_ids
:
Union
[
int
,
L
ist
[
int
]],
token_ids
:
Union
[
int
,
l
ist
[
int
]],
)
->
None
:
)
->
None
:
if
isinstance
(
token_ids
,
int
):
if
isinstance
(
token_ids
,
int
):
token_ids
=
[
token_ids
]
token_ids
=
[
token_ids
]
...
...
vllm/v1/sample/metadata.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Optional
import
torch
import
torch
...
@@ -17,7 +17,7 @@ class SamplingMetadata:
...
@@ -17,7 +17,7 @@ class SamplingMetadata:
top_k
:
Optional
[
torch
.
Tensor
]
top_k
:
Optional
[
torch
.
Tensor
]
min_p
:
Optional
[
torch
.
Tensor
]
min_p
:
Optional
[
torch
.
Tensor
]
generators
:
D
ict
[
int
,
torch
.
Generator
]
generators
:
d
ict
[
int
,
torch
.
Generator
]
# None means no logprobs, 0 means sampled token logprobs only
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs
:
Optional
[
int
]
max_num_logprobs
:
Optional
[
int
]
...
@@ -28,12 +28,12 @@ class SamplingMetadata:
...
@@ -28,12 +28,12 @@ class SamplingMetadata:
presence_penalties
:
torch
.
Tensor
presence_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
output_token_ids
:
L
ist
[
L
ist
[
int
]]
output_token_ids
:
l
ist
[
l
ist
[
int
]]
# req_index -> (min_tokens, stop_token_ids)
# req_index -> (min_tokens, stop_token_ids)
min_tokens
:
D
ict
[
int
,
T
uple
[
int
,
S
et
[
int
]]]
min_tokens
:
d
ict
[
int
,
t
uple
[
int
,
s
et
[
int
]]]
logit_bias
:
L
ist
[
Optional
[
D
ict
[
int
,
float
]]]
logit_bias
:
l
ist
[
Optional
[
d
ict
[
int
,
float
]]]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
# vocab size).
...
...
vllm/v1/sample/ops/penalties.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
List
,
Set
,
Tuple
import
torch
import
torch
from
vllm.model_executor.layers.utils
import
apply_penalties
from
vllm.model_executor.layers.utils
import
apply_penalties
...
@@ -9,13 +7,13 @@ from vllm.utils import is_pin_memory_available, make_tensor_with_pad
...
@@ -9,13 +7,13 @@ from vllm.utils import is_pin_memory_available, make_tensor_with_pad
def
apply_min_token_penalties
(
def
apply_min_token_penalties
(
logits
:
torch
.
Tensor
,
output_token_ids
:
L
ist
[
L
ist
[
int
]],
logits
:
torch
.
Tensor
,
output_token_ids
:
l
ist
[
l
ist
[
int
]],
min_tokens
:
D
ict
[
int
,
T
uple
[
int
,
S
et
[
int
]]])
->
None
:
min_tokens
:
d
ict
[
int
,
t
uple
[
int
,
s
et
[
int
]]])
->
None
:
"""
"""
Applies minimum token penalty by setting the logits of the stop tokens
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
to -inf.
"""
"""
min_tokens_logits_to_penalize
:
L
ist
[
T
uple
[
int
,
int
]]
=
[]
min_tokens_logits_to_penalize
:
l
ist
[
t
uple
[
int
,
int
]]
=
[]
for
index
,
(
min_token
,
stop_token_ids
)
in
min_tokens
.
items
():
for
index
,
(
min_token
,
stop_token_ids
)
in
min_tokens
.
items
():
if
len
(
output_token_ids
[
index
])
<
min_token
:
if
len
(
output_token_ids
[
index
])
<
min_token
:
for
stop_token_id
in
stop_token_ids
:
for
stop_token_id
in
stop_token_ids
:
...
@@ -30,7 +28,7 @@ def apply_all_penalties(
...
@@ -30,7 +28,7 @@ def apply_all_penalties(
presence_penalties
:
torch
.
Tensor
,
presence_penalties
:
torch
.
Tensor
,
frequency_penalties
:
torch
.
Tensor
,
frequency_penalties
:
torch
.
Tensor
,
repetition_penalties
:
torch
.
Tensor
,
repetition_penalties
:
torch
.
Tensor
,
output_token_ids
:
L
ist
[
L
ist
[
int
]],
output_token_ids
:
l
ist
[
l
ist
[
int
]],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Applies presence, frequency and repetition penalties to the logits.
Applies presence, frequency and repetition penalties to the logits.
...
@@ -43,7 +41,7 @@ def apply_all_penalties(
...
@@ -43,7 +41,7 @@ def apply_all_penalties(
repetition_penalties
)
repetition_penalties
)
def
_convert_to_tensors
(
output_token_ids
:
L
ist
[
L
ist
[
int
]],
vocab_size
:
int
,
def
_convert_to_tensors
(
output_token_ids
:
l
ist
[
l
ist
[
int
]],
vocab_size
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""
"""
Convert the different list data structures to tensors.
Convert the different list data structures to tensors.
...
...
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
Optional
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -54,7 +54,7 @@ class TopKTopPSampler(nn.Module):
...
@@ -54,7 +54,7 @@ class TopKTopPSampler(nn.Module):
def
forward_native
(
def
forward_native
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
generators
:
D
ict
[
int
,
torch
.
Generator
],
generators
:
d
ict
[
int
,
torch
.
Generator
],
k
:
Optional
[
torch
.
Tensor
],
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -66,7 +66,7 @@ class TopKTopPSampler(nn.Module):
...
@@ -66,7 +66,7 @@ class TopKTopPSampler(nn.Module):
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
generators
:
D
ict
[
int
,
torch
.
Generator
],
generators
:
d
ict
[
int
,
torch
.
Generator
],
k
:
Optional
[
torch
.
Tensor
],
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -117,7 +117,7 @@ def apply_top_k_top_p(
...
@@ -117,7 +117,7 @@ def apply_top_k_top_p(
def
random_sample
(
def
random_sample
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
generators
:
D
ict
[
int
,
torch
.
Generator
],
generators
:
d
ict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Randomly sample from the probabilities.
"""Randomly sample from the probabilities.
...
@@ -143,7 +143,7 @@ def flashinfer_sample(
...
@@ -143,7 +143,7 @@ def flashinfer_sample(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
k
:
Optional
[
torch
.
Tensor
],
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
generators
:
D
ict
[
int
,
torch
.
Generator
],
generators
:
d
ict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Sample from the probabilities using FlashInfer.
"""Sample from the probabilities using FlashInfer.
...
...
vllm/v1/sample/rejection_sampler.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
...
@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
else
:
else
:
self
.
forward_method
=
self
.
forward_native
self
.
forward_method
=
self
.
forward_native
def
forward
(
self
,
draft_token_ids
:
L
ist
[
L
ist
[
int
]],
def
forward
(
self
,
draft_token_ids
:
l
ist
[
l
ist
[
int
]],
target_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
if
not
sampling_metadata
.
all_greedy
:
if
not
sampling_metadata
.
all_greedy
:
...
@@ -66,7 +65,7 @@ class RejectionSampler(nn.Module):
...
@@ -66,7 +65,7 @@ class RejectionSampler(nn.Module):
def
flashinfer_sample
(
def
flashinfer_sample
(
self
,
self
,
draft_token_ids
:
L
ist
[
L
ist
[
int
]],
draft_token_ids
:
l
ist
[
l
ist
[
int
]],
target_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
...
@@ -119,7 +118,7 @@ class RejectionSampler(nn.Module):
...
@@ -119,7 +118,7 @@ class RejectionSampler(nn.Module):
# TODO: The following method can be optimized for better performance.
# TODO: The following method can be optimized for better performance.
def
forward_native
(
def
forward_native
(
self
,
self
,
draft_token_ids
:
L
ist
[
L
ist
[
int
]],
draft_token_ids
:
l
ist
[
l
ist
[
int
]],
target_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
...
...
vllm/v1/stats/common.py
View file @
cf069aa8
...
@@ -4,7 +4,7 @@ import time
...
@@ -4,7 +4,7 @@ import time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
dataclasses
import
field
as
dataclass_field
from
dataclasses
import
field
as
dataclass_field
from
enum
import
IntEnum
from
enum
import
IntEnum
from
typing
import
ClassVar
,
Dict
,
List
,
Optional
,
Set
from
typing
import
ClassVar
,
Optional
import
msgspec
import
msgspec
from
msgspec
import
field
as
msgspec_field
from
msgspec
import
field
as
msgspec_field
...
@@ -78,7 +78,7 @@ class RequestStatsUpdate(
...
@@ -78,7 +78,7 @@ class RequestStatsUpdate(
▼
▼
FINISHED (All could go to FINISHED)
FINISHED (All could go to FINISHED)
"""
"""
_VALID_TRANSITIONS
:
ClassVar
[
D
ict
[
Type
,
S
et
[
Type
]]]
=
{
_VALID_TRANSITIONS
:
ClassVar
[
d
ict
[
Type
,
s
et
[
Type
]]]
=
{
Type
.
ARRIVED
:
{
Type
.
ARRIVED
:
{
Type
.
INPUT_PROCESSED
,
Type
.
INPUT_PROCESSED
,
Type
.
FINISHED
,
Type
.
FINISHED
,
...
@@ -140,7 +140,7 @@ class RequestStatsUpdate(
...
@@ -140,7 +140,7 @@ class RequestStatsUpdate(
finish_reason
:
Optional
[
str
]
=
None
finish_reason
:
Optional
[
str
]
=
None
# Non-optional fields for each update type.
# Non-optional fields for each update type.
_REQUIRED_FIELDS
:
ClassVar
[
D
ict
[
Type
,
L
ist
[
str
]]]
=
{
_REQUIRED_FIELDS
:
ClassVar
[
d
ict
[
Type
,
l
ist
[
str
]]]
=
{
Type
.
INPUT_PROCESSED
:
[
"num_prompt_tokens"
,
"sampling_params"
],
Type
.
INPUT_PROCESSED
:
[
"num_prompt_tokens"
,
"sampling_params"
],
Type
.
PREFILLING
:
[
"num_computed_tokens"
,
"num_cached_tokens"
],
Type
.
PREFILLING
:
[
"num_computed_tokens"
,
"num_cached_tokens"
],
Type
.
DETOKENIZED
:
[
"num_new_tokens"
],
Type
.
DETOKENIZED
:
[
"num_new_tokens"
],
...
@@ -218,13 +218,13 @@ class RequestStats:
...
@@ -218,13 +218,13 @@ class RequestStats:
# 2. the request was preempted and resumed. It is equivalent to running
# 2. the request was preempted and resumed. It is equivalent to running
# a prefill of the original prefill tokens + generated output tokens
# a prefill of the original prefill tokens + generated output tokens
# before preemption.
# before preemption.
prefill_start_ts_s_lst
:
L
ist
[
float
]
=
dataclass_field
(
default_factory
=
list
)
prefill_start_ts_s_lst
:
l
ist
[
float
]
=
dataclass_field
(
default_factory
=
list
)
# A list of timestamps when a token is decoded by the engine core.
# A list of timestamps when a token is decoded by the engine core.
decoding_ts_s_lst
:
L
ist
[
float
]
=
dataclass_field
(
default_factory
=
list
)
decoding_ts_s_lst
:
l
ist
[
float
]
=
dataclass_field
(
default_factory
=
list
)
# A sorted list of timestamps for each output token.
# A sorted list of timestamps for each output token.
output_token_ts_s_lst
:
L
ist
[
float
]
=
dataclass_field
(
default_factory
=
list
)
output_token_ts_s_lst
:
l
ist
[
float
]
=
dataclass_field
(
default_factory
=
list
)
# First token's timestamp.
# First token's timestamp.
first_token_ts_s
:
Optional
[
float
]
=
None
first_token_ts_s
:
Optional
[
float
]
=
None
...
@@ -241,7 +241,7 @@ class RequestStats:
...
@@ -241,7 +241,7 @@ class RequestStats:
# metric to measure the impact of preemption other than observation of
# metric to measure the impact of preemption other than observation of
# large P99 TPOT. Ideally we could quantify the impact of preemption by
# large P99 TPOT. Ideally we could quantify the impact of preemption by
# measuring the number of tokens re-computed due to preemption.
# measuring the number of tokens re-computed due to preemption.
preempted_ts_s_lst
:
L
ist
[
float
]
=
dataclass_field
(
default_factory
=
list
)
preempted_ts_s_lst
:
l
ist
[
float
]
=
dataclass_field
(
default_factory
=
list
)
# Timestamp when the request was finished at the engine core.
# Timestamp when the request was finished at the engine core.
finished_ts_s
:
Optional
[
float
]
=
None
finished_ts_s
:
Optional
[
float
]
=
None
...
@@ -308,7 +308,7 @@ class RequestStats:
...
@@ -308,7 +308,7 @@ class RequestStats:
return
self
.
e2e_latency_s
-
self
.
first_token_latency_s
return
self
.
e2e_latency_s
-
self
.
first_token_latency_s
@
property
@
property
def
output_token_latency_s_lst
(
self
)
->
L
ist
[
float
]:
def
output_token_latency_s_lst
(
self
)
->
l
ist
[
float
]:
if
len
(
self
.
output_token_ts_s_lst
)
==
0
:
if
len
(
self
.
output_token_ts_s_lst
)
==
0
:
return
[]
return
[]
latency_s_lst
=
[]
latency_s_lst
=
[]
...
@@ -442,7 +442,7 @@ class EngineCoreStatsSnapshot(
...
@@ -442,7 +442,7 @@ class EngineCoreStatsSnapshot(
default_factory
=
SchedulerStats
)
default_factory
=
SchedulerStats
)
# Per request stats updates.
# Per request stats updates.
requests_stats_updates
:
L
ist
[
RequestStatsUpdate
]
=
msgspec_field
(
requests_stats_updates
:
l
ist
[
RequestStatsUpdate
]
=
msgspec_field
(
default_factory
=
list
)
default_factory
=
list
)
# Engine core's queue stats.
# Engine core's queue stats.
...
...
vllm/v1/utils.py
View file @
cf069aa8
...
@@ -5,8 +5,8 @@ import os
...
@@ -5,8 +5,8 @@ import os
import
weakref
import
weakref
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Generic
,
List
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Optional
,
TypeVar
,
Optional
,
TypeVar
,
Union
,
overload
)
Union
,
overload
)
import
torch
import
torch
...
@@ -24,7 +24,7 @@ T = TypeVar("T")
...
@@ -24,7 +24,7 @@ T = TypeVar("T")
class
ConstantList
(
Generic
[
T
],
Sequence
):
class
ConstantList
(
Generic
[
T
],
Sequence
):
def
__init__
(
self
,
x
:
L
ist
[
T
])
->
None
:
def
__init__
(
self
,
x
:
l
ist
[
T
])
->
None
:
self
.
_x
=
x
self
.
_x
=
x
def
append
(
self
,
item
):
def
append
(
self
,
item
):
...
@@ -57,10 +57,10 @@ class ConstantList(Generic[T], Sequence):
...
@@ -57,10 +57,10 @@ class ConstantList(Generic[T], Sequence):
...
...
@
overload
@
overload
def
__getitem__
(
self
,
s
:
slice
,
/
)
->
L
ist
[
T
]:
def
__getitem__
(
self
,
s
:
slice
,
/
)
->
l
ist
[
T
]:
...
...
def
__getitem__
(
self
,
item
:
Union
[
int
,
slice
])
->
Union
[
T
,
L
ist
[
T
]]:
def
__getitem__
(
self
,
item
:
Union
[
int
,
slice
])
->
Union
[
T
,
l
ist
[
T
]]:
return
self
.
_x
[
item
]
return
self
.
_x
[
item
]
@
overload
@
overload
...
@@ -71,7 +71,7 @@ class ConstantList(Generic[T], Sequence):
...
@@ -71,7 +71,7 @@ class ConstantList(Generic[T], Sequence):
def
__setitem__
(
self
,
s
:
slice
,
value
:
T
,
/
):
def
__setitem__
(
self
,
s
:
slice
,
value
:
T
,
/
):
...
...
def
__setitem__
(
self
,
item
:
Union
[
int
,
slice
],
value
:
Union
[
T
,
L
ist
[
T
]]):
def
__setitem__
(
self
,
item
:
Union
[
int
,
slice
],
value
:
Union
[
T
,
l
ist
[
T
]]):
raise
Exception
(
"Cannot set item in a constant list"
)
raise
Exception
(
"Cannot set item in a constant list"
)
def
__delitem__
(
self
,
item
):
def
__delitem__
(
self
,
item
):
...
@@ -99,7 +99,7 @@ class BackgroundProcHandle:
...
@@ -99,7 +99,7 @@ class BackgroundProcHandle:
output_path
:
str
,
output_path
:
str
,
process_name
:
str
,
process_name
:
str
,
target_fn
:
Callable
,
target_fn
:
Callable
,
process_kwargs
:
D
ict
[
Any
,
Any
],
process_kwargs
:
d
ict
[
Any
,
Any
],
):
):
context
=
get_mp_context
()
context
=
get_mp_context
()
reader
,
writer
=
context
.
Pipe
(
duplex
=
False
)
reader
,
writer
=
context
.
Pipe
(
duplex
=
False
)
...
@@ -146,9 +146,9 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
...
@@ -146,9 +146,9 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
def
bind_kv_cache
(
def
bind_kv_cache
(
kv_caches
:
D
ict
[
str
,
torch
.
Tensor
],
kv_caches
:
d
ict
[
str
,
torch
.
Tensor
],
forward_context
:
D
ict
[
str
,
"Attention"
],
forward_context
:
d
ict
[
str
,
"Attention"
],
runner_kv_caches
:
L
ist
[
torch
.
Tensor
],
runner_kv_caches
:
l
ist
[
torch
.
Tensor
],
)
->
None
:
)
->
None
:
"""
"""
Bind the allocated KV cache to both ModelRunner and forward context so
Bind the allocated KV cache to both ModelRunner and forward context so
...
...
vllm/v1/worker/block_table.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -40,7 +38,7 @@ class BlockTable:
...
@@ -40,7 +38,7 @@ class BlockTable:
def
append_row
(
def
append_row
(
self
,
self
,
block_ids
:
L
ist
[
int
],
block_ids
:
l
ist
[
int
],
row_idx
:
int
,
row_idx
:
int
,
)
->
None
:
)
->
None
:
if
not
block_ids
:
if
not
block_ids
:
...
@@ -50,7 +48,7 @@ class BlockTable:
...
@@ -50,7 +48,7 @@ class BlockTable:
self
.
num_blocks_per_row
[
row_idx
]
+=
num_blocks
self
.
num_blocks_per_row
[
row_idx
]
+=
num_blocks
self
.
block_table_np
[
row_idx
,
start
:
start
+
num_blocks
]
=
block_ids
self
.
block_table_np
[
row_idx
,
start
:
start
+
num_blocks
]
=
block_ids
def
add_row
(
self
,
block_ids
:
L
ist
[
int
],
row_idx
:
int
)
->
None
:
def
add_row
(
self
,
block_ids
:
l
ist
[
int
],
row_idx
:
int
)
->
None
:
self
.
num_blocks_per_row
[
row_idx
]
=
0
self
.
num_blocks_per_row
[
row_idx
]
=
0
self
.
append_row
(
block_ids
,
row_idx
)
self
.
append_row
(
block_ids
,
row_idx
)
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
cf069aa8
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# Datastructures defining an input batch
# Datastructures defining an input batch
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
cast
from
typing
import
TYPE_CHECKING
,
Optional
,
cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -24,16 +24,16 @@ if TYPE_CHECKING:
...
@@ -24,16 +24,16 @@ if TYPE_CHECKING:
class
CachedRequestState
:
class
CachedRequestState
:
req_id
:
str
req_id
:
str
prompt_token_ids
:
L
ist
[
int
]
prompt_token_ids
:
l
ist
[
int
]
prompt
:
Optional
[
str
]
prompt
:
Optional
[
str
]
mm_inputs
:
L
ist
[
MultiModalKwargs
]
mm_inputs
:
l
ist
[
MultiModalKwargs
]
mm_positions
:
L
ist
[
"PlaceholderRange"
]
mm_positions
:
l
ist
[
"PlaceholderRange"
]
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
generator
:
Optional
[
torch
.
Generator
]
generator
:
Optional
[
torch
.
Generator
]
block_ids
:
L
ist
[
int
]
block_ids
:
l
ist
[
int
]
num_computed_tokens
:
int
num_computed_tokens
:
int
output_token_ids
:
L
ist
[
int
]
output_token_ids
:
l
ist
[
int
]
mrope_positions
:
Optional
[
torch
.
Tensor
]
=
None
mrope_positions
:
Optional
[
torch
.
Tensor
]
=
None
mrope_position_delta
:
Optional
[
int
]
=
None
mrope_position_delta
:
Optional
[
int
]
=
None
...
@@ -63,8 +63,8 @@ class InputBatch:
...
@@ -63,8 +63,8 @@ class InputBatch:
self
.
pin_memory
=
pin_memory
self
.
pin_memory
=
pin_memory
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
_req_ids
:
L
ist
[
Optional
[
str
]]
=
[]
self
.
_req_ids
:
l
ist
[
Optional
[
str
]]
=
[]
self
.
req_id_to_index
:
D
ict
[
str
,
int
]
=
{}
self
.
req_id_to_index
:
d
ict
[
str
,
int
]
=
{}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# Find a way to reduce the CPU memory usage.
...
@@ -106,8 +106,8 @@ class InputBatch:
...
@@ -106,8 +106,8 @@ class InputBatch:
device
=
"cpu"
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
)
self
.
temperature_cpu
=
self
.
temperature_cpu_tensor
.
numpy
()
self
.
temperature_cpu
=
self
.
temperature_cpu_tensor
.
numpy
()
self
.
greedy_reqs
:
S
et
[
str
]
=
set
()
self
.
greedy_reqs
:
s
et
[
str
]
=
set
()
self
.
random_reqs
:
S
et
[
str
]
=
set
()
self
.
random_reqs
:
s
et
[
str
]
=
set
()
self
.
top_p
=
torch
.
empty
((
max_num_reqs
,
),
self
.
top_p
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
@@ -117,7 +117,7 @@ class InputBatch:
...
@@ -117,7 +117,7 @@ class InputBatch:
device
=
"cpu"
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
)
self
.
top_p_cpu
=
self
.
top_p_cpu_tensor
.
numpy
()
self
.
top_p_cpu
=
self
.
top_p_cpu_tensor
.
numpy
()
self
.
top_p_reqs
:
S
et
[
str
]
=
set
()
self
.
top_p_reqs
:
s
et
[
str
]
=
set
()
self
.
top_k
=
torch
.
empty
((
max_num_reqs
,
),
self
.
top_k
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
...
@@ -127,7 +127,7 @@ class InputBatch:
...
@@ -127,7 +127,7 @@ class InputBatch:
device
=
"cpu"
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
)
self
.
top_k_cpu
=
self
.
top_k_cpu_tensor
.
numpy
()
self
.
top_k_cpu
=
self
.
top_k_cpu_tensor
.
numpy
()
self
.
top_k_reqs
:
S
et
[
str
]
=
set
()
self
.
top_k_reqs
:
s
et
[
str
]
=
set
()
self
.
min_p
=
torch
.
empty
((
max_num_reqs
,
),
self
.
min_p
=
torch
.
empty
((
max_num_reqs
,
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
@@ -137,7 +137,7 @@ class InputBatch:
...
@@ -137,7 +137,7 @@ class InputBatch:
device
=
"cpu"
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
)
self
.
min_p_cpu
=
self
.
min_p_cpu_tensor
.
numpy
()
self
.
min_p_cpu
=
self
.
min_p_cpu_tensor
.
numpy
()
self
.
min_p_reqs
:
S
et
[
str
]
=
set
()
self
.
min_p_reqs
:
s
et
[
str
]
=
set
()
# Frequency penalty related data structures
# Frequency penalty related data structures
self
.
frequency_penalties
=
torch
.
empty
((
max_num_reqs
,
),
self
.
frequency_penalties
=
torch
.
empty
((
max_num_reqs
,
),
...
@@ -150,7 +150,7 @@ class InputBatch:
...
@@ -150,7 +150,7 @@ class InputBatch:
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
)
self
.
frequency_penalties_cpu
=
\
self
.
frequency_penalties_cpu
=
\
self
.
frequency_penalties_cpu_tensor
.
numpy
()
self
.
frequency_penalties_cpu_tensor
.
numpy
()
self
.
frequency_penalties_reqs
:
S
et
[
str
]
=
set
()
self
.
frequency_penalties_reqs
:
s
et
[
str
]
=
set
()
# Presence penalty related data structures
# Presence penalty related data structures
self
.
presence_penalties
=
torch
.
empty
((
max_num_reqs
,
),
self
.
presence_penalties
=
torch
.
empty
((
max_num_reqs
,
),
...
@@ -162,7 +162,7 @@ class InputBatch:
...
@@ -162,7 +162,7 @@ class InputBatch:
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
)
self
.
presence_penalties_cpu
=
self
.
presence_penalties_cpu_tensor
.
numpy
(
self
.
presence_penalties_cpu
=
self
.
presence_penalties_cpu_tensor
.
numpy
(
)
)
self
.
presence_penalties_reqs
:
S
et
[
str
]
=
set
()
self
.
presence_penalties_reqs
:
s
et
[
str
]
=
set
()
# Repetition penalty related data structures
# Repetition penalty related data structures
self
.
repetition_penalties
=
torch
.
empty
((
max_num_reqs
,
),
self
.
repetition_penalties
=
torch
.
empty
((
max_num_reqs
,
),
...
@@ -175,43 +175,43 @@ class InputBatch:
...
@@ -175,43 +175,43 @@ class InputBatch:
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
)
self
.
repetition_penalties_cpu
=
\
self
.
repetition_penalties_cpu
=
\
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_reqs
:
S
et
[
str
]
=
set
()
self
.
repetition_penalties_reqs
:
s
et
[
str
]
=
set
()
# req_index -> (min_tokens, stop_token_ids)
# req_index -> (min_tokens, stop_token_ids)
self
.
min_tokens
:
D
ict
[
int
,
T
uple
[
int
,
S
et
[
int
]]]
=
{}
self
.
min_tokens
:
d
ict
[
int
,
t
uple
[
int
,
s
et
[
int
]]]
=
{}
# lora related
# lora related
self
.
request_lora_mapping
=
np
.
zeros
((
self
.
max_num_reqs
,
),
self
.
request_lora_mapping
=
np
.
zeros
((
self
.
max_num_reqs
,
),
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
self
.
lora_id_to_request_ids
:
D
ict
[
int
,
S
et
[
str
]]
=
{}
self
.
lora_id_to_request_ids
:
d
ict
[
int
,
s
et
[
str
]]
=
{}
self
.
lora_id_to_lora_request
:
D
ict
[
int
,
LoRARequest
]
=
{}
self
.
lora_id_to_lora_request
:
d
ict
[
int
,
LoRARequest
]
=
{}
# req_index -> generator
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
# generator should not be included in the dictionary.
self
.
generators
:
D
ict
[
int
,
torch
.
Generator
]
=
{}
self
.
generators
:
d
ict
[
int
,
torch
.
Generator
]
=
{}
self
.
num_logprobs
:
D
ict
[
str
,
int
]
=
{}
self
.
num_logprobs
:
d
ict
[
str
,
int
]
=
{}
# NOTE(rob): num_prompt_logprobs only includes reqs
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
# that are currently in the prefill phase.
self
.
num_prompt_logprobs
:
D
ict
[
str
,
int
]
=
{}
self
.
num_prompt_logprobs
:
d
ict
[
str
,
int
]
=
{}
self
.
logit_bias
:
L
ist
[
Optional
[
D
ict
[
int
,
self
.
logit_bias
:
l
ist
[
Optional
[
d
ict
[
int
,
float
]]]
=
[
None
]
*
max_num_reqs
float
]]]
=
[
None
]
*
max_num_reqs
self
.
has_allowed_token_ids
:
S
et
[
str
]
=
set
()
self
.
has_allowed_token_ids
:
s
et
[
str
]
=
set
()
self
.
allowed_token_ids_mask
:
Optional
[
torch
.
Tensor
]
=
None
self
.
allowed_token_ids_mask
:
Optional
[
torch
.
Tensor
]
=
None
self
.
allowed_token_ids_mask_cpu_tensor
:
Optional
[
torch
.
Tensor
]
=
None
self
.
allowed_token_ids_mask_cpu_tensor
:
Optional
[
torch
.
Tensor
]
=
None
self
.
req_output_token_ids
:
L
ist
[
Optional
[
L
ist
[
int
]]]
=
[]
self
.
req_output_token_ids
:
l
ist
[
Optional
[
l
ist
[
int
]]]
=
[]
# This is updated each time the batch constituents change.
# This is updated each time the batch constituents change.
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
@
property
@
property
def
req_ids
(
self
)
->
L
ist
[
str
]:
def
req_ids
(
self
)
->
l
ist
[
str
]:
# None elements should only be present transiently
# None elements should only be present transiently
# while performing state updates to the batch.
# while performing state updates to the batch.
return
cast
(
L
ist
[
str
],
self
.
_req_ids
)
return
cast
(
l
ist
[
str
],
self
.
_req_ids
)
def
add_request
(
def
add_request
(
self
,
self
,
...
@@ -417,7 +417,7 @@ class InputBatch:
...
@@ -417,7 +417,7 @@ class InputBatch:
self
.
logit_bias
[
i2
],
self
.
logit_bias
[
i1
]
self
.
logit_bias
[
i2
],
self
.
logit_bias
[
i1
]
self
.
block_table
.
swap_row
(
i1
,
i2
)
self
.
block_table
.
swap_row
(
i1
,
i2
)
def
condense
(
self
,
empty_req_indices
:
L
ist
[
int
])
->
None
:
def
condense
(
self
,
empty_req_indices
:
l
ist
[
int
])
->
None
:
num_reqs
=
self
.
num_reqs
num_reqs
=
self
.
num_reqs
if
num_reqs
==
0
:
if
num_reqs
==
0
:
# The batched states are empty.
# The batched states are empty.
...
@@ -550,7 +550,7 @@ class InputBatch:
...
@@ -550,7 +550,7 @@ class InputBatch:
frequency_penalties
=
self
.
frequency_penalties
[:
num_reqs
],
frequency_penalties
=
self
.
frequency_penalties
[:
num_reqs
],
presence_penalties
=
self
.
presence_penalties
[:
num_reqs
],
presence_penalties
=
self
.
presence_penalties
[:
num_reqs
],
repetition_penalties
=
self
.
repetition_penalties
[:
num_reqs
],
repetition_penalties
=
self
.
repetition_penalties
[:
num_reqs
],
output_token_ids
=
cast
(
L
ist
[
L
ist
[
int
]],
self
.
req_output_token_ids
),
output_token_ids
=
cast
(
l
ist
[
l
ist
[
int
]],
self
.
req_output_token_ids
),
min_tokens
=
self
.
min_tokens
,
min_tokens
=
self
.
min_tokens
,
no_penalties
=
self
.
no_penalties
,
no_penalties
=
self
.
no_penalties
,
logit_bias
=
self
.
logit_bias
[:
num_reqs
],
logit_bias
=
self
.
logit_bias
[:
num_reqs
],
...
@@ -577,7 +577,7 @@ class InputBatch:
...
@@ -577,7 +577,7 @@ class InputBatch:
def
make_lora_inputs
(
def
make_lora_inputs
(
self
,
num_scheduled_tokens
:
np
.
ndarray
self
,
num_scheduled_tokens
:
np
.
ndarray
)
->
T
uple
[
T
uple
[
int
,
...],
T
uple
[
int
,
...],
S
et
[
LoRARequest
]]:
)
->
t
uple
[
t
uple
[
int
,
...],
t
uple
[
int
,
...],
s
et
[
LoRARequest
]]:
"""
"""
Given the num_scheduled_tokens for each request in the batch, return
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
datastructures used to activate the current LoRAs.
...
@@ -593,7 +593,7 @@ class InputBatch:
...
@@ -593,7 +593,7 @@ class InputBatch:
prompt_lora_mapping
=
tuple
(
req_lora_mapping
)
prompt_lora_mapping
=
tuple
(
req_lora_mapping
)
token_lora_mapping
=
tuple
(
token_lora_mapping
=
tuple
(
req_lora_mapping
.
repeat
(
num_scheduled_tokens
))
req_lora_mapping
.
repeat
(
num_scheduled_tokens
))
active_lora_requests
:
S
et
[
LoRARequest
]
=
set
(
active_lora_requests
:
s
et
[
LoRARequest
]
=
set
(
self
.
lora_id_to_lora_request
.
values
())
self
.
lora_id_to_lora_request
.
values
())
return
prompt_lora_mapping
,
token_lora_mapping
,
active_lora_requests
return
prompt_lora_mapping
,
token_lora_mapping
,
active_lora_requests
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
cf069aa8
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
gc
import
gc
import
time
import
time
import
weakref
import
weakref
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -135,9 +135,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -135,9 +135,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Lazy initialization
# Lazy initialization
# self.model: nn.Module # Set after load_model
# self.model: nn.Module # Set after load_model
self
.
kv_caches
:
L
ist
[
torch
.
Tensor
]
=
[]
self
.
kv_caches
:
l
ist
[
torch
.
Tensor
]
=
[]
# req_id -> (input_id -> encoder_output)
# req_id -> (input_id -> encoder_output)
self
.
encoder_cache
:
D
ict
[
str
,
D
ict
[
int
,
torch
.
Tensor
]]
=
{}
self
.
encoder_cache
:
d
ict
[
str
,
d
ict
[
int
,
torch
.
Tensor
]]
=
{}
# Set up speculative decoding.
# Set up speculative decoding.
self
.
use_spec_decode
=
False
self
.
use_spec_decode
=
False
...
@@ -158,7 +158,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -158,7 +158,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
# Request states.
# Request states.
self
.
requests
:
D
ict
[
str
,
CachedRequestState
]
=
{}
self
.
requests
:
d
ict
[
str
,
CachedRequestState
]
=
{}
# Persistent batch.
# Persistent batch.
self
.
input_batch
=
InputBatch
(
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_num_reqs
=
self
.
max_num_reqs
,
...
@@ -274,7 +274,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -274,7 +274,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# then resubmitted with the same ID. In this case, we treat them as two
# then resubmitted with the same ID. In this case, we treat them as two
# distinct requests - clearing the cached states for the first request
# distinct requests - clearing the cached states for the first request
# and handling the second as a new request.
# and handling the second as a new request.
removed_req_indices
:
L
ist
[
int
]
=
[]
removed_req_indices
:
l
ist
[
int
]
=
[]
for
req_id
in
scheduler_output
.
finished_req_ids
:
for
req_id
in
scheduler_output
.
finished_req_ids
:
req_index
=
self
.
input_batch
.
remove_request
(
req_id
)
req_index
=
self
.
input_batch
.
remove_request
(
req_id
)
if
req_index
is
not
None
:
if
req_index
is
not
None
:
...
@@ -305,7 +305,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -305,7 +305,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
req_index
is
not
None
assert
req_index
is
not
None
removed_req_indices
.
append
(
req_index
)
removed_req_indices
.
append
(
req_index
)
req_ids_to_add
:
L
ist
[
str
]
=
[]
req_ids_to_add
:
l
ist
[
str
]
=
[]
# Add new requests to the cached states.
# Add new requests to the cached states.
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
req_id
=
new_req_data
.
req_id
req_id
=
new_req_data
.
req_id
...
@@ -446,7 +446,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -446,7 +446,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_prepare_inputs
(
def
_prepare_inputs
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
T
uple
[
FlashAttentionMetadata
,
torch
.
Tensor
]:
)
->
t
uple
[
FlashAttentionMetadata
,
torch
.
Tensor
]:
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
assert
total_num_scheduled_tokens
>
0
num_reqs
=
self
.
input_batch
.
num_reqs
num_reqs
=
self
.
input_batch
.
num_reqs
...
@@ -774,8 +774,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -774,8 +774,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
return
# Batch the multi-modal inputs.
# Batch the multi-modal inputs.
mm_inputs
:
L
ist
[
MultiModalKwargs
]
=
[]
mm_inputs
:
l
ist
[
MultiModalKwargs
]
=
[]
req_input_ids
:
L
ist
[
T
uple
[
str
,
int
]]
=
[]
req_input_ids
:
l
ist
[
t
uple
[
str
,
int
]]
=
[]
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
for
input_id
in
encoder_input_ids
:
for
input_id
in
encoder_input_ids
:
...
@@ -819,8 +819,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -819,8 +819,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_gather_encoder_outputs
(
def
_gather_encoder_outputs
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
L
ist
[
torch
.
Tensor
]:
)
->
l
ist
[
torch
.
Tensor
]:
encoder_outputs
:
L
ist
[
torch
.
Tensor
]
=
[]
encoder_outputs
:
l
ist
[
torch
.
Tensor
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
for
req_id
in
self
.
input_batch
.
req_ids
:
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
req_id
]
...
@@ -1022,10 +1022,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1022,10 +1022,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
generate_draft_token_ids
(
def
generate_draft_token_ids
(
self
,
self
,
sampled_token_ids
:
L
ist
[
L
ist
[
int
]],
sampled_token_ids
:
l
ist
[
l
ist
[
int
]],
)
->
L
ist
[
L
ist
[
int
]]:
)
->
l
ist
[
l
ist
[
int
]]:
# TODO(woosuk): Optimize.
# TODO(woosuk): Optimize.
draft_token_ids
:
L
ist
[
L
ist
[
int
]]
=
[]
draft_token_ids
:
l
ist
[
l
ist
[
int
]]
=
[]
for
i
,
sampled_ids
in
enumerate
(
sampled_token_ids
):
for
i
,
sampled_ids
in
enumerate
(
sampled_token_ids
):
num_sampled_ids
=
len
(
sampled_ids
)
num_sampled_ids
=
len
(
sampled_ids
)
if
not
num_sampled_ids
:
if
not
num_sampled_ids
:
...
@@ -1069,12 +1069,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1069,12 +1069,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
D
ict
[
str
,
Optional
[
LogprobsTensors
]]:
)
->
d
ict
[
str
,
Optional
[
LogprobsTensors
]]:
num_prompt_logprobs_dict
=
self
.
input_batch
.
num_prompt_logprobs
num_prompt_logprobs_dict
=
self
.
input_batch
.
num_prompt_logprobs
if
not
num_prompt_logprobs_dict
:
if
not
num_prompt_logprobs_dict
:
return
{}
return
{}
prompt_logprobs_dict
:
D
ict
[
str
,
Optional
[
LogprobsTensors
]]
=
{}
prompt_logprobs_dict
:
d
ict
[
str
,
Optional
[
LogprobsTensors
]]
=
{}
# Since prompt logprobs are a rare feature, prioritize simple,
# Since prompt logprobs are a rare feature, prioritize simple,
# maintainable loop over optimal performance.
# maintainable loop over optimal performance.
...
@@ -1365,7 +1365,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1365,7 +1365,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"Hybrid models with more than one KV cache type are not "
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
"supported yet."
)
kv_caches
:
D
ict
[
str
,
torch
.
Tensor
]
=
{}
kv_caches
:
d
ict
[
str
,
torch
.
Tensor
]
=
{}
for
layer_name
,
layer_spec
in
kv_cache_config
.
kv_cache_spec
.
items
():
for
layer_name
,
layer_spec
in
kv_cache_config
.
kv_cache_spec
.
items
():
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
...
...
vllm/v1/worker/gpu_worker.py
View file @
cf069aa8
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
"""A GPU worker class."""
"""A GPU worker class."""
import
gc
import
gc
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Optional
,
Set
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -243,7 +243,7 @@ class Worker(WorkerBase):
...
@@ -243,7 +243,7 @@ class Worker(WorkerBase):
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_runner
.
remove_lora
(
lora_id
)
return
self
.
model_runner
.
remove_lora
(
lora_id
)
def
list_loras
(
self
)
->
S
et
[
int
]:
def
list_loras
(
self
)
->
s
et
[
int
]:
return
self
.
model_runner
.
list_loras
()
return
self
.
model_runner
.
list_loras
()
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
...
...
vllm/v1/worker/lora_model_runner_mixin.py
View file @
cf069aa8
...
@@ -4,7 +4,6 @@ Define LoRA functionality mixin for model runners.
...
@@ -4,7 +4,6 @@ Define LoRA functionality mixin for model runners.
"""
"""
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Set
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -57,9 +56,9 @@ class LoRAModelRunnerMixin:
...
@@ -57,9 +56,9 @@ class LoRAModelRunnerMixin:
)
)
return
self
.
lora_manager
.
create_lora_manager
(
model
)
return
self
.
lora_manager
.
create_lora_manager
(
model
)
def
_set_active_loras
(
self
,
prompt_lora_mapping
:
T
uple
[
int
,
...],
def
_set_active_loras
(
self
,
prompt_lora_mapping
:
t
uple
[
int
,
...],
token_lora_mapping
:
T
uple
[
int
,
...],
token_lora_mapping
:
t
uple
[
int
,
...],
lora_requests
:
S
et
[
LoRARequest
])
->
None
:
lora_requests
:
s
et
[
LoRARequest
])
->
None
:
if
not
self
.
lora_manager
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
raise
RuntimeError
(
"LoRA is not enabled."
)
...
@@ -74,10 +73,10 @@ class LoRAModelRunnerMixin:
...
@@ -74,10 +73,10 @@ class LoRAModelRunnerMixin:
def
set_active_loras
(
self
,
input_batch
:
InputBatch
,
def
set_active_loras
(
self
,
input_batch
:
InputBatch
,
num_scheduled_tokens
:
np
.
ndarray
)
->
None
:
num_scheduled_tokens
:
np
.
ndarray
)
->
None
:
prompt_lora_mapping
:
T
uple
[
int
,
...]
# of size input_batch.num_reqs
prompt_lora_mapping
:
t
uple
[
int
,
...]
# of size input_batch.num_reqs
token_lora_mapping
:
T
uple
[
int
,
token_lora_mapping
:
t
uple
[
int
,
...]
# of size np.sum(num_scheduled_tokens)
...]
# of size np.sum(num_scheduled_tokens)
lora_requests
:
S
et
[
LoRARequest
]
lora_requests
:
s
et
[
LoRARequest
]
prompt_lora_mapping
,
token_lora_mapping
,
lora_requests
=
\
prompt_lora_mapping
,
token_lora_mapping
,
lora_requests
=
\
input_batch
.
make_lora_inputs
(
num_scheduled_tokens
)
input_batch
.
make_lora_inputs
(
num_scheduled_tokens
)
return
self
.
_set_active_loras
(
prompt_lora_mapping
,
token_lora_mapping
,
return
self
.
_set_active_loras
(
prompt_lora_mapping
,
token_lora_mapping
,
...
@@ -105,7 +104,7 @@ class LoRAModelRunnerMixin:
...
@@ -105,7 +104,7 @@ class LoRAModelRunnerMixin:
num_scheduled_tokens
)
num_scheduled_tokens
)
# Make dummy lora requests
# Make dummy lora requests
lora_requests
:
S
et
[
LoRARequest
]
=
{
lora_requests
:
s
et
[
LoRARequest
]
=
{
LoRARequest
(
lora_name
=
f
"warmup_
{
lora_id
}
"
,
LoRARequest
(
lora_name
=
f
"warmup_
{
lora_id
}
"
,
lora_int_id
=
lora_id
,
lora_int_id
=
lora_id
,
lora_path
=
"/not/a/real/path"
)
lora_path
=
"/not/a/real/path"
)
...
@@ -143,7 +142,7 @@ class LoRAModelRunnerMixin:
...
@@ -143,7 +142,7 @@ class LoRAModelRunnerMixin:
raise
RuntimeError
(
"LoRA is not enabled."
)
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
pin_adapter
(
lora_id
)
return
self
.
lora_manager
.
pin_adapter
(
lora_id
)
def
list_loras
(
self
)
->
S
et
[
int
]:
def
list_loras
(
self
)
->
s
et
[
int
]:
if
not
self
.
lora_manager
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
list_adapters
()
return
self
.
lora_manager
.
list_adapters
()
\ No newline at end of file
vllm/v1/worker/tpu_model_runner.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
time
import
time
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
cast
from
typing
import
TYPE_CHECKING
,
Optional
,
cast
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
numpy
as
np
import
numpy
as
np
...
@@ -95,13 +95,13 @@ class TPUModelRunner:
...
@@ -95,13 +95,13 @@ class TPUModelRunner:
)
)
# Request states.
# Request states.
self
.
requests
:
D
ict
[
str
,
CachedRequestState
]
=
{}
self
.
requests
:
d
ict
[
str
,
CachedRequestState
]
=
{}
# req_id -> (input_id -> encoder_output)
# req_id -> (input_id -> encoder_output)
self
.
encoder_cache
:
D
ict
[
str
,
D
ict
[
int
,
torch
.
Tensor
]]
=
{}
self
.
encoder_cache
:
d
ict
[
str
,
d
ict
[
int
,
torch
.
Tensor
]]
=
{}
# KV caches for forward pass
# KV caches for forward pass
self
.
kv_caches
:
L
ist
[
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
[]
self
.
kv_caches
:
l
ist
[
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
[]
# Cached torch/numpy tensor
# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
# The pytorch tensor and numpy array share the same buffer.
...
@@ -171,7 +171,7 @@ class TPUModelRunner:
...
@@ -171,7 +171,7 @@ class TPUModelRunner:
# then resubmitted with the same ID. In this case, we treat them as two
# then resubmitted with the same ID. In this case, we treat them as two
# distinct requests - clearing the cached states for the first request
# distinct requests - clearing the cached states for the first request
# and handling the second as a new request.
# and handling the second as a new request.
removed_req_indices
:
L
ist
[
int
]
=
[]
removed_req_indices
:
l
ist
[
int
]
=
[]
for
req_id
in
scheduler_output
.
finished_req_ids
:
for
req_id
in
scheduler_output
.
finished_req_ids
:
req_index
=
self
.
input_batch
.
remove_request
(
req_id
)
req_index
=
self
.
input_batch
.
remove_request
(
req_id
)
if
req_index
is
not
None
:
if
req_index
is
not
None
:
...
@@ -194,7 +194,7 @@ class TPUModelRunner:
...
@@ -194,7 +194,7 @@ class TPUModelRunner:
assert
req_index
is
not
None
assert
req_index
is
not
None
removed_req_indices
.
append
(
req_index
)
removed_req_indices
.
append
(
req_index
)
req_ids_to_add
:
L
ist
[
str
]
=
[]
req_ids_to_add
:
l
ist
[
str
]
=
[]
# Add new requests to the cached states.
# Add new requests to the cached states.
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
req_id
=
new_req_data
.
req_id
req_id
=
new_req_data
.
req_id
...
@@ -453,7 +453,7 @@ class TPUModelRunner:
...
@@ -453,7 +453,7 @@ class TPUModelRunner:
selected_token_ids
=
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
)
selected_token_ids
=
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
)
# Then, let's update the cache state.
# Then, let's update the cache state.
request_seq_lens
:
L
ist
[
T
uple
[
int
,
CachedRequestState
,
int
]]
=
[]
request_seq_lens
:
l
ist
[
t
uple
[
int
,
CachedRequestState
,
int
]]
=
[]
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
assert
req_id
is
not
None
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
...
@@ -473,9 +473,9 @@ class TPUModelRunner:
...
@@ -473,9 +473,9 @@ class TPUModelRunner:
assert
all
(
assert
all
(
req_id
is
not
None
for
req_id
in
req_id
is
not
None
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]),
"req_ids contains None"
self
.
input_batch
.
req_ids
[:
num_reqs
]),
"req_ids contains None"
req_ids
=
cast
(
L
ist
[
str
],
self
.
input_batch
.
req_ids
[:
num_reqs
])
req_ids
=
cast
(
l
ist
[
str
],
self
.
input_batch
.
req_ids
[:
num_reqs
])
prompt_logprobs_dict
:
D
ict
[
str
,
Optional
[
LogprobsTensors
]]
=
{}
prompt_logprobs_dict
:
d
ict
[
str
,
Optional
[
LogprobsTensors
]]
=
{}
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]:
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]:
prompt_logprobs_dict
[
req_id
]
=
None
prompt_logprobs_dict
[
req_id
]
=
None
...
@@ -612,7 +612,7 @@ class TPUModelRunner:
...
@@ -612,7 +612,7 @@ class TPUModelRunner:
"Hybrid models with more than one KV cache type are not "
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
"supported yet."
)
kv_caches
:
D
ict
[
str
,
torch
.
Tensor
]
=
{}
kv_caches
:
d
ict
[
str
,
torch
.
Tensor
]
=
{}
for
layer_name
,
layer_spec
in
kv_cache_config
.
kv_cache_spec
.
items
():
for
layer_name
,
layer_spec
in
kv_cache_config
.
kv_cache_spec
.
items
():
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
...
@@ -649,7 +649,7 @@ class ModelWrapperV1(nn.Module):
...
@@ -649,7 +649,7 @@ class ModelWrapperV1(nn.Module):
self
,
self
,
token_ids
:
torch
.
Tensor
,
token_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
L
ist
[
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
l
ist
[
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model and samples the next token.
"""Executes the forward pass of the model and samples the next token.
...
@@ -667,7 +667,7 @@ class ModelWrapperV1(nn.Module):
...
@@ -667,7 +667,7 @@ class ModelWrapperV1(nn.Module):
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
# the slot_mapping accordingly.
# kv_caches:
L
ist[
T
uple[torch.Tensor, torch.Tensor]]
# kv_caches:
l
ist[
t
uple[torch.Tensor, torch.Tensor]]
num_kv_heads
,
num_blocks
,
block_size
,
_
=
kv_caches
[
0
][
0
].
shape
num_kv_heads
,
num_blocks
,
block_size
,
_
=
kv_caches
[
0
][
0
].
shape
slot_mapping
=
attn_metadata
.
slot_mapping
slot_mapping
=
attn_metadata
.
slot_mapping
slot_mapping
=
slot_mapping
.
flatten
()
slot_mapping
=
slot_mapping
.
flatten
()
...
...
vllm/v1/worker/tpu_worker.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""A TPU worker class."""
"""A TPU worker class."""
import
os
import
os
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Optional
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -103,7 +103,7 @@ class TPUWorker:
...
@@ -103,7 +103,7 @@ class TPUWorker:
self
.
model_runner
=
TPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
self
.
model_runner
=
TPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
def
determine_available_memory
(
self
)
->
int
:
def
determine_available_memory
(
self
)
->
int
:
kv_caches
:
D
ict
[
str
,
torch
.
Tensor
]
=
{}
kv_caches
:
d
ict
[
str
,
torch
.
Tensor
]
=
{}
kv_cache_spec
=
self
.
model_runner
.
get_kv_cache_spec
()
kv_cache_spec
=
self
.
model_runner
.
get_kv_cache_spec
()
for
layer_name
,
layer_spec
in
kv_cache_spec
.
items
():
for
layer_name
,
layer_spec
in
kv_cache_spec
.
items
():
if
isinstance
(
layer_spec
,
FullAttentionSpec
):
if
isinstance
(
layer_spec
,
FullAttentionSpec
):
...
@@ -118,7 +118,7 @@ class TPUWorker:
...
@@ -118,7 +118,7 @@ class TPUWorker:
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
runner_kv_caches
:
L
ist
[
torch
.
Tensor
]
=
[]
runner_kv_caches
:
l
ist
[
torch
.
Tensor
]
=
[]
bind_kv_cache
(
bind_kv_cache
(
kv_caches
,
kv_caches
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
...
...
Prev
1
…
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment