Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6cac54f4
Unverified
Commit
6cac54f4
authored
Jun 04, 2025
by
Chen Zhang
Committed by
GitHub
Jun 03, 2025
Browse files
[v1] Re-init input batch for multiple kv cache groups (#18654)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
6865fe00
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
61 additions
and
46 deletions
+61
-46
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+3
-26
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+3
-1
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+2
-1
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+9
-9
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+40
-6
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+4
-3
No files found.
tests/v1/worker/test_gpu_input_batch.py
View file @
6cac54f4
...
...
@@ -10,8 +10,6 @@ import torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.block_table
import
BlockTable
,
MultiGroupBlockTable
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
...
...
@@ -25,27 +23,6 @@ CUDA_DEVICES = [
MAX_NUM_PROMPT_TOKENS
=
64
def
get_kv_cache_config
()
->
KVCacheConfig
:
return
KVCacheConfig
(
num_blocks
=
10
,
tensors
=
{
"layer.0"
:
KVCacheTensor
(
size
=
1024
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
(
layer_names
=
[
"layer.0"
],
kv_cache_spec
=
FullAttentionSpec
(
block_size
=
1
,
num_kv_heads
=
1
,
head_size
=
16
,
dtype
=
torch
.
float16
,
use_mla
=
False
,
),
),
],
)
def
_compare_objs
(
obj1
,
obj2
):
attrs
=
inspect
.
getmembers
(
obj1
,
lambda
a
:
not
(
inspect
.
isroutine
(
a
)))
attr_names
=
set
([
...
...
@@ -252,7 +229,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
block_size
=
1
,
block_size
s
=
[
1
]
,
)
reqs
:
list
[
CachedRequestState
]
=
[]
req_id_reqs
=
{}
...
...
@@ -342,7 +319,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
block_size
=
1
,
block_size
s
=
[
1
]
,
)
ref_input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
...
...
@@ -351,7 +328,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
block_size
=
1
,
block_size
s
=
[
1
]
,
)
reqs
:
list
[
CachedRequestState
]
=
[]
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
6cac54f4
...
...
@@ -54,7 +54,9 @@ def initialize_kv_cache(runner: GPUModelRunner):
device
=
runner
.
device
,
pin_memory
=
runner
.
pin_memory
,
vocab_size
=
runner
.
model_config
.
get_vocab_size
(),
block_size
=
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
.
block_size
,
block_sizes
=
[
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
.
block_size
],
)
runner
.
initialize_attn_backend
(
kv_cache_config
)
...
...
vllm/v1/worker/block_table.py
View file @
6cac54f4
...
...
@@ -105,10 +105,11 @@ class MultiGroupBlockTable:
def
__init__
(
self
,
max_num_reqs
:
int
,
max_model_len
:
int
,
max_num_batched_tokens
:
int
,
pin_memory
:
bool
,
device
:
torch
.
device
,
block_size
:
int
)
->
None
:
device
:
torch
.
device
,
block_size
s
:
list
[
int
]
)
->
None
:
self
.
block_tables
=
[
BlockTable
(
max_num_reqs
,
cdiv
(
max_model_len
,
block_size
),
max_num_batched_tokens
,
pin_memory
,
device
)
for
block_size
in
block_sizes
]
def
append_row
(
self
,
block_ids
:
list
[
list
[
int
]],
row_idx
:
int
)
->
None
:
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
6cac54f4
...
...
@@ -56,14 +56,14 @@ class CachedRequestState:
class
InputBatch
:
def
__init__
(
self
,
max_num_reqs
:
int
,
max_model_len
:
int
,
max_num_batched_tokens
:
int
,
device
:
torch
.
device
,
pin_memory
:
bool
,
vocab_size
:
int
,
block_size
:
int
,
self
,
max_num_reqs
:
int
,
max_model_len
:
int
,
max_num_batched_tokens
:
int
,
device
:
torch
.
device
,
pin_memory
:
bool
,
vocab_size
:
int
,
block_size
s
:
list
[
int
],
# The block_size of each kv cache group
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
...
...
@@ -105,7 +105,7 @@ class InputBatch:
max_num_batched_tokens
=
max_num_batched_tokens
,
pin_memory
=
pin_memory
,
device
=
device
,
block_size
=
block_size
,
block_size
s
=
block_size
s
,
)
# Sampling-related.
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
6cac54f4
...
...
@@ -143,7 +143,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
]
=
[]
self
.
attn_backends
:
list
[
type
[
AttentionBackend
]]
=
[]
# self.kv_cache_config: KVCacheConfig
# self.input_batch: InputBatch # Persistent batch.
# req_id -> (input_id -> encoder_output)
self
.
encoder_cache
:
dict
[
str
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
...
...
@@ -173,6 +172,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
# Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside
# `initialize_kv_cache` based on the kv cache config. However, as in
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
# reasons, we have to initialize the input batch before `load_model`,
# quantization + weight offloading will fail otherwise. As a temporary
# solution, we initialize the input batch here, and re-initialize it
# in `initialize_kv_cache` if the block_sizes here is different from
# the block_sizes in the kv cache config.
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
...
...
@@ -180,7 +188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
block_size
=
self
.
cache_config
.
block_size
,
block_size
s
=
[
self
.
cache_config
.
block_size
]
,
)
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
...
...
@@ -2040,6 +2048,35 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
attn_backends
.
append
(
attn_backend_i
)
self
.
attn_metadata_builders
.
append
(
attn_metadata_builder_i
)
def
may_reinitialize_input_batch
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
are multiple KV cache groups.
Args:
kv_cache_config: The KV cache configuration.
"""
block_sizes
=
[
kv_cache_group
.
kv_cache_spec
.
block_size
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
]
if
block_sizes
!=
[
self
.
cache_config
.
block_size
]:
assert
self
.
cache_config
.
cpu_offload_gb
==
0
,
(
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 "
# noqa: E501
"for more details."
)
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
block_sizes
=
block_sizes
,
)
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Initialize KV cache based on `kv_cache_config`.
...
...
@@ -2047,11 +2084,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if
len
(
kv_cache_config
.
kv_cache_groups
)
>
1
:
raise
NotImplementedError
(
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
self
.
kv_cache_config
=
kv_cache_config
self
.
may_reinitialize_input_batch
(
kv_cache_config
)
self
.
initialize_attn_backend
(
kv_cache_config
)
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
6cac54f4
...
...
@@ -200,7 +200,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
block_size
=
self
.
block_size
,
block_size
s
=
[
self
.
block_size
]
,
)
# Cached torch/numpy tensor
...
...
@@ -1358,8 +1358,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
block_size
=
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
.
block_size
,
block_sizes
=
[
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
.
block_size
],
)
# Verify dtype compatibility between block_table_cpu and input_batch
assert
self
.
block_table_cpu
.
dtype
==
self
.
input_batch
.
block_table
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment