Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
53076d70
Commit
53076d70
authored
Mar 24, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.2' into v0.8.2-ori
parents
322a0be6
9c5c81b0
Changes
219
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
618 additions
and
1118 deletions
+618
-1118
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+1
-2
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+5
-8
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+4
-0
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+34
-18
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+1
-1
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+33
-1
vllm/v1/sample/tpu/__init__.py
vllm/v1/sample/tpu/__init__.py
+0
-0
vllm/v1/sample/tpu/metadata.py
vllm/v1/sample/tpu/metadata.py
+159
-0
vllm/v1/sample/tpu/sampler.py
vllm/v1/sample/tpu/sampler.py
+154
-0
vllm/v1/spec_decode/ngram_proposer.py
vllm/v1/spec_decode/ngram_proposer.py
+13
-4
vllm/v1/structured_output/__init__.py
vllm/v1/structured_output/__init__.py
+3
-1
vllm/v1/structured_output/backend_xgrammar.py
vllm/v1/structured_output/backend_xgrammar.py
+5
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+45
-28
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+2
-2
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+156
-76
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+2
-2
vllm/v1/worker/worker_base.py
vllm/v1/worker/worker_base.py
+1
-1
vllm/worker/openvino_model_runner.py
vllm/worker/openvino_model_runner.py
+0
-372
vllm/worker/openvino_worker.py
vllm/worker/openvino_worker.py
+0
-600
No files found.
vllm/v1/engine/processor.py
View file @
53076d70
...
@@ -120,7 +120,7 @@ class Processor:
...
@@ -120,7 +120,7 @@ class Processor:
if
not
params
.
guided_decoding
or
not
self
.
decoding_config
:
if
not
params
.
guided_decoding
or
not
self
.
decoding_config
:
return
return
supported_backends
=
[
"xgrammar"
]
supported_backends
=
[
"xgrammar"
,
"xgrammar:disable-any-whitespace"
]
engine_level_backend
=
self
.
decoding_config
.
guided_decoding_backend
engine_level_backend
=
self
.
decoding_config
.
guided_decoding_backend
if
engine_level_backend
not
in
supported_backends
:
if
engine_level_backend
not
in
supported_backends
:
raise
ValueError
(
f
"Only
{
supported_backends
}
structured output is "
raise
ValueError
(
f
"Only
{
supported_backends
}
structured output is "
...
@@ -173,7 +173,6 @@ class Processor:
...
@@ -173,7 +173,6 @@ class Processor:
# 3. Apply prompt adapter to prompt token ids if one exists.
# 3. Apply prompt adapter to prompt token ids if one exists.
processed_inputs
:
ProcessorInputs
=
self
.
input_preprocessor
.
preprocess
(
processed_inputs
:
ProcessorInputs
=
self
.
input_preprocessor
.
preprocess
(
prompt
,
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
return_mm_hashes
=
self
.
use_hash
,
return_mm_hashes
=
self
.
use_hash
,
...
...
vllm/v1/executor/abstract.py
View file @
53076d70
...
@@ -62,14 +62,11 @@ class Executor(ExecutorBase):
...
@@ -62,14 +62,11 @@ class Executor(ExecutorBase):
args
=
(
kv_cache_configs
,
))
args
=
(
kv_cache_configs
,
))
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
self
.
collective_rpc
(
"compile_or_warm_up_model"
)
def
determine_available_memory
(
self
)
->
int
:
# in bytes
def
determine_available_memory
(
self
)
->
list
[
int
]
:
# in bytes
output
=
self
.
collective_rpc
(
"determine_available_memory"
)
output
=
self
.
collective_rpc
(
"determine_available_memory"
)
# Since we use a shared centralized controller, we take the minimum
return
output
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return
min
(
output
)
def
get_kv_cache_specs
(
self
)
->
list
[
KVCacheSpec
]:
def
get_kv_cache_specs
(
self
)
->
list
[
dict
[
str
,
KVCacheSpec
]
]
:
output
=
self
.
collective_rpc
(
"get_kv_cache_spec"
)
output
=
self
.
collective_rpc
(
"get_kv_cache_spec"
)
return
output
return
output
...
@@ -95,7 +92,7 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
...
@@ -95,7 +92,7 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
class
ExecutorWithExternalLauncher
(
ExecutorWithExternalLauncherV0
,
Executor
):
class
ExecutorWithExternalLauncher
(
ExecutorWithExternalLauncherV0
,
Executor
):
def
determine_available_memory
(
self
)
->
int
:
# in bytes
def
determine_available_memory
(
self
)
->
list
[
int
]
:
# in bytes
# same as determine_num_available_blocks in v0,
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
# we need to get the min across all ranks.
memory
=
super
().
determine_available_memory
()
memory
=
super
().
determine_available_memory
()
...
@@ -103,4 +100,4 @@ class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
...
@@ -103,4 +100,4 @@ class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
cpu_group
=
get_world_group
().
cpu_group
cpu_group
=
get_world_group
().
cpu_group
memory_tensor
=
torch
.
tensor
([
memory
],
device
=
"cpu"
,
dtype
=
torch
.
int64
)
memory_tensor
=
torch
.
tensor
([
memory
],
device
=
"cpu"
,
dtype
=
torch
.
int64
)
dist
.
all_reduce
(
memory_tensor
,
group
=
cpu_group
,
op
=
dist
.
ReduceOp
.
MIN
)
dist
.
all_reduce
(
memory_tensor
,
group
=
cpu_group
,
op
=
dist
.
ReduceOp
.
MIN
)
return
memory_tensor
.
item
()
return
[
memory_tensor
.
item
()
]
vllm/v1/executor/multiproc_executor.py
View file @
53076d70
...
@@ -5,6 +5,7 @@ import pickle
...
@@ -5,6 +5,7 @@ import pickle
import
signal
import
signal
import
sys
import
sys
import
time
import
time
import
traceback
import
weakref
import
weakref
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
...
@@ -370,6 +371,9 @@ class WorkerProc:
...
@@ -370,6 +371,9 @@ class WorkerProc:
func
=
partial
(
cloudpickle
.
loads
(
method
),
self
.
worker
)
func
=
partial
(
cloudpickle
.
loads
(
method
),
self
.
worker
)
output
=
func
(
*
args
,
**
kwargs
)
output
=
func
(
*
args
,
**
kwargs
)
except
Exception
as
e
:
except
Exception
as
e
:
# Notes have been introduced in python 3.11
if
hasattr
(
e
,
"add_note"
):
e
.
add_note
(
traceback
.
format_exc
())
self
.
worker_response_mq
.
enqueue
(
self
.
worker_response_mq
.
enqueue
(
(
WorkerProc
.
ResponseStatus
.
FAILURE
,
e
))
(
WorkerProc
.
ResponseStatus
.
FAILURE
,
e
))
logger
.
exception
(
"WorkerProc hit an exception: %s"
,
exc_info
=
e
)
logger
.
exception
(
"WorkerProc hit an exception: %s"
,
exc_info
=
e
)
...
...
vllm/v1/kv_cache_interface.py
View file @
53076d70
...
@@ -11,7 +11,7 @@ logger = init_logger(__name__)
...
@@ -11,7 +11,7 @@ logger = init_logger(__name__)
@
dataclass
@
dataclass
class
KVCacheSpec
Base
:
class
KVCacheSpec
:
"""
"""
A base class for specifying the KV cache format of one layer.
A base class for specifying the KV cache format of one layer.
"""
"""
...
@@ -55,7 +55,7 @@ class KVCacheSpecBase:
...
@@ -55,7 +55,7 @@ class KVCacheSpecBase:
@
dataclass
@
dataclass
class
FullAttentionSpec
(
KVCacheSpec
Base
):
class
FullAttentionSpec
(
KVCacheSpec
):
num_kv_heads
:
int
num_kv_heads
:
int
head_size
:
int
head_size
:
int
dtype
:
torch
.
dtype
dtype
:
torch
.
dtype
...
@@ -76,9 +76,6 @@ class FullAttentionSpec(KVCacheSpecBase):
...
@@ -76,9 +76,6 @@ class FullAttentionSpec(KVCacheSpecBase):
return
cdiv
(
num_tokens
,
self
.
block_size
)
*
self
.
page_size_bytes
return
cdiv
(
num_tokens
,
self
.
block_size
)
*
self
.
page_size_bytes
KVCacheSpec
=
dict
[
str
,
KVCacheSpecBase
]
@
dataclass
@
dataclass
class
KVCacheTensor
:
class
KVCacheTensor
:
"""
"""
...
@@ -89,6 +86,18 @@ class KVCacheTensor:
...
@@ -89,6 +86,18 @@ class KVCacheTensor:
size
:
int
# The size of KV cache Tensor in bytes
size
:
int
# The size of KV cache Tensor in bytes
@
dataclass
class
KVCacheGroupSpec
:
"""
Represents a group of model layers that share the same KV cache block table.
These layers are regarded as one layer in the KV cache manager.
"""
# The names of model layers in this group
layer_names
:
list
[
str
]
# The KV cache spec of this manager layer
kv_cache_spec
:
KVCacheSpec
@
dataclass
@
dataclass
class
KVCacheConfig
:
class
KVCacheConfig
:
"""
"""
...
@@ -99,17 +108,24 @@ class KVCacheConfig:
...
@@ -99,17 +108,24 @@ class KVCacheConfig:
"""layer_name -> how to initialize KV cache for that layer"""
"""layer_name -> how to initialize KV cache for that layer"""
tensors
:
dict
[
str
,
KVCacheTensor
]
tensors
:
dict
[
str
,
KVCacheTensor
]
"""
"""
A list of kv-cache groups. Each group includes a set of layers with
The kv cache groups of the model.
the same kv-cache spec, and the total page_size of layers inside a group
The layers in the models are repeated with some patterns, e.g., a model
is same across all groups (as the KVCacheManager only supports allocating
with 10 full attention layers and 20 sliding window attention layers can be
pages of the same size). For example:
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
1. A model only uses full attention: one group with all layers in the model.
The KVCacheManager allocates different block tables for each of the 3 layers
2. (not implemented yet) A model with the same number of full attention
in the pattern, and repeats each of them 10 times to generate the
layers and sliding window attention layers: two groups, one for full
block_table for the 30 layers in the model.
attention layers and one for sliding window attention layers.
Therefore, we can group the layers in the model into 3 groups, each of which
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
contains 10 layers in the model.
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
is shared by all layers.
2. (WIP) A model with 10 full attention layers and 20 sliding window
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
there are 3 groups, each of which represents 10 layers in the model.
"""
"""
groups
:
list
[
list
[
str
]]
kv_cache_groups
:
list
[
KVCacheGroupSpec
]
"""the KVCacheSpec of the model"""
kv_cache_spec
:
KVCacheSpec
vllm/v1/metrics/stats.py
View file @
53076d70
...
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
...
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.engine
import
EngineCoreEvent
,
EngineCoreOutput
,
FinishReason
from
vllm.v1.engine
import
EngineCoreEvent
,
EngineCoreOutput
,
FinishReason
from
vllm.v1.output_processor
import
RequestState
from
vllm.v1.
engine.
output_processor
import
RequestState
@
dataclass
@
dataclass
...
...
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
53076d70
...
@@ -65,6 +65,15 @@ class TopKTopPSampler(nn.Module):
...
@@ -65,6 +65,15 @@ class TopKTopPSampler(nn.Module):
"native implementation of top-p & top-k sampling. For the "
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FlashInfer."
)
"best performance, please install FlashInfer."
)
self
.
forward
=
self
.
forward_native
self
.
forward
=
self
.
forward_native
elif
current_platform
.
is_tpu
():
if
envs
.
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION
:
logger
.
warning
(
"TPU-specific optimization for top-k & top-p sampling are "
"disabled, falling back to PyTorch-native implementation "
"which could be very slow."
)
self
.
forward
=
self
.
forward_native
else
:
self
.
forward
=
self
.
forward_tpu
else
:
else
:
self
.
forward
=
self
.
forward_native
self
.
forward
=
self
.
forward_native
...
@@ -96,6 +105,29 @@ class TopKTopPSampler(nn.Module):
...
@@ -96,6 +105,29 @@ class TopKTopPSampler(nn.Module):
return
random_sample
(
probs
,
generators
)
return
random_sample
(
probs
,
generators
)
return
flashinfer_sample
(
probs
,
k
,
p
,
generators
)
return
flashinfer_sample
(
probs
,
k
,
p
,
generators
)
def
forward_tpu
(
self
,
logits
:
torch
.
Tensor
,
generators
:
dict
[
int
,
torch
.
Generator
],
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# If only top-k is specified, use pytorch's builtin topk op. This leads
# to significant speed up on TPU compared to using apply_top_k_top_p.
if
k
is
not
None
and
p
is
None
:
topk_values
,
topk_indices
=
torch
.
topk
(
logits
,
k
,
dim
=-
1
)
mask
=
torch
.
ones_like
(
logits
,
dtype
=
torch
.
bool
)
mask
.
scatter_
(
-
1
,
topk_indices
,
False
)
logits
.
masked_fill_
(
mask
,
float
(
'-inf'
))
else
:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
random_sample
(
probs
,
generators
)
def
apply_top_k_top_p
(
def
apply_top_k_top_p
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
...
@@ -112,7 +144,7 @@ def apply_top_k_top_p(
...
@@ -112,7 +144,7 @@ def apply_top_k_top_p(
if
k
is
not
None
:
if
k
is
not
None
:
# Apply top-k.
# Apply top-k.
top_k_mask
=
logits_sort
.
size
(
1
)
-
k
.
to
(
torch
.
long
)
top_k_mask
=
logits_sort
.
size
(
1
)
-
k
.
to
(
torch
.
long
)
# shape: B
# Get all the top_k values.
# Get all the top_k values.
top_k_mask
=
logits_sort
.
gather
(
1
,
top_k_mask
.
unsqueeze
(
dim
=
1
))
top_k_mask
=
logits_sort
.
gather
(
1
,
top_k_mask
.
unsqueeze
(
dim
=
1
))
top_k_mask
=
logits_sort
<
top_k_mask
top_k_mask
=
logits_sort
<
top_k_mask
...
...
vllm/v1/sample/tpu/__init__.py
0 → 100644
View file @
53076d70
vllm/v1/sample/tpu/metadata.py
0 → 100644
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
import
torch
import
torch_xla.core.xla_model
as
xm
from
vllm.v1.sample.metadata
import
SamplingMetadata
@
dataclass
class
TPUSupportedSamplingMetadata
:
# This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones.
temperature
:
torch
.
Tensor
min_p
:
torch
.
Tensor
# Still too slow on forward_native!
top_k
:
torch
.
Tensor
=
None
top_p
:
torch
.
Tensor
=
None
# XLA-unfriendly control flow in Sampler
all_greedy
:
bool
=
False
all_random
:
bool
=
False
# Greedy sampling flag for compiling single xla graph.
do_argmax
:
torch
.
Tensor
=
None
# speculation not supported
spec_token_ids
=
None
# Generator not supported by xla
generators
:
dict
[
int
,
torch
.
Generator
]
=
field
(
default_factory
=
lambda
:
dict
())
# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs
=
None
# TODO No penalties for now
no_penalties
:
bool
=
True
prompt_token_ids
=
None
frequency_penalties
=
None
presence_penalties
=
None
repetition_penalties
=
None
# should use tensor
output_token_ids
:
list
[
list
[
int
]]
=
field
(
default_factory
=
lambda
:
list
())
min_tokens
=
None
# impl is not vectorized
logit_bias
:
list
[
Optional
[
dict
[
int
,
float
]]]
=
field
(
default_factory
=
lambda
:
list
())
allowed_token_ids_mask
=
None
bad_words_token_ids
=
None
indices_do_sample
:
torch
.
Tensor
=
None
def
__post_init__
(
self
):
temp
=
self
.
temperature
if
self
.
indices_do_sample
is
None
:
self
.
indices_do_sample
=
torch
.
zeros
(
temp
.
shape
[
0
],
device
=
temp
.
device
,
dtype
=
torch
.
int32
)
if
self
.
do_argmax
is
None
:
self
.
do_argmax
=
torch
.
tensor
(
0
,
dtype
=
torch
.
bool
,
device
=
temp
.
device
)
@
classmethod
def
from_sampling_metadata
(
cls
,
metadata
:
SamplingMetadata
,
padded_do_sample_indices
:
torch
.
Tensor
,
num_do_sample
:
int
,
device
:
torch
.
device
)
->
"TPUSupportedSamplingMetadata"
:
"""
Create an XLA-frienly SamplingMetadata structure. Do so by first
instantiating an object with fixed-sized tensors and then writing the
values in input `metadata`. Do that only for non-None values so that
recompilation is not triggered for optional values (None/torch.Tensor).
In order to handle different sizes for the params that range from 1 up
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
Same thing for `padded_do_sample_indices`, which contains the indices
to be fed to the Sampler, padded to the closest pre-compiled shape.
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
"""
metadata
=
cls
.
_validate_sampling_metadata
(
metadata
)
# NOTE we have to initialize default tensor-based params first and
# skip None values altogether to produce the same xla graph.
num_samples
=
len
(
padded_do_sample_indices
)
do_argmax
=
torch
.
tensor
(
metadata
.
all_greedy
,
dtype
=
torch
.
bool
,
device
=
device
)
new_metadata
=
cls
.
get_default_sampling_params
(
num_samples
,
device
,
indices_do_sample
=
\
padded_do_sample_indices
,
do_argmax
=
do_argmax
)
supported_params
=
\
TPUSupportedSamplingMetadata
.
_get_default_params_values
()
# Copy input non-None values into `new_metadata` fixed-sized tensors.
for
p_name
in
supported_params
:
old_val
=
getattr
(
metadata
,
p_name
)
new_val
=
getattr
(
new_metadata
,
p_name
)
if
isinstance
(
old_val
,
torch
.
Tensor
):
new_val
[:
num_do_sample
]
=
old_val
setattr
(
new_metadata
,
p_name
,
new_val
)
xm
.
mark_step
()
xm
.
wait_device_ops
()
return
new_metadata
@
classmethod
def
get_default_sampling_params
(
cls
,
num_samples
:
int
,
device
:
torch
.
device
,
indices_do_sample
=
None
,
do_argmax
=
None
)
->
"TPUSupportedSamplingMetadata"
:
# As sampling happens on a single traced graph, options
# are "disabled" by having them evaluate to an Identity op.
# Note that initialization is dependent on num_samples.
sampling_metadata_disable_value
=
\
TPUSupportedSamplingMetadata
.
_get_default_params_values
()
init_kwargs
=
dict
()
for
p_name
,
(
default_val
,
dtype
)
in
sampling_metadata_disable_value
.
items
():
default_tensor
=
torch
.
full
((
num_samples
,
),
default_val
,
dtype
=
dtype
,
device
=
device
)
init_kwargs
[
p_name
]
=
default_tensor
return
cls
(
**
init_kwargs
,
indices_do_sample
=
indices_do_sample
,
do_argmax
=
do_argmax
)
@
staticmethod
def
_validate_sampling_metadata
(
sampling_metadata
:
SamplingMetadata
)
->
SamplingMetadata
:
if
sampling_metadata
.
all_greedy
:
# Set to None since #13587. Make sure default isn't overruled.
assert
sampling_metadata
.
temperature
is
None
return
sampling_metadata
@
staticmethod
def
_get_default_params_values
():
return
dict
(
# Since #13587 greedy sampling requires branching off which leads
# to separate graphs. We set temp to noop and handle argmax here.
temperature
=
(
1.0
,
torch
.
float32
),
min_p
=
(
0.0
,
torch
.
float32
),
# strictly disabled for now
# top_k=(-1, torch.int32),
# top_p=(0.0, torch.float32),
# frequency_penalties=(0.0, torch.float32),
# presence_penalties=(0.0, torch.float32),
# repetition_penalties=(0.0, torch.float32),
)
\ No newline at end of file
vllm/v1/sample/tpu/sampler.py
0 → 100644
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
"""Sampler layer implementing TPU supported operations."""
import
torch
import
torch.nn
as
nn
from
vllm.v1.outputs
import
LogprobsTensors
,
SamplerOutput
from
vllm.v1.sample.ops.topk_topp_sampler
import
TopKTopPSampler
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
_SAMPLING_EPS
=
1e-5
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
def
forward
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
TPUSupportedSamplingMetadata
,
)
->
SamplerOutput
:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# Use float32 for the logits.
logits
=
logits
.
to
(
torch
.
float32
)
# Sample the next token.
sampled
=
self
.
sample
(
logits
,
sampling_metadata
)
# Use int32 to reduce the tensor size.
sampled
=
sampled
.
to
(
torch
.
int32
)
# These are GPU tensors.
sampler_output
=
SamplerOutput
(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids
=
sampled
.
unsqueeze
(
-
1
),
logprobs_tensors
=
None
,
)
return
sampler_output
def
apply_temperature
(
self
,
logits
:
torch
.
Tensor
,
temp
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Use in-place division to avoid creating a new tensor.
return
logits
.
div_
(
temp
.
unsqueeze
(
dim
=
1
))
def
greedy_sample
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
logits
.
argmax
(
dim
=-
1
).
view
(
-
1
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
TPUSupportedSamplingMetadata
,
)
->
torch
.
Tensor
:
greedy_sampled
=
self
.
greedy_sample
(
logits
)
assert
sampling_metadata
.
temperature
is
not
None
# Apply temperature.
logits
=
self
.
apply_temperature
(
logits
,
sampling_metadata
.
temperature
)
# Apply min_p.
if
sampling_metadata
.
min_p
is
not
None
:
logits
=
self
.
apply_min_p
(
logits
,
sampling_metadata
.
min_p
)
# Apply top_k and/or top_p.
random_sampled
=
self
.
topk_topp_sampler
(
logits
,
sampling_metadata
.
generators
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_p
,
)
sampled
=
torch
.
where
(
sampling_metadata
.
temperature
<
_SAMPLING_EPS
,
greedy_sampled
,
random_sampled
)
return
sampled
def
compute_logprobs
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
logits
.
log_softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
def
gather_logprobs
(
self
,
logprobs
:
torch
.
Tensor
,
num_logprobs
:
int
,
token_ids
:
torch
.
Tensor
,
)
->
LogprobsTensors
:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logits: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
# Find the topK values.
topk_logprobs
,
topk_indices
=
torch
.
topk
(
logprobs
,
num_logprobs
,
dim
=-
1
)
# Get with the logprob of the prompt or sampled token.
token_ids
=
token_ids
.
unsqueeze
(
-
1
)
token_logprobs
=
logprobs
.
gather
(
-
1
,
token_ids
)
# Compute the ranks of the actual token.
token_ranks
=
(
logprobs
>=
token_logprobs
).
sum
(
-
1
)
# Concatenate together with the topk.
indices
=
torch
.
cat
((
token_ids
,
topk_indices
),
dim
=
1
)
logprobs
=
torch
.
cat
((
token_logprobs
,
topk_logprobs
),
dim
=
1
)
# Use int32 to reduce the tensor size.
indices
=
indices
.
to
(
torch
.
int32
)
return
LogprobsTensors
(
indices
,
logprobs
,
token_ranks
)
def
apply_min_p
(
self
,
logits
:
torch
.
Tensor
,
min_p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Filters logits using adaptive probability thresholding.
"""
# Convert logits to probability distribution
probability_values
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
# Calculate maximum probabilities per sequence
max_probabilities
=
torch
.
amax
(
probability_values
,
dim
=-
1
,
keepdim
=
True
)
# Reshape min_p for broadcasting
adjusted_min_p
=
min_p
.
unsqueeze
(
1
)
*
max_probabilities
# Identify valid tokens using threshold comparison
valid_token_mask
=
probability_values
>=
adjusted_min_p
# Apply mask using boolean indexing (xla friendly)
logits
.
masked_fill_
(
~
valid_token_mask
,
-
float
(
"inf"
))
return
logits
vllm/v1/spec_decode/ngram_proposer.py
View file @
53076d70
...
@@ -10,7 +10,8 @@ class NgramProposer:
...
@@ -10,7 +10,8 @@ class NgramProposer:
def
propose
(
def
propose
(
self
,
self
,
context_token_ids
:
np
.
ndarray
,
context_token_ids
:
np
.
ndarray
,
n
:
int
,
min_n
:
int
,
max_n
:
int
,
k
:
int
,
k
:
int
,
)
->
Optional
[
np
.
ndarray
]:
)
->
Optional
[
np
.
ndarray
]:
"""Proposes the next sequence of tokens based on n-gram pattern
"""Proposes the next sequence of tokens based on n-gram pattern
...
@@ -21,7 +22,8 @@ class NgramProposer:
...
@@ -21,7 +22,8 @@ class NgramProposer:
Args:
Args:
context_token_ids: Numpy array of token IDs representing the
context_token_ids: Numpy array of token IDs representing the
context sequence.
context sequence.
n: Length of the n-gram to match.
min_n: Minimum length of the n-gram to match.
max_n: Maximum length of the n-gram to match.
k: Number of tokens follow the match. If there are less
k: Number of tokens follow the match. If there are less
than k tokens follow the match, we will return
than k tokens follow the match, we will return
the maximum amount of tokens until the end.
the maximum amount of tokens until the end.
...
@@ -32,14 +34,21 @@ class NgramProposer:
...
@@ -32,14 +34,21 @@ class NgramProposer:
None: If no matching n-gram pattern is found.
None: If no matching n-gram pattern is found.
Example:
Example:
If context_token_ids = [1,2,3,4,2,3], n = 2, and k = 4:
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
k = 4:
- The last 3 (= max_n) tokens [4,2,3] cannot find a match.
- The last 2 tokens [2,3] will be matched against the previous
- The last 2 tokens [2,3] will be matched against the previous
4 tokens [1,2,3,4].
4 tokens [1,2,3,4].
- Finding a match of [2,3] would return the tokens that
- Finding a match of [2,3] would return the tokens that
followed that pattern. Here we will return [4,2,3] because
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
we only have three tokens after the match.
"""
"""
return
_find_subarray_kmp
(
context_token_ids
,
n
,
k
)
# TODO(woosuk): Optimize this.
for
n
in
range
(
max_n
,
min_n
-
1
,
-
1
):
result
=
_find_subarray_kmp
(
context_token_ids
,
n
,
k
)
if
result
is
not
None
:
return
result
return
None
@
jit
(
nopython
=
True
)
@
jit
(
nopython
=
True
)
...
...
vllm/v1/structured_output/__init__.py
View file @
53076d70
...
@@ -9,7 +9,6 @@ from vllm.config import VllmConfig
...
@@ -9,7 +9,6 @@ from vllm.config import VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.structured_output.backend_types
import
(
StructuredOutputBackend
,
from
vllm.v1.structured_output.backend_types
import
(
StructuredOutputBackend
,
StructuredOutputGrammar
)
StructuredOutputGrammar
)
from
vllm.v1.structured_output.backend_xgrammar
import
XgrammarBackend
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
import
numpy
as
np
import
numpy
as
np
...
@@ -47,6 +46,9 @@ class StructuredOutputManager:
...
@@ -47,6 +46,9 @@ class StructuredOutputManager:
if
self
.
backend
is
None
:
if
self
.
backend
is
None
:
backend_name
=
request
.
sampling_params
.
guided_decoding
.
backend_name
backend_name
=
request
.
sampling_params
.
guided_decoding
.
backend_name
if
backend_name
==
"xgrammar"
:
if
backend_name
==
"xgrammar"
:
from
vllm.v1.structured_output.backend_xgrammar
import
(
XgrammarBackend
)
self
.
backend
=
XgrammarBackend
(
self
.
vllm_config
)
self
.
backend
=
XgrammarBackend
(
self
.
vllm_config
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
...
vllm/v1/structured_output/backend_xgrammar.py
View file @
53076d70
...
@@ -26,6 +26,9 @@ class XgrammarBackend(StructuredOutputBackend):
...
@@ -26,6 +26,9 @@ class XgrammarBackend(StructuredOutputBackend):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
disable_any_whitespace
=
(
"disable-any-whitespace"
in
vllm_config
.
decoding_config
.
guided_decoding_backend
)
tokenizer_group
=
init_tokenizer_from_configs
(
tokenizer_group
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
...
@@ -74,8 +77,8 @@ class XgrammarBackend(StructuredOutputBackend):
...
@@ -74,8 +77,8 @@ class XgrammarBackend(StructuredOutputBackend):
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
if
request_type
==
StructuredOutputOptions
.
JSON
:
if
request_type
==
StructuredOutputOptions
.
JSON
:
ctx
=
self
.
compiler
.
compile_json_schema
(
grammar_spec
,
ctx
=
self
.
compiler
.
compile_json_schema
(
any_whitespace
=
Fals
e
)
grammar_spec
,
any_whitespace
=
not
self
.
disable_any_whitespac
e
)
elif
request_type
==
StructuredOutputOptions
.
JSON_OBJECT
:
elif
request_type
==
StructuredOutputOptions
.
JSON_OBJECT
:
ctx
=
self
.
compiler
.
compile_builtin_json_grammar
()
ctx
=
self
.
compiler
.
compile_builtin_json_grammar
()
elif
request_type
==
StructuredOutputOptions
.
GRAMMAR
:
elif
request_type
==
StructuredOutputOptions
.
GRAMMAR
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
53076d70
...
@@ -45,7 +45,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
...
@@ -45,7 +45,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
import
xgrammar
as
xgr
from
vllm.v1.core.sched
uler_
output
import
SchedulerOutput
from
vllm.v1.core.sched
.
output
import
SchedulerOutput
else
:
else
:
xgr
=
LazyLoader
(
"xgr"
,
globals
(),
"xgrammar"
)
xgr
=
LazyLoader
(
"xgr"
,
globals
(),
"xgrammar"
)
...
@@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
attn_metadata_builder
=
self
.
attn_backend
.
get_builder_cls
()(
self
.
attn_metadata_builder
=
self
.
attn_backend
.
get_builder_cls
()(
weakref
.
proxy
(
self
))
weakref
.
proxy
(
self
))
self
.
cascade_attn_enabled
=
not
self
.
model_config
.
disable_cascade_attn
# Multi-modal data support
# Multi-modal data support
self
.
input_registry
=
INPUT_REGISTRY
self
.
input_registry
=
INPUT_REGISTRY
...
@@ -150,8 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -150,8 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
use_spec_decode
=
False
self
.
use_spec_decode
=
False
if
self
.
speculative_config
:
if
self
.
speculative_config
:
self
.
use_spec_decode
=
True
self
.
use_spec_decode
=
True
# TODO: find a better way to check if we are using ngram.
assert
self
.
speculative_config
.
method
==
"ngram"
,
\
assert
self
.
speculative_config
.
ngram_prompt_lookup_min
,
\
"Currently, only ngram spec decode is supported in V1."
"Currently, only ngram spec decode is supported in V1."
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
self
.
drafter
=
NgramProposer
()
self
.
drafter
=
NgramProposer
()
...
@@ -159,7 +159,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -159,7 +159,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# This usually takes less than 1 second.
# This usually takes less than 1 second.
self
.
drafter
.
propose
(
self
.
drafter
.
propose
(
np
.
zeros
(
1024
,
dtype
=
np
.
int32
),
np
.
zeros
(
1024
,
dtype
=
np
.
int32
),
self
.
speculative_config
.
ngram_prompt_lookup_min
,
self
.
speculative_config
.
prompt_lookup_min
,
self
.
speculative_config
.
prompt_lookup_max
,
self
.
speculative_config
.
num_speculative_tokens
,
self
.
speculative_config
.
num_speculative_tokens
,
)
)
self
.
rejection_sampler
=
RejectionSampler
()
self
.
rejection_sampler
=
RejectionSampler
()
...
@@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
positions_cpu
[:
total_num_scheduled_tokens
],
self
.
positions_cpu
[:
total_num_scheduled_tokens
],
non_blocking
=
True
)
non_blocking
=
True
)
# Prepare for cascade attention if needed.
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
common_prefix_len
=
0
num_scheduled_tokens
,
if
self
.
cascade_attn_enabled
:
scheduler_output
.
num_common_prefix_blocks
,
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
)
num_scheduled_tokens
,
scheduler_output
.
num_common_prefix_blocks
,
)
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
num_actual_tokens
=
total_num_scheduled_tokens
,
...
@@ -1151,7 +1155,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1151,7 +1155,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
input_batch
.
token_ids_cpu
[
i
,
start_idx
:
end_idx
]
=
sampled_ids
self
.
input_batch
.
token_ids_cpu
[
i
,
start_idx
:
end_idx
]
=
sampled_ids
drafter_output
=
self
.
drafter
.
propose
(
drafter_output
=
self
.
drafter
.
propose
(
self
.
input_batch
.
token_ids_cpu
[
i
,
:
end_idx
],
self
.
input_batch
.
token_ids_cpu
[
i
,
:
end_idx
],
self
.
speculative_config
.
ngram_prompt_lookup_min
,
self
.
speculative_config
.
prompt_lookup_min
,
self
.
speculative_config
.
prompt_lookup_max
,
self
.
speculative_config
.
num_speculative_tokens
,
self
.
speculative_config
.
num_speculative_tokens
,
)
)
if
drafter_output
is
None
or
len
(
drafter_output
)
==
0
:
if
drafter_output
is
None
or
len
(
drafter_output
)
==
0
:
...
@@ -1506,34 +1511,46 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1506,34 +1511,46 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
cache size of each layer
"""
"""
if
len
(
kv_cache_config
.
groups
)
>
1
:
if
len
(
kv_cache_config
.
kv_cache_
groups
)
>
1
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Hybrid models with more than one KV cache type are not "
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
"supported yet."
)
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
layer_name
,
layer_spec
in
kv_cache_config
.
kv_cache_spec
.
items
():
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
:
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
kv_cache_spec
=
kv_cache_group
.
kv_cache_spec
assert
tensor_config
.
size
%
layer_spec
.
page_size_bytes
==
0
for
layer_name
in
kv_cache_group
.
layer_names
:
num_blocks
=
tensor_config
.
size
//
layer_spec
.
page_size_bytes
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
if
isinstance
(
layer_spec
,
FullAttentionSpec
):
assert
tensor_config
.
size
%
kv_cache_spec
.
page_size_bytes
==
0
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
tensor_config
.
size
//
kv_cache_spec
.
page_size_bytes
num_blocks
,
layer_spec
.
block_size
,
layer_spec
.
num_kv_heads
,
# `num_blocks` is the number of blocks the model runner can use.
layer_spec
.
head_size
)
# `kv_cache_config.num_blocks` is the number of blocks that
dtype
=
layer_spec
.
dtype
# KVCacheManager may allocate.
kv_caches
[
layer_name
]
=
torch
.
zeros
(
kv_cache_shape
,
# Since different GPUs may have different number of layers and
dtype
=
dtype
,
# different memory capacities, `num_blocks` can be different on
device
=
self
.
device
)
# different GPUs, and `kv_cache_config.num_blocks` is set to
else
:
# the min of all `num_blocks`. Verify it here.
raise
NotImplementedError
assert
num_blocks
>=
kv_cache_config
.
num_blocks
if
isinstance
(
kv_cache_spec
,
FullAttentionSpec
):
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
dtype
=
kv_cache_spec
.
dtype
kv_caches
[
layer_name
]
=
torch
.
zeros
(
kv_cache_shape
,
dtype
=
dtype
,
device
=
self
.
device
)
else
:
# TODO: add new branches when introducing more types of
# KV cache specs.
raise
ValueError
(
"Unknown KV cache spec type."
)
bind_kv_cache
(
bind_kv_cache
(
kv_caches
,
kv_caches
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
kv_caches
)
self
.
kv_caches
)
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]
:
"""
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Attention module in the static forward context.
...
@@ -1545,7 +1562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1545,7 +1562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
kv_cache_spec
:
KVCacheSpec
=
{}
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
if
isinstance
(
attn_module
,
FusedMoE
):
if
isinstance
(
attn_module
,
FusedMoE
):
continue
continue
...
@@ -1558,7 +1575,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1558,7 +1575,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_size
=
block_size
,
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
head_size
=
attn_module
.
head_size
,
dtype
=
attn_module
.
dtype
,
dtype
=
self
.
kv_cache_
dtype
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
):
AttentionType
.
ENCODER_ONLY
):
...
...
vllm/v1/worker/gpu_worker.py
View file @
53076d70
...
@@ -28,7 +28,7 @@ from vllm.v1.worker.worker_base import WorkerBase
...
@@ -28,7 +28,7 @@ from vllm.v1.worker.worker_base import WorkerBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.sched
uler_
output
import
SchedulerOutput
from
vllm.v1.core.sched
.
output
import
SchedulerOutput
class
Worker
(
WorkerBase
):
class
Worker
(
WorkerBase
):
...
@@ -185,7 +185,7 @@ class Worker(WorkerBase):
...
@@ -185,7 +185,7 @@ class Worker(WorkerBase):
return
int
(
available_kv_cache_memory
)
return
int
(
available_kv_cache_memory
)
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]
:
return
self
.
model_runner
.
get_kv_cache_spec
()
return
self
.
model_runner
.
get_kv_cache_spec
()
def
initialize_from_config
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
initialize_from_config
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
53076d70
...
@@ -11,6 +11,7 @@ import torch.nn as nn
...
@@ -11,6 +11,7 @@ import torch.nn as nn
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
import
torch_xla.runtime
as
xr
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -23,18 +24,21 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
...
@@ -23,18 +24,21 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
,
cdiv
,
is_pin_memory_available
from
vllm.utils
import
LayerBlockType
,
cdiv
,
is_pin_memory_available
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
from
vllm.v1.attention.backends.pallas
import
(
NUM_KV_PAGES_PER_BLOCK
,
PallasAttentionBackend
,
PallasMetadata
)
PallasMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
KVCacheSpec
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
)
ModelRunnerOutput
,
SamplerOutput
)
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
from
vllm.v1.sample.tpu.sampler
import
Sampler
as
TPUSampler
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.sched
uler
import
SchedulerOutput
from
vllm.v1.core.sched
.output
import
SchedulerOutput
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -42,6 +46,8 @@ logger = init_logger(__name__)
...
@@ -42,6 +46,8 @@ logger = init_logger(__name__)
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID
=
1_000_000_000
_PAD_SLOT_ID
=
1_000_000_000
INVALID_TOKEN_ID
=
-
1
INVALID_TOKEN_ID
=
-
1
# Smallest output size
MIN_NUM_SEQS
=
8
class
TPUModelRunner
:
class
TPUModelRunner
:
...
@@ -68,6 +74,10 @@ class TPUModelRunner:
...
@@ -68,6 +74,10 @@ class TPUModelRunner:
scheduler_config
=
self
.
scheduler_config
scheduler_config
=
self
.
scheduler_config
parallel_config
=
self
.
parallel_config
parallel_config
=
self
.
parallel_config
self
.
device
=
device
self
.
device
=
device
self
.
check_recompilation
=
envs
.
VLLM_XLA_CHECK_RECOMPILATION
if
self
.
check_recompilation
:
self
.
num_xla_graphs
=
xr
.
get_num_cached_compilation_graph
()
self
.
enforce_eager
=
model_config
.
enforce_eager
self
.
pin_memory
=
is_pin_memory_available
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
self
.
model_config
.
dtype
self
.
dtype
=
self
.
model_config
.
dtype
...
@@ -138,8 +148,10 @@ class TPUModelRunner:
...
@@ -138,8 +148,10 @@ class TPUModelRunner:
device
=
"cpu"
)
device
=
"cpu"
)
self
.
slot_mapping_np
=
self
.
slot_mapping_cpu
.
numpy
()
self
.
slot_mapping_np
=
self
.
slot_mapping_cpu
.
numpy
()
padded_max_num_blocks_per_req
=
_get_padded_number
(
self
.
max_num_blocks_per_req
,
NUM_KV_PAGES_PER_BLOCK
)
self
.
block_table_cpu
=
torch
.
zeros
(
self
.
block_table_cpu
=
torch
.
zeros
(
(
self
.
max_num_tokens
,
self
.
max_num_blocks_per_req
),
(
self
.
max_num_tokens
,
padded_
max_num_blocks_per_req
),
dtype
=
self
.
input_batch
.
block_table
.
get_cpu_tensor
().
dtype
,
dtype
=
self
.
input_batch
.
block_table
.
get_cpu_tensor
().
dtype
,
device
=
"cpu"
)
device
=
"cpu"
)
...
@@ -267,6 +279,9 @@ class TPUModelRunner:
...
@@ -267,6 +279,9 @@ class TPUModelRunner:
req_data
.
num_computed_tokens
)
req_data
.
num_computed_tokens
)
self
.
input_batch
.
block_table
.
append_row
(
req_data
.
new_block_ids
,
self
.
input_batch
.
block_table
.
append_row
(
req_data
.
new_block_ids
,
req_index
)
req_index
)
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed
=
len
(
removed_req_indices
)
>
0
or
len
(
req_ids_to_add
)
>
0
# Add the new or resumed requests to the persistent batch.
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
# The smaller empty indices are filled first.
...
@@ -284,13 +299,17 @@ class TPUModelRunner:
...
@@ -284,13 +299,17 @@ class TPUModelRunner:
# Condense the batched states if there are empty indices.
# Condense the batched states if there are empty indices.
if
removed_req_indices
:
if
removed_req_indices
:
self
.
input_batch
.
condense
(
removed_req_indices
)
self
.
input_batch
.
condense
(
removed_req_indices
)
# TODO This slices tensors to copy to device, triggering recompilation.
if
batch_changed
:
self
.
input_batch
.
refresh_sampling_metadata
()
return
len
(
unscheduled_req_ids
)
>
0
or
len
(
req_ids_to_add
)
>
0
return
len
(
unscheduled_req_ids
)
>
0
or
len
(
req_ids_to_add
)
>
0
def
get_model
(
self
)
->
nn
.
Module
:
def
get_model
(
self
)
->
nn
.
Module
:
assert
self
.
model
is
not
None
assert
self
.
model
is
not
None
return
self
.
model
return
self
.
model
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]
:
"""
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Attention module in the static forward context.
...
@@ -301,7 +320,7 @@ class TPUModelRunner:
...
@@ -301,7 +320,7 @@ class TPUModelRunner:
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
forward_ctx
=
self
.
vllm_config
.
compilation_config
.
static_forward_context
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
kv_cache_spec
:
KVCacheSpec
=
{}
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
# TODO: Support other attention modules, e.g., sliding window,
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
# cross-attention, MLA.
...
@@ -447,6 +466,8 @@ class TPUModelRunner:
...
@@ -447,6 +466,8 @@ class TPUModelRunner:
# TODO: Support prompt logprobs.
# TODO: Support prompt logprobs.
padded_num_reqs
=
_get_padded_num_reqs_with_upper_limit
(
padded_num_reqs
=
_get_padded_num_reqs_with_upper_limit
(
num_reqs
,
self
.
max_num_reqs
)
num_reqs
,
self
.
max_num_reqs
)
# Indices at which we sample (positions of last token in the sequence).
# Padded to avoid recompiling when `num_reqs` varies.
logits_indices
=
self
.
query_start_loc_cpu
[
1
:
padded_num_reqs
+
1
]
-
1
logits_indices
=
self
.
query_start_loc_cpu
[
1
:
padded_num_reqs
+
1
]
-
1
logits_indices
=
logits_indices
.
to
(
self
.
device
)
logits_indices
=
logits_indices
.
to
(
self
.
device
)
return
attn_metadata
,
logits_indices
return
attn_metadata
,
logits_indices
...
@@ -576,7 +597,14 @@ class TPUModelRunner:
...
@@ -576,7 +597,14 @@ class TPUModelRunner:
# then the embedding layer is not included in the CUDA graph.
# then the embedding layer is not included in the CUDA graph.
input_ids
=
self
.
input_ids
input_ids
=
self
.
input_ids
inputs_embeds
=
None
inputs_embeds
=
None
sampling_metadata
=
self
.
input_batch
.
sampling_metadata
num_reqs
=
self
.
input_batch
.
num_reqs
# NOTE (NickLucche) here we sync with TPU: if there's any shape
# mismatch in pre-processing, it will trigger a small recompilation
# of the code thus far. Forward graph remains untouched.
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
\
from_sampling_metadata
(
sampling_metadata
,
logits_indices
,
num_reqs
,
self
.
device
)
# Run the decoder
# Run the decoder
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
...
@@ -585,12 +613,13 @@ class TPUModelRunner:
...
@@ -585,12 +613,13 @@ class TPUModelRunner:
kv_caches
=
self
.
kv_caches
,
kv_caches
=
self
.
kv_caches
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
num_reqs
=
self
.
input_batch
.
num_reqs
selected_token_ids
=
self
.
model
.
sample_from_hidden
(
selected_token_ids
=
self
.
model
.
compute_logits
(
hidden_states
,
hidden_states
,
tpu_sampling_metadata
)
logits_indices
,
None
)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids
=
selected_token_ids
.
cpu
()[:
num_reqs
]
selected_token_ids
=
selected_token_ids
.
cpu
()[:
num_reqs
]
# Then, let's update the cache state.
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens
:
list
[
tuple
[
int
,
CachedRequestState
,
int
]]
=
[]
request_seq_lens
:
list
[
tuple
[
int
,
CachedRequestState
,
int
]]
=
[]
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
assert
req_id
is
not
None
...
@@ -607,7 +636,6 @@ class TPUModelRunner:
...
@@ -607,7 +636,6 @@ class TPUModelRunner:
# This relies on cuda-specific torch-internal impl details
# This relies on cuda-specific torch-internal impl details
generator
.
set_offset
(
generator
.
get_offset
()
-
4
)
generator
.
set_offset
(
generator
.
get_offset
()
-
4
)
# num_reqs entries should be non-None
assert
all
(
assert
all
(
req_id
is
not
None
for
req_id
in
req_id
is
not
None
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]),
"req_ids contains None"
self
.
input_batch
.
req_ids
[:
num_reqs
]),
"req_ids contains None"
...
@@ -620,6 +648,7 @@ class TPUModelRunner:
...
@@ -620,6 +648,7 @@ class TPUModelRunner:
max_gen_len
=
selected_token_ids
.
shape
[
-
1
]
max_gen_len
=
selected_token_ids
.
shape
[
-
1
]
if
max_gen_len
==
1
:
if
max_gen_len
==
1
:
valid_sampled_token_ids
=
selected_token_ids
.
tolist
()
valid_sampled_token_ids
=
selected_token_ids
.
tolist
()
for
i
,
req_state
,
seq_len
in
request_seq_lens
:
for
i
,
req_state
,
seq_len
in
request_seq_lens
:
token_id
=
valid_sampled_token_ids
[
i
][
0
]
token_id
=
valid_sampled_token_ids
[
i
][
0
]
self
.
input_batch
.
token_ids_cpu
[
i
,
seq_len
]
=
token_id
self
.
input_batch
.
token_ids_cpu
[
i
,
seq_len
]
=
token_id
...
@@ -647,6 +676,12 @@ class TPUModelRunner:
...
@@ -647,6 +676,12 @@ class TPUModelRunner:
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
)
)
# Check there is no new graph compilation, all the graphs should be
# captured and compiled during warming up.
if
self
.
check_recompilation
and
not
self
.
enforce_eager
:
curr_cached_graph
=
xr
.
get_num_cached_compilation_graph
()
assert
self
.
num_xla_graphs
==
curr_cached_graph
,
(
"Recompilation after warm up is detected."
)
return
model_runner_output
return
model_runner_output
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
...
@@ -676,11 +711,8 @@ class TPUModelRunner:
...
@@ -676,11 +711,8 @@ class TPUModelRunner:
fullgraph
=
True
,
fullgraph
=
True
,
dynamic
=
False
)
dynamic
=
False
)
def
_dummy_run
(
@
torch
.
no_grad
()
self
,
def
_dummy_run
(
self
,
kv_caches
,
num_tokens
:
int
)
->
None
:
kv_caches
,
num_tokens
:
int
,
)
->
None
:
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
input_ids
=
None
input_ids
=
None
inputs_embeds
=
torch
.
zeros
((
num_tokens
,
self
.
hidden_size
),
inputs_embeds
=
torch
.
zeros
((
num_tokens
,
self
.
hidden_size
),
...
@@ -729,32 +761,10 @@ class TPUModelRunner:
...
@@ -729,32 +761,10 @@ class TPUModelRunner:
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
0
)
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
assert
self
.
model
is
not
None
self
.
model
(
input_ids
=
input_ids
,
hidden_states
=
self
.
model
(
positions
=
position_ids
,
input_ids
=
input_ids
,
kv_caches
=
kv_caches
,
positions
=
position_ids
,
inputs_embeds
=
inputs_embeds
)
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
,
)
num_reqs
=
_get_padded_num_reqs_with_upper_limit
(
64
,
self
.
max_num_reqs
)
# NOTE(chengjiyao): In total, the compute_logits function utilizes a
# compilation cache size of token_bucket_num multiplied by
# req_bucket_num. This is acceptable, given the graph's relatively
# small size.
while
True
:
logits_indices
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
torch
.
_dynamo
.
mark_dynamic
(
hidden_states
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
logits_indices
,
0
)
self
.
model
.
compute_logits
(
hidden_states
,
logits_indices
,
None
)
if
num_reqs
>=
self
.
max_num_reqs
:
break
num_reqs
=
_get_padded_num_reqs_with_upper_limit
(
num_reqs
+
1
,
self
.
max_num_reqs
)
def
capture_model
(
self
)
->
None
:
def
capture_model
(
self
)
->
None
:
"""Compile the model."""
"""Compile the model."""
...
@@ -764,16 +774,62 @@ class TPUModelRunner:
...
@@ -764,16 +774,62 @@ class TPUModelRunner:
start
=
time
.
perf_counter
()
start
=
time
.
perf_counter
()
num_tokens
=
16
num_tokens
=
16
while
True
:
while
True
:
self
.
_dummy_run
(
self
.
kv_caches
,
num_tokens
)
logger
.
info
(
" -- num_tokens: %d"
,
num_tokens
)
logger
.
info
(
" -- num_tokens: %d"
,
num_tokens
)
self
.
_dummy_run
(
self
.
kv_caches
,
num_tokens
)
xm
.
mark_step
()
xm
.
mark_step
()
xm
.
wait_device_ops
()
if
num_tokens
>=
self
.
max_num_tokens
:
if
num_tokens
>=
self
.
max_num_tokens
:
break
break
num_tokens
*=
2
num_tokens
*=
2
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
logger
.
info
(
"Compiling sampling with different input shapes."
)
start
=
time
.
perf_counter
()
num_tokens
=
16
hsize
=
self
.
model_config
.
get_hidden_size
()
device
=
self
.
device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
while
True
:
num_reqs_to_sample
=
MIN_NUM_SEQS
dummy_hidden
=
torch
.
randn
((
num_tokens
,
hsize
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
while
True
:
# Default metadata is an all_greedy setup. But since the
# `do_argmax` flag is a tensor, we still compile the full graph
meta
=
self
.
input_batch
.
sampling_metadata
indices
=
torch
.
zeros
(
num_reqs_to_sample
,
dtype
=
torch
.
int32
,
device
=
device
,
)
sampling_meta
=
TPUSupportedSamplingMetadata
.
\
from_sampling_metadata
(
meta
,
indices
,
num_reqs_to_sample
,
device
)
logger
.
info
(
" -- num_tokens: %d, num_seqs: %d"
,
num_tokens
,
num_reqs_to_sample
)
self
.
model
.
sample_from_hidden
(
dummy_hidden
,
sampling_meta
)
xm
.
mark_step
()
if
num_reqs_to_sample
>=
self
.
max_num_reqs
:
break
num_reqs_to_sample
*=
2
if
num_tokens
>=
self
.
max_num_tokens
:
break
num_tokens
*=
2
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
# Record the number cached XLA graph after warming up, this will be
# used for checking there is no additional graph compilation during
# runtime execution.
if
self
.
check_recompilation
:
total_cached_graphs
=
xr
.
get_num_cached_compilation_graph
()
num_compiled_graphs
=
total_cached_graphs
-
self
.
num_xla_graphs
logger
.
info
(
"Compiled %d XLA graphs."
,
num_compiled_graphs
)
self
.
num_xla_graphs
+=
num_compiled_graphs
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
"""
Initialize KV cache based on `kv_cache_config`.
Initialize KV cache based on `kv_cache_config`.
...
@@ -781,31 +837,33 @@ class TPUModelRunner:
...
@@ -781,31 +837,33 @@ class TPUModelRunner:
kv_cache_config: Configuration for the KV cache, including the KV
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
cache size of each layer
"""
"""
if
len
(
kv_cache_config
.
groups
)
>
1
:
if
len
(
kv_cache_config
.
kv_cache_
groups
)
>
1
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Hybrid models with more than one KV cache type are not "
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
"supported yet."
)
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
layer_name
,
layer_spec
in
kv_cache_config
.
kv_cache_spec
.
items
():
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
:
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
kv_cache_spec
=
kv_cache_group
.
kv_cache_spec
assert
tensor_config
.
size
%
layer_spec
.
page_size_bytes
==
0
for
layer_name
in
kv_cache_group
.
layer_names
:
num_blocks
=
tensor_config
.
size
//
layer_spec
.
page_size_bytes
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
if
isinstance
(
layer_spec
,
FullAttentionSpec
):
assert
tensor_config
.
size
%
kv_cache_spec
.
page_size_bytes
==
0
kv_cache_shape
=
PallasAttentionBackend
.
get_kv_cache_shape
(
num_blocks
=
tensor_config
.
size
//
kv_cache_spec
.
page_size_bytes
num_blocks
,
layer_spec
.
block_size
,
layer_spec
.
num_kv_heads
,
if
isinstance
(
kv_cache_spec
,
FullAttentionSpec
):
layer_spec
.
head_size
)
kv_cache_shape
=
PallasAttentionBackend
.
get_kv_cache_shape
(
dtype
=
layer_spec
.
dtype
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
tpu_k_cache
=
torch
.
zeros
(
kv_cache_shape
,
dtype
=
kv_cache_spec
.
dtype
dtype
=
dtype
,
device
=
self
.
device
)
tpu_k_cache
=
torch
.
zeros
(
kv_cache_shape
,
tpu_v_cache
=
torch
.
zeros_like
(
tpu_k_cache
)
dtype
=
dtype
,
device
=
self
.
device
)
kv_caches
[
layer_name
]
=
(
tpu_k_cache
,
tpu_v_cache
)
tpu_v_cache
=
torch
.
zeros_like
(
tpu_k_cache
)
else
:
raise
NotImplementedError
kv_caches
[
layer_name
]
=
(
tpu_k_cache
,
tpu_v_cache
)
else
:
raise
NotImplementedError
bind_kv_cache
(
bind_kv_cache
(
kv_caches
,
kv_caches
,
...
@@ -818,6 +876,13 @@ class ModelWrapperV1(nn.Module):
...
@@ -818,6 +876,13 @@ class ModelWrapperV1(nn.Module):
def
__init__
(
self
,
model
:
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
):
super
().
__init__
()
super
().
__init__
()
self
.
model
=
model
self
.
model
=
model
self
.
sampler
=
TPUSampler
()
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
TPUSupportedSamplingMetadata
)
->
SamplerOutput
:
sampler_out
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
sampler_out
def
forward
(
def
forward
(
self
,
self
,
...
@@ -826,7 +891,7 @@ class ModelWrapperV1(nn.Module):
...
@@ -826,7 +891,7 @@ class ModelWrapperV1(nn.Module):
kv_caches
:
list
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
list
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model
and samples the next token
.
"""Executes the forward pass of the model.
Args:
Args:
input_ids: The input token IDs of shape [num_tokens].
input_ids: The input token IDs of shape [num_tokens].
...
@@ -837,7 +902,6 @@ class ModelWrapperV1(nn.Module):
...
@@ -837,7 +902,6 @@ class ModelWrapperV1(nn.Module):
hidden_size]. It is used for multimodal models.
hidden_size]. It is used for multimodal models.
"""
"""
assert
self
.
model
is
not
None
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
...
@@ -846,17 +910,33 @@ class ModelWrapperV1(nn.Module):
...
@@ -846,17 +910,33 @@ class ModelWrapperV1(nn.Module):
return
hidden_states
return
hidden_states
@
torch
.
compile
(
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
sample_from_hidden
(
def
compute_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
logits_indices
:
torch
.
Tensor
,
sampling_metadata
:
TPUSupportedSamplingMetadata
,
sampling_metadata
,
)
->
torch
.
Tensor
:
)
->
Optional
[
torch
.
Tensor
]:
"""
hidden_states
=
hidden_states
[
logits_indices
]
Sample with xla-friendly function. This function is to be traced
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
separately from `forward` for lighter compilation overhead.
selected_token_ids
=
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
)
"""
return
selected_token_ids
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states
=
\
hidden_states
[
sampling_metadata
.
indices_do_sample
]
logits
=
self
.
compute_logits
(
sample_hidden_states
)
# Greedy sampling can't be run without branching the graph on Sampler.
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
# NOTE do_argmax is a scalar, this is just an optimized if/else.
out_tokens
=
torch
.
where
(
sampling_metadata
.
do_argmax
,
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
),
self
.
sample
(
logits
,
sampling_metadata
)
\
.
sampled_token_ids
)
return
out_tokens
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
return
logits
def
get_multimodal_embeddings
(
self
,
*
args
,
**
kwargs
):
def
get_multimodal_embeddings
(
self
,
*
args
,
**
kwargs
):
return
self
.
model
.
get_multimodal_embeddings
(
*
args
,
**
kwargs
)
return
self
.
model
.
get_multimodal_embeddings
(
*
args
,
**
kwargs
)
...
@@ -876,5 +956,5 @@ def _get_padded_token_len(x: int) -> int:
...
@@ -876,5 +956,5 @@ def _get_padded_token_len(x: int) -> int:
def
_get_padded_num_reqs_with_upper_limit
(
x
,
upper_limit
)
->
int
:
def
_get_padded_num_reqs_with_upper_limit
(
x
,
upper_limit
)
->
int
:
res
=
64
if
x
<=
64
else
1
<<
(
x
-
1
).
bit_length
()
res
=
MIN_NUM_SEQS
if
x
<=
MIN_NUM_SEQS
else
1
<<
(
x
-
1
).
bit_length
()
return
min
(
res
,
upper_limit
)
return
min
(
res
,
upper_limit
)
vllm/v1/worker/tpu_worker.py
View file @
53076d70
...
@@ -17,7 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
...
@@ -17,7 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched
uler
import
SchedulerOutput
from
vllm.v1.core.sched
.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
KVCacheSpec
)
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
...
@@ -189,7 +189,7 @@ class TPUWorker:
...
@@ -189,7 +189,7 @@ class TPUWorker:
def
get_model
(
self
)
->
nn
.
Module
:
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model_runner
.
get_model
()
return
self
.
model_runner
.
get_model
()
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]
:
return
self
.
model_runner
.
get_kv_cache_spec
()
return
self
.
model_runner
.
get_kv_cache_spec
()
def
initialize_from_config
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
initialize_from_config
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
...
...
vllm/v1/worker/worker_base.py
View file @
53076d70
...
@@ -51,7 +51,7 @@ class WorkerBase(WorkerBaseV0):
...
@@ -51,7 +51,7 @@ class WorkerBase(WorkerBaseV0):
self
.
device
:
Optional
[
torch
.
device
]
=
None
self
.
device
:
Optional
[
torch
.
device
]
=
None
self
.
model_runner
:
Optional
[
nn
.
Module
]
=
None
self
.
model_runner
:
Optional
[
nn
.
Module
]
=
None
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]
:
"""Get specifications for KV cache implementation."""
"""Get specifications for KV cache implementation."""
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/worker/openvino_model_runner.py
deleted
100644 → 0
View file @
322a0be6
# SPDX-License-Identifier: Apache-2.0
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
,
Tuple
import
openvino
as
ov
import
torch
from
torch
import
nn
from
vllm.attention
import
get_attn_backend
from
vllm.attention.backends.openvino
import
OpenVINOAttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.openvino
import
get_model
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalKwargs
,
MultiModalPlaceholderMap
)
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.worker.model_runner_base
import
ModelRunnerBase
logger
=
init_logger
(
__name__
)
class
ModelInput
(
NamedTuple
):
input_tokens
:
torch
.
Tensor
input_positions
:
torch
.
Tensor
attn_metadata
:
Optional
[
OpenVINOAttentionMetadata
]
seq_lens
:
List
[
int
]
query_lens
:
List
[
int
]
multi_modal_kwargs
:
BatchedTensorInputs
@
classmethod
def
empty
(
cls
,
device
):
return
ModelInput
(
input_tokens
=
torch
.
empty
(
0
,
device
=
device
),
input_positions
=
torch
.
empty
(
0
,
device
=
device
),
attn_metadata
=
None
,
seq_lens
=
[],
query_lens
=
[],
multi_modal_kwargs
=
{})
class
OpenVINOModelRunner
(
ModelRunnerBase
):
def
__init__
(
self
,
ov_core
:
ov
.
Core
,
vllm_config
:
VllmConfig
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
*
args
,
**
kwargs
,
):
self
.
ov_core
=
ov_core
ModelRunnerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
self
.
is_driver_worker
=
is_driver_worker
self
.
device
=
self
.
device_config
.
device
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
sliding_window
=
self
.
model_config
.
get_sliding_window
()
self
.
block_size
=
self
.
cache_config
.
block_size
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
)
# Multi-modal data support
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
multi_modal_input_mapper
=
self
.
mm_registry
\
.
create_input_mapper
(
self
.
model_config
)
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
ov_core
=
self
.
ov_core
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
def
_prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
ModelInput
:
"""Prepare the model input based on a given sequence group.
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
The result tensors and data structure also batches input in prefill
-> decode order. For example,
- input_tokens[:num_prefill_tokens] contains prefill tokens.
- input_tokens[num_prefill_tokens:] contains decode tokens.
"""
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
past_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
multi_modal_kwargs_list
:
List
[
MultiModalKwargs
]
=
[]
multi_modal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
subsequence_begins
:
List
[
int
]
=
[]
block_indices
:
List
[
int
]
=
[]
block_indices_begins
:
List
[
int
]
=
[]
# initialize beginning of prefix sums
subsequence_begins
.
append
(
0
)
block_indices_begins
.
append
(
0
)
if
len
(
seq_group_metadata_list
)
==
0
:
return
ModelInput
.
empty
(
self
.
device
)
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
is_prompt
=
seq_group_metadata
.
is_prompt
for
seq_id
in
seq_ids
:
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
if
(
self
.
scheduler_config
is
not
None
and
self
.
scheduler_config
.
chunked_prefill_enabled
and
not
(
computed_block_nums
is
None
or
computed_block_nums
==
[])):
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching "
"now."
)
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
if
is_prompt
:
computed_len
=
seq_data
.
get_num_computed_tokens
()
else
:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
computed_len
=
seq_data
.
get_len
()
-
1
seq_len
=
min
(
seq_data
.
get_len
(),
computed_len
+
seq_group_metadata
.
token_chunk_size
,
)
if
is_prompt
:
tokens
=
seq_data
.
get_token_ids
()[
computed_len
:
seq_len
]
else
:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens
=
[
seq_data
.
get_last_token_id
()]
# Prefix cache was hit.
# Prefix is not supported with sliding_window
prefix_cache_hit
=
(
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
and
is_prompt
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
if
prefix_cache_hit
:
assert
computed_block_nums
is
not
None
computed_len
=
len
(
computed_block_nums
)
*
self
.
block_size
tokens
=
tokens
[
computed_len
:]
elif
(
self
.
scheduler_config
.
chunked_prefill_enabled
or
not
is_prompt
):
if
seq_group_metadata
.
block_tables
is
not
None
:
# chunked prefill or decode
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
if
self
.
sliding_window
is
not
None
:
# chunked prefill doesn't support sliding window.
assert
not
self
.
scheduler_config
.
chunked_prefill_enabled
# noqa: E501
sliding_window_blocks
=
(
self
.
sliding_window
//
self
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
else
:
# Only happens when memory profiling runs.
block_table
=
[]
else
:
# prompt phase w/o prefix_caching, chunked_prefill
pass
block_indices
.
extend
(
block_table
)
block_indices_begins
.
append
(
block_indices_begins
[
-
1
]
+
len
(
block_table
))
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if
self
.
sliding_window
is
not
None
and
not
is_prompt
:
seq_len
=
min
(
seq_len
,
self
.
sliding_window
)
computed_len
=
seq_len
-
1
seq_lens
.
append
(
seq_len
)
query_len
=
seq_len
-
computed_len
query_lens
.
append
(
query_len
)
input_tokens
.
extend
(
tokens
)
positions_range
=
range
(
computed_len
,
seq_len
)
input_positions
.
extend
(
list
(
positions_range
))
past_lens
.
append
(
computed_len
)
subsequence_begins
.
append
(
subsequence_begins
[
-
1
]
+
query_len
)
if
is_prompt
:
assert
len
(
seq_ids
)
==
1
else
:
assert
(
query_len
==
1
),
"seq_len: {}, computed_len: {}, query_len: {}"
.
format
(
seq_len
,
computed_len
,
query_len
)
if
seq_group_metadata
.
multi_modal_data
:
# NOTE: mm_data only includes the subset of multi-modal
# items that intersect with the current prefill positions.
mm_data
,
placeholder_maps
=
MultiModalPlaceholderMap
\
.
from_seq_group
(
seq_group_metadata
,
positions_range
)
if
self
.
mm_registry
.
has_processor
(
self
.
model_config
):
mm_kwargs
=
mm_data
else
:
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
,
seq_group_metadata
.
mm_processor_kwargs
,
)
multi_modal_kwargs_list
.
append
(
mm_kwargs
)
for
modality
,
placeholder_map
in
placeholder_maps
.
items
():
multi_modal_placeholder_maps
[
modality
].
extend
(
placeholder_map
,
)
max_query_len
=
max
(
query_lens
)
assert
max_query_len
>
0
,
"query_lens: {}"
.
format
(
query_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
past_lens_tensor
=
torch
.
tensor
(
past_lens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# type: ignore
subsequence_begins_tensor
=
torch
.
tensor
(
subsequence_begins
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# type: ignore
block_indices_tensor
=
torch
.
tensor
(
block_indices
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# type: ignore
block_indices_begins_tensor
=
torch
.
tensor
(
block_indices_begins
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# type: ignore
max_context_len
=
max
(
seq_lens
)
max_context_len_tensor
=
torch
.
tensor
(
max_context_len
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# type: ignore
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
multi_modal_placeholder_maps
.
items
()
}
attn_metadata
=
self
.
attn_backend
.
make_openvino_metadata
(
past_lens
=
past_lens_tensor
,
subsequence_begins
=
subsequence_begins_tensor
,
block_indices
=
block_indices_tensor
,
block_indices_begins
=
block_indices_begins_tensor
,
max_context_len
=
max_context_len_tensor
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
enable_kv_scales_calculation
=
False
,
)
multi_modal_kwargs
=
MultiModalKwargs
.
batch
(
multi_modal_kwargs_list
)
return
ModelInput
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
query_lens
,
multi_modal_kwargs
=
multi_modal_kwargs
,
)
def
prepare_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
OpenVINOAttentionMetadata
,
SamplingMetadata
,
BatchedTensorInputs
]:
# Prepare input tensors.
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
query_lens
,
multi_modal_kwargs
,
)
=
self
.
_prepare_model_input
(
seq_group_metadata_list
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
pin_memory
=
False
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
multi_modal_kwargs
,
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
Tuple
[
"ov.Tensor"
,
"ov.Tensor"
]],
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
multi_modal_kwargs
,
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
model_executable
=
self
.
model
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"kv_caches"
:
kv_caches
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
or
{},
device
=
self
.
device
),
}
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
# Sample the next token.
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
return
output
def
prepare_model_input
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
vllm/worker/openvino_worker.py
deleted
100644 → 0
View file @
322a0be6
# SPDX-License-Identifier: Apache-2.0
"""An OpenVINO worker class."""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
openvino
as
ov
import
torch
import
torch.distributed
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
VllmConfig
)
from
vllm.distributed
import
(
broadcast_tensor_dict
,
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.platforms
import
current_platform
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.utils
import
bind_kv_cache
from
vllm.worker.openvino_model_runner
import
OpenVINOModelRunner
from
vllm.worker.worker_base
import
LoRANotSupportedWorkerBase
,
WorkerBase
logger
=
init_logger
(
__name__
)
class
OpenVINOCacheEngine
:
"""Manages the KV cache for OpenVINO backend.
This class is responsible for initializing and managing CPU KV
caches. It also provides methods for performing KV cache operations, such
as copying.
"""
def
__init__
(
self
,
cache_config
:
CacheConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
device_config
:
DeviceConfig
,
ov_core
:
ov
.
Core
,
ov_device
:
str
,
)
->
None
:
assert
device_config
.
device_type
==
"openvino"
self
.
cache_config
=
cache_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
head_size
=
model_config
.
get_head_size
()
if
device_config
.
device
.
type
==
"cpu"
and
\
cache_config
.
cache_dtype
==
ov
.
Type
.
u8
:
# Scale, zero point and quantized data will be stored together.
# The layout for per token per head:
# |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
# so, we have to extend head_size by 8, which is sizeof(float)
# for scale and sizeof(float) for zeropoint
self
.
head_size
+=
8
self
.
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
block_size
=
cache_config
.
block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for OpenVINO backend with a CPU target device, because we want
# to reuse KV cache management in the scheduler.
self
.
num_device_blocks
=
cache_config
.
num_gpu_blocks
self
.
num_swap_blocks
=
cache_config
.
num_cpu_blocks
# Get attention backend.
self
.
attn_backend
=
get_attn_backend
(
self
.
head_size
,
self
.
model_config
.
dtype
,
self
.
cache_config
.
cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
)
# Initialize the cache.
self
.
kv_cache
:
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]]
=
self
.
_allocate_kv_cache
(
self
.
num_device_blocks
,
ov_core
,
ov_device
)
# Initialize the swap.
self
.
swap_cache
:
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]]
=
self
.
_allocate_swap_cache
(
self
.
num_swap_blocks
,
ov_device
)
def
_allocate_kv_cache
(
self
,
num_blocks
:
int
,
ov_core
:
ov
.
Core
,
ov_device
:
str
,
)
->
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]]:
"""Allocates KV cache."""
k_block_shape
=
v_block_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
self
.
block_size
,
self
.
num_kv_heads
,
self
.
head_size
)[
1
:]
kv_cache
:
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]]
=
[]
if
current_platform
.
is_openvino_cpu
():
for
_
in
range
(
self
.
num_layers
):
key_blocks
=
ov
.
Tensor
(
self
.
cache_config
.
cache_dtype
,
k_block_shape
)
value_blocks
=
ov
.
Tensor
(
self
.
cache_config
.
cache_dtype
,
v_block_shape
)
kv_cache
.
append
((
key_blocks
,
value_blocks
))
else
:
# Update key_cache shape:
k_block_shape
=
(
v_block_shape
[
0
],
v_block_shape
[
1
],
v_block_shape
[
3
],
v_block_shape
[
2
])
remote_context
=
ov_core
.
get_default_context
(
ov_device
)
for
_
in
range
(
self
.
num_layers
):
key_blocks
=
\
remote_context
.
create_tensor
(
self
.
cache_config
.
cache_dtype
,
ov
.
Shape
(
k_block_shape
),
{})
value_blocks
=
\
remote_context
.
create_tensor
(
self
.
cache_config
.
cache_dtype
,
ov
.
Shape
(
v_block_shape
),
{})
kv_cache
.
append
((
key_blocks
,
value_blocks
))
return
kv_cache
def
_allocate_swap_cache
(
self
,
num_blocks
:
int
,
ov_device
:
str
,
)
->
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]]:
"""Allocates swap cache."""
k_block_shape
=
v_block_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
self
.
block_size
,
self
.
num_kv_heads
,
self
.
head_size
)[
1
:]
swap_cache
:
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]]
=
[]
if
num_blocks
==
0
:
return
swap_cache
assert
not
current_platform
.
is_openvino_cpu
(),
\
"CPU device isn't supposed to have swap cache"
# Update key_cache shape:
k_block_shape
=
(
v_block_shape
[
0
],
v_block_shape
[
1
],
v_block_shape
[
3
],
v_block_shape
[
2
])
for
_
in
range
(
self
.
num_layers
):
key_blocks
=
ov
.
Tensor
(
self
.
cache_config
.
cache_dtype
,
k_block_shape
)
value_blocks
=
ov
.
Tensor
(
self
.
cache_config
.
cache_dtype
,
v_block_shape
)
swap_cache
.
append
((
key_blocks
,
value_blocks
))
return
swap_cache
def
swap_in
(
self
,
src_to_dst
:
List
[
Tuple
[
int
,
int
]])
->
None
:
for
i
in
range
(
self
.
num_layers
):
for
swap_tensor
,
kv_tensor
in
zip
(
self
.
swap_cache
[
i
],
self
.
kv_cache
[
i
]):
self
.
attn_backend
.
swap_blocks
(
swap_tensor
,
kv_tensor
,
src_to_dst
)
def
swap_out
(
self
,
src_to_dst
:
List
[
Tuple
[
int
,
int
]])
->
None
:
for
i
in
range
(
self
.
num_layers
):
for
swap_tensor
,
kv_tensor
in
zip
(
self
.
swap_cache
[
i
],
self
.
kv_cache
[
i
]):
self
.
attn_backend
.
swap_blocks
(
kv_tensor
,
swap_tensor
,
src_to_dst
)
def
copy
(
self
,
src_to_dsts
:
List
[
Tuple
[
int
,
int
]])
->
None
:
if
(
len
(
src_to_dsts
)
>
0
):
self
.
attn_backend
.
copy_blocks
(
self
.
kv_cache
,
src_to_dsts
)
@
staticmethod
def
get_cache_block_size
(
block_size
:
int
,
cache_dtype
:
ov
.
Type
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
)
->
int
:
head_size
=
model_config
.
get_head_size
()
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
if
cache_dtype
==
ov
.
Type
.
u8
:
# Scale, zero point and quantized data will be stored together.
# The layout for per token per head:
# |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
# so, we have to extend head_size by 8, which is sizeof(float)
# for scale and sizeof(float) for zeropoint
head_size
+=
8
key_cache_block
=
block_size
*
num_kv_heads
*
head_size
value_cache_block
=
key_cache_block
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
cache_dtype
.
size
return
dtype_size
*
total
class
OpenVINOWorker
(
LoRANotSupportedWorkerBase
):
"""A worker class that executes the model on OpenVINO backend.
Each worker is associated with a single OpenVINO device. The worker is
responsible for maintaining the KV cache and executing the model on the
OpenVINO backend.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
=
False
,
)
->
None
:
WorkerBase
.
__init__
(
self
,
vllm_config
)
self
.
ov_core
=
ov
.
Core
()
self
.
parallel_config
.
rank
=
rank
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
self
.
model_runner
=
OpenVINOModelRunner
(
self
.
ov_core
,
vllm_config
=
self
.
vllm_config
,
kv_cache_dtype
=
self
.
vllm_config
.
cache_config
.
cache_dtype
,
is_driver_worker
=
is_driver_worker
,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self
.
cache_engine
:
OpenVINOCacheEngine
self
.
kv_cache
:
List
[
Tuple
[
ov
.
Tensor
,
ov
.
Tensor
]]
def
init_device
(
self
)
->
None
:
self
.
init_distributed_environment
()
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured
KV cache space.
"""
# For OpenVINO backend, in case of CPU device, the block number will be
# calculated based on the openvino_kvcache_space_bytes.
cache_block_size
=
self
.
get_cache_block_size_bytes
()
kvcache_space_bytes
=
self
.
cache_config
.
openvino_kvcache_space_bytes
if
current_platform
.
is_openvino_cpu
():
num_device_blocks
=
int
(
kvcache_space_bytes
//
cache_block_size
)
num_swap_blocks
=
0
else
:
if
kvcache_space_bytes
>
0
:
logger
.
info
(
"KV_CACHE size was explicitly configured via "
"VLLM_OPENVINO_KVCACHE_SPACE environment "
"variable, ignoring profiling run."
)
kv_cache_size
=
kvcache_space_bytes
else
:
try
:
kv_cache_size
=
self
.
profile_run
()
except
Exception
as
err
:
raise
RuntimeError
(
"The error occurred during profile run. This might be "
"due to insufficient GPU memory. Consider decreasing "
"`max_model_len` to limit the maximum simultaneously "
"processed tokens."
)
from
err
num_device_blocks
=
int
(
kv_cache_size
//
cache_block_size
)
num_swap_blocks
=
int
(
self
.
cache_config
.
swap_space_bytes
//
cache_block_size
)
return
num_device_blocks
,
num_swap_blocks
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the KV cache. Swappable CPU memory is only
supported on GPU.
For CPU, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
num_device_blocks
=
num_gpu_blocks
num_swap_blocks
=
num_cpu_blocks
if
current_platform
.
is_openvino_cpu
():
assert
(
num_swap_blocks
==
0
),
f
"
{
type
(
self
)
}
does not support swappable cache for CPU"
self
.
_validate_num_blocks
(
num_device_blocks
)
self
.
cache_config
.
num_gpu_blocks
=
num_device_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_swap_blocks
# Initialize the cache.
self
.
_init_cache_engine
()
def
_validate_num_blocks
(
self
,
num_blocks
:
int
)
->
None
:
"""Raise errors if the num_blocks is invalid."""
if
num_blocks
<=
0
:
raise
ValueError
(
"No available memory for the cache blocks. "
"Try increasing `VLLM_OPENVINO_KVCACHE_SPACE` when "
"initializing the engine."
)
max_seq_len
=
self
.
cache_config
.
block_size
*
num_blocks
if
self
.
model_config
.
max_model_len
>
max_seq_len
:
raise
ValueError
(
f
"The model's max seq len (
{
self
.
model_config
.
max_model_len
}
) "
"is larger than the maximum number of tokens that can be "
f
"stored in KV cache (
{
max_seq_len
}
). Try increasing "
"`VLLM_OPENVINO_KVCACHE_SPACE` or decreasing `max_model_len` "
"when initializing the engine."
)
def
_init_cache_engine
(
self
)
->
None
:
ov_device
=
envs
.
VLLM_OPENVINO_DEVICE
self
.
cache_engine
=
OpenVINOCacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
,
self
.
device_config
,
self
.
ov_core
,
ov_device
,
)
self
.
kv_cache
=
self
.
cache_engine
.
kv_cache
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
[
self
.
kv_cache
])
self
.
model_runner
.
block_size
=
self
.
cache_engine
.
block_size
assert
self
.
kv_cache
is
not
None
# Populate the cache to warmup the memory
if
current_platform
.
is_openvino_cpu
():
for
key_cache
,
value_cache
in
self
.
kv_cache
:
key_cache
.
data
[:]
=
0
value_cache
.
data
[:]
=
0
def
cache_swap_in
(
self
,
src_to_dst
:
List
[
Tuple
[
int
,
int
]])
->
None
:
self
.
cache_engine
.
swap_in
(
src_to_dst
)
def
cache_swap_out
(
self
,
src_to_dst
:
List
[
Tuple
[
int
,
int
]])
->
None
:
self
.
cache_engine
.
swap_out
(
src_to_dst
)
def
cache_copy
(
self
,
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]],
)
->
None
:
self
.
cache_engine
.
copy
(
blocks_to_copy
)
# type: ignore
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model_runner
.
get_model
()
@
torch
.
inference_mode
()
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
List
[
SamplerOutput
]:
if
execute_model_req
is
None
:
seq_group_metadata_list
=
None
else
:
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
num_seq_groups
:
int
=
len
(
seq_group_metadata_list
)
assert
execute_model_req
is
not
None
blocks_to_copy
=
execute_model_req
.
blocks_to_copy
blocks_to_swap_in
=
execute_model_req
.
blocks_to_swap_in
blocks_to_swap_out
=
execute_model_req
.
blocks_to_swap_out
data
:
Dict
[
str
,
Any
]
=
{
"num_seq_groups"
:
num_seq_groups
,
"blocks_to_copy"
:
execute_model_req
.
blocks_to_copy
,
"blocks_to_swap_in"
:
execute_model_req
.
blocks_to_swap_in
,
"blocks_to_swap_out"
:
execute_model_req
.
blocks_to_swap_out
,
}
broadcast_tensor_dict
(
data
,
src
=
0
)
else
:
data
=
broadcast_tensor_dict
(
src
=
0
)
num_seq_groups
=
data
[
"num_seq_groups"
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
blocks_to_swap_in
=
data
[
"blocks_to_swap_in"
]
blocks_to_swap_out
=
data
[
"blocks_to_swap_out"
]
if
current_platform
.
is_openvino_cpu
():
assert
len
(
execute_model_req
.
blocks_to_swap_in
)
==
0
assert
len
(
execute_model_req
.
blocks_to_swap_out
)
==
0
else
:
self
.
cache_swap_in
(
blocks_to_swap_in
)
self
.
cache_swap_out
(
blocks_to_swap_out
)
self
.
cache_copy
(
blocks_to_copy
)
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
return
[]
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
self
.
kv_cache
)
# OpenVINO worker only supports single-step execution.
return
[
output
]
def
init_distributed_environment
(
self
)
->
None
:
"""Initialize the distributed environment."""
parallel_config
=
self
.
parallel_config
rank
=
self
.
rank
distributed_init_method
=
self
.
distributed_init_method
init_distributed_environment
(
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
backend
=
"gloo"
,
)
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cpu
())
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
,
)
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Return the size in bytes of a single KV cache block."""
return
OpenVINOCacheEngine
.
get_cache_block_size
(
self
.
cache_config
.
block_size
,
self
.
cache_config
.
cache_dtype
,
self
.
model_config
,
self
.
parallel_config
,
)
def
profile_run
(
self
)
->
int
:
ov_device
=
envs
.
VLLM_OPENVINO_DEVICE
assert
not
current_platform
.
is_openvino_cpu
(),
\
"CPU device isn't supposed to use profile run."
import
openvino.properties.device
as
device
import
openvino.properties.intel_gpu
as
intel_gpu
ov_core
=
self
.
ov_core
cache_config
=
self
.
cache_config
model_config
=
self
.
model_config
parallel_config
=
self
.
parallel_config
device_config
=
self
.
device_config
input_registry
=
INPUT_REGISTRY
mm_registry
=
MULTIMODAL_REGISTRY
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
def
model_profile_run
():
top_k
=
model_config
.
get_vocab_size
()
-
1
sampling_params
=
SamplingParams
(
top_p
=
0.99
,
top_k
=
top_k
)
max_num_batched_tokens
=
\
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
tmp_cache_config
=
CacheConfig
(
cache_config
.
block_size
,
cache_config
.
gpu_memory_utilization
,
cache_config
.
swap_space_bytes
,
"auto"
)
tmp_cache_config
.
num_gpu_blocks
=
1
tmp_cache_config
.
num_cpu_blocks
=
0
tmp_cache_config
.
cache_dtype
=
cache_config
.
cache_dtype
profiling_cache_engine
=
OpenVINOCacheEngine
(
tmp_cache_config
,
model_config
,
parallel_config
,
device_config
,
ov_core
,
ov_device
)
# Profile memory usage with max_num_sequences sequences and the
# total # number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
block_size
=
cache_config
.
block_size
seq_num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
dummy_data
=
input_registry
\
.
dummy_data_for_profiling
(
model_config
,
seq_len
,
mm_registry
)
block_tables
=
[[
0
]
*
seq_num_blocks
]
*
max_num_seqs
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
is_prompt
=
True
,
seq_data
=
{
group_id
:
dummy_data
.
seq_data
},
sampling_params
=
sampling_params
,
block_tables
=
block_tables
,
lora_request
=
None
,
multi_modal_data
=
dummy_data
.
multi_modal_data
)
seqs
.
append
(
seq
)
self
.
model_runner
.
block_size
=
tmp_cache_config
.
block_size
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
profiling_cache_engine
.
kv_cache
)
# Run the model with the dummy inputs.
self
.
model_runner
.
execute_model
(
seqs
,
profiling_cache_engine
.
kv_cache
)
# Explicitly revert bind_kv_cache and delete temporary KV cache
# manager to free KV cache when real inputs will be passed to OV
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
[[
torch
.
tensor
([])
for
_
in
range
(
len
(
profiling_cache_engine
.
kv_cache
))
]])
del
profiling_cache_engine
logger
.
info
(
"Start profiling run with dummy inputs to evaluate "
"memory usage for %s. It might take a while."
,
ov_device
)
model_profile_run
()
gpu_device_type
=
ov_core
.
get_property
(
ov_device
,
device
.
type
)
memory_statistics
=
\
ov_core
.
get_property
(
ov_device
,
intel_gpu
.
memory_statistics
)
memory_utilization
=
cache_config
.
gpu_memory_utilization
if
gpu_device_type
==
device
.
Type
.
INTEGRATED
and
\
memory_utilization
>=
0.9
:
logger
.
warning
(
"iGPU is used with high gpu_memory_utilization=%f "
"value. This may cause low performance due to "
"occupying the majority of available system "
"memory. Please consider decreasing "
"gpu_memory_utilization or explicitly setting "
"`VLLM_OPENVINO_KVCACHE_SPACE` (GB) environment "
"variable."
,
memory_utilization
)
# sum up all used device memory
device_memory_types
=
[
"cl_mem"
,
"usm_device"
]
used_device_mem
=
\
sum
(
memory_statistics
.
get
(
key
,
0
)
for
key
in
device_memory_types
)
if
gpu_device_type
==
device
.
Type
.
INTEGRATED
:
used_device_mem
+=
memory_statistics
.
get
(
"usm_host"
,
0
)
# there could be unaccounted extra memory reserved by kernels, kept
# in memory pools, etc
# therefore, add a threshold to account for this
used_memory_threshold
=
1.1
used_device_mem
*=
used_memory_threshold
total_device_memory
=
\
ov_core
.
get_property
(
ov_device
,
intel_gpu
.
device_total_mem_size
)
def
format_memory_size
(
size
)
->
str
:
units
=
[
"B"
,
"KB"
,
"MB"
,
"GB"
]
unit_index
=
0
while
size
>
1024
and
unit_index
<
len
(
units
)
-
1
:
size
/=
1024
unit_index
+=
1
return
f
"
{
size
:.
2
f
}
{
units
[
unit_index
]
}
"
total_device_memory_str
=
\
format
(
format_memory_size
(
total_device_memory
))
used_device_memory_str
=
\
format
(
format_memory_size
(
used_device_mem
))
logger
.
info
(
"Total %s memory: %s. "
"Amount of memory required to run the model with "
"max_num_batched_tokens=%d: %s."
,
ov_device
,
total_device_memory_str
,
self
.
scheduler_config
.
max_num_batched_tokens
,
used_device_memory_str
)
if
used_device_mem
>=
total_device_memory
:
raise
RuntimeError
(
f
"The required memory size
{
used_device_memory_str
}
for model "
"is higher than the total available device "
"memory {total_device_memory_str}. Please consider to "
"decrease `max_num_batched_tokens` or increase "
"`gpu_memory_utilization`"
)
return
total_device_memory
*
memory_utilization
-
used_device_mem
Prev
1
…
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment