Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e60f550b
Unverified
Commit
e60f550b
authored
May 15, 2025
by
Chen Zhang
Committed by
GitHub
May 14, 2025
Browse files
[v1] Support multiple KV cache groups in GPU model runner (#17945)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
f25e0d11
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
482 additions
and
215 deletions
+482
-215
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+68
-3
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+18
-18
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+33
-6
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+41
-16
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+1
-1
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
...d/kv_transfer/kv_connector/v1/shared_storage_connector.py
+3
-3
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+2
-2
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+21
-13
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+6
-7
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+6
-6
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+10
-6
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+42
-0
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+47
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+7
-6
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+158
-112
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+19
-16
No files found.
tests/v1/core/test_kv_cache_utils.py
View file @
e60f550b
...
...
@@ -19,7 +19,8 @@ from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
hash_request_tokens
,
unify_kv_cache_configs
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
)
KVCacheGroupSpec
,
KVCacheTensor
,
SlidingWindowSpec
)
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
...
...
@@ -54,12 +55,14 @@ def new_kv_cache_spec(block_size=16,
num_kv_heads
=
2
,
head_size
=
64
,
dtype
=
torch
.
float32
,
use_mla
=
False
):
use_mla
=
False
,
sliding_window
=
None
):
return
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
num_kv_heads
,
head_size
=
head_size
,
dtype
=
dtype
,
use_mla
=
use_mla
)
use_mla
=
use_mla
,
sliding_window
=
sliding_window
)
def
test_none_hash
():
...
...
@@ -471,6 +474,68 @@ def test_unify_kv_cache_configs():
unify_kv_cache_configs
(
diff_kv_cache_config
)
def
test_merge_kv_cache_spec
():
same_layer_specs
=
[
new_kv_cache_spec
(
num_kv_heads
=
32
),
new_kv_cache_spec
(
num_kv_heads
=
32
),
]
merged_layer_spec
=
same_layer_specs
[
0
].
merge
(
same_layer_specs
)
assert
merged_layer_spec
.
block_size
==
16
assert
merged_layer_spec
.
num_kv_heads
==
32
assert
merged_layer_spec
.
head_size
==
64
assert
merged_layer_spec
.
dtype
==
torch
.
float32
assert
merged_layer_spec
.
sliding_window
is
None
different_layer_specs
=
[
new_kv_cache_spec
(
num_kv_heads
=
32
),
new_kv_cache_spec
(
num_kv_heads
=
16
),
]
with
pytest
.
raises
(
AssertionError
):
different_layer_specs
[
0
].
merge
(
different_layer_specs
)
full_spec
=
new_kv_cache_spec
(
num_kv_heads
=
32
)
different_type_layer_specs
=
[
full_spec
,
SlidingWindowSpec
(
block_size
=
full_spec
.
block_size
,
num_kv_heads
=
full_spec
.
num_kv_heads
,
head_size
=
full_spec
.
head_size
,
dtype
=
full_spec
.
dtype
,
use_mla
=
full_spec
.
use_mla
,
sliding_window
=
1
,
),
]
with
pytest
.
raises
(
AssertionError
):
different_type_layer_specs
[
0
].
merge
(
different_type_layer_specs
)
with
pytest
.
raises
(
AssertionError
):
different_type_layer_specs
[
1
].
merge
(
different_type_layer_specs
)
different_sliding_window_layer_specs
=
[
new_kv_cache_spec
(
num_kv_heads
=
32
),
new_kv_cache_spec
(
num_kv_heads
=
32
,
sliding_window
=
1
),
new_kv_cache_spec
(
num_kv_heads
=
32
,
sliding_window
=
2
),
]
with
pytest
.
raises
(
ValueError
):
different_sliding_window_layer_specs
[
0
].
merge
(
different_sliding_window_layer_specs
)
same_sliding_window_layer_specs
=
[
new_kv_cache_spec
(
num_kv_heads
=
32
,
sliding_window
=
1
),
new_kv_cache_spec
(
num_kv_heads
=
32
,
sliding_window
=
1
),
]
merged_layer_spec
=
same_sliding_window_layer_specs
[
0
].
merge
(
same_sliding_window_layer_specs
)
assert
merged_layer_spec
.
sliding_window
==
1
same_sliding_window_layer_spec_with_none
=
[
new_kv_cache_spec
(
num_kv_heads
=
32
,
sliding_window
=
1
),
new_kv_cache_spec
(
num_kv_heads
=
32
,
sliding_window
=
None
),
]
merged_layer_spec
=
same_sliding_window_layer_spec_with_none
[
0
].
merge
(
same_sliding_window_layer_spec_with_none
)
assert
merged_layer_spec
.
sliding_window
==
1
@
pytest
.
mark
.
parametrize
(
(
"model_id"
,
"max_model_len"
,
"want_estimated_max_len"
),
[
(
"Qwen/Qwen1.5-7B"
,
16385
,
16384
),
...
...
tests/v1/core/test_prefix_caching.py
View file @
e60f550b
...
...
@@ -84,7 +84,7 @@ def test_prefill(hash_algo):
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
1
,
2
,
3
,
4
]
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
# Check full block metadata
parent_block_hash
=
None
...
...
@@ -107,13 +107,13 @@ def test_prefill(hash_algo):
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
1
,
2
,
3
]
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
5
]
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
for
block
in
computed_blocks
.
blocks
:
assert
block
.
ref_cnt
==
2
...
...
@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
1
,
2
,
3
]
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
6
]
assert
blocks
.
get_block_ids
()
==
[
[
6
]
]
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
...
...
@@ -171,7 +171,7 @@ def test_prefill(hash_algo):
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
# This block ID order also checks the eviction order.
assert
blocks
.
get_block_ids
()
==
[
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
]
assert
blocks
.
get_block_ids
()
==
[
[
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
]
]
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
0
assert
manager
.
block_pool
.
free_block_queue
.
free_list_head
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
...
...
@@ -208,7 +208,7 @@ def test_prefill_plp():
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
1
,
2
,
3
,
4
]
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
.
blocks
]
# Check full block metadata
...
...
@@ -233,13 +233,13 @@ def test_prefill_plp():
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
1
,
2
,
3
]
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
5
]
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
for
block
in
computed_blocks
.
blocks
:
assert
block
.
ref_cnt
==
2
...
...
@@ -277,11 +277,11 @@ def test_prefill_plp():
block_ids
=
blocks
.
get_block_ids
()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert
[
b
.
block_hash
for
b
in
blocks
.
blocks
]
==
req0_block_hashes
assert
block_ids
!=
[
1
,
2
,
3
,
4
]
assert
block_ids
!=
[
[
1
,
2
,
3
,
4
]
]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
for
block_id
in
block_ids
:
for
block_id
in
block_ids
[
0
]
:
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
manager
.
free
(
req2
)
...
...
@@ -307,7 +307,7 @@ def test_decode():
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
1
,
2
,
3
,
4
]
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
# Append slots without allocating a new block.
req0
.
num_computed_tokens
=
55
...
...
@@ -379,12 +379,12 @@ def test_evict():
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
computed_blocks
.
get_block_ids
()
==
[
1
,
2
]
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
]
]
assert
num_computed_tokens
==
2
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
10
]
assert
blocks
.
get_block_ids
()
==
[
[
10
]
]
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
7
...
...
@@ -625,7 +625,7 @@ def test_mm_prefix_caching():
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
1
,
2
,
3
,
4
]
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
...
...
@@ -686,7 +686,7 @@ def test_cache_key_salting():
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
1
,
2
,
3
,
4
]
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
...
...
@@ -797,7 +797,7 @@ def test_reset_prefix_cache():
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
assert
blocks
.
get_block_ids
()
==
[
1
,
2
,
3
,
4
]
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
unique_token_ids
=
[
4
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
...
...
@@ -808,7 +808,7 @@ def test_reset_prefix_cache():
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
5
]
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
# Failed to reset prefix cache because some blocks are not freed yet.
assert
not
manager
.
reset_prefix_cache
()
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
e60f550b
...
...
@@ -9,9 +9,11 @@ import torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.
gpu_input_batch
import
(
BlockTable
,
CachedRequestState
,
InputBatch
)
from
vllm.v1.worker.
block_table
import
BlockTable
,
MultiGroupBlockTable
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
VOCAB_SIZE
=
1024
NUM_OUTPUT_TOKENS
=
20
...
...
@@ -22,6 +24,27 @@ CUDA_DEVICES = [
MAX_NUM_PROMPT_TOKENS
=
64
def
get_kv_cache_config
()
->
KVCacheConfig
:
return
KVCacheConfig
(
num_blocks
=
10
,
tensors
=
{
"layer.0"
:
KVCacheTensor
(
size
=
1024
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
(
layer_names
=
[
"layer.0"
],
kv_cache_spec
=
FullAttentionSpec
(
block_size
=
1
,
num_kv_heads
=
1
,
head_size
=
16
,
dtype
=
torch
.
float16
,
use_mla
=
False
,
),
),
],
)
def
_compare_objs
(
obj1
,
obj2
):
attrs
=
inspect
.
getmembers
(
obj1
,
lambda
a
:
not
(
inspect
.
isroutine
(
a
)))
attr_names
=
set
([
...
...
@@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2):
elif
isinstance
(
a
,
np
.
ndarray
):
if
np
.
allclose
(
a
,
b
):
is_same
=
True
elif
isinstance
(
a
,
MultiGroupBlockTable
):
for
a_i
,
b_i
in
zip
(
a
.
block_tables
,
b
.
block_tables
):
_compare_objs
(
a_i
,
b_i
)
is_same
=
True
elif
isinstance
(
a
,
(
BlockTable
,
SamplingMetadata
)):
_compare_objs
(
a
,
b
)
is_same
=
True
# if we make it here must be same
...
...
@@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params
=
_create_sampling_params
(),
mm_inputs
=
[],
mm_positions
=
[],
block_ids
=
[],
block_ids
=
[
[]
],
generator
=
None
,
num_computed_tokens
=
len
(
output_token_ids
),
output_token_ids
=
output_token_ids
,
...
...
@@ -220,11 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
max_num_batched_tokens
=
1024
,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
kv_cache_config
=
get_kv_cache_config
(),
)
reqs
:
list
[
CachedRequestState
]
=
[]
req_id_reqs
=
{}
...
...
@@ -310,20 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
max_num_batched_tokens
=
1024
,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
kv_cache_config
=
get_kv_cache_config
(),
)
ref_input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
max_num_batched_tokens
=
1024
,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
kv_cache_config
=
get_kv_cache_config
(),
)
reqs
:
list
[
CachedRequestState
]
=
[]
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
e60f550b
# SPDX-License-Identifier: Apache-2.0
import
weakref
import
pytest
import
torch
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VllmConfig
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
...
...
@@ -17,13 +18,34 @@ def initialize_kv_cache(runner: GPUModelRunner):
"""
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
"""
kv_cache_spec
=
FullAttentionSpec
(
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
64
,
dtype
=
torch
.
float16
,
use_mla
=
False
)
runner
.
attn_metadata_builder
=
runner
.
attn_backend
.
get_builder_cls
()(
weakref
.
proxy
(
runner
),
kv_cache_spec
,
runner
.
input_batch
.
block_table
)
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
10
,
tensors
=
{
"layer.0"
:
KVCacheTensor
(
size
=
1024
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
(
layer_names
=
[
"layer.0"
],
kv_cache_spec
=
FullAttentionSpec
(
block_size
=
16
,
num_kv_heads
=
runner
.
model_config
.
get_num_kv_heads
(
runner
.
parallel_config
),
head_size
=
runner
.
model_config
.
get_head_size
(),
dtype
=
runner
.
kv_cache_dtype
,
use_mla
=
False
,
))
])
runner
.
kv_cache_config
=
kv_cache_config
runner
.
input_batch
=
InputBatch
(
max_num_reqs
=
runner
.
max_num_reqs
,
max_model_len
=
runner
.
max_model_len
,
max_num_batched_tokens
=
runner
.
max_num_tokens
,
device
=
runner
.
device
,
pin_memory
=
runner
.
pin_memory
,
vocab_size
=
runner
.
model_config
.
get_vocab_size
(),
kv_cache_config
=
kv_cache_config
,
)
runner
.
initialize_attn_backend
(
kv_cache_config
)
@
pytest
.
fixture
...
...
@@ -48,10 +70,12 @@ def model_runner():
swap_space
=
0
,
cache_dtype
=
"auto"
,
)
parallel_config
=
ParallelConfig
()
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
cache_config
,
scheduler_config
=
scheduler_config
,
parallel_config
=
parallel_config
,
)
device
=
"cuda"
...
...
@@ -73,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes
=
[],
mm_positions
=
[],
sampling_params
=
SamplingParams
(),
block_ids
=
[
0
],
block_ids
=
[
[
0
]
],
num_computed_tokens
=
0
,
lora_request
=
None
,
))
...
...
@@ -111,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner,
def
_is_req_state_block_table_match
(
model_runner
,
req_id
:
str
)
->
bool
:
req_index
=
model_runner
.
input_batch
.
req_id_to_index
[
req_id
]
block_table
=
model_runner
.
input_batch
.
block_table
block_table
=
model_runner
.
input_batch
.
block_table
[
0
]
req_state
=
model_runner
.
requests
[
req_id
]
if
block_table
.
num_blocks_per_row
[
req_index
]
!=
len
(
req_state
.
block_ids
):
if
block_table
.
num_blocks_per_row
[
req_index
]
!=
len
(
req_state
.
block_ids
[
0
]):
return
False
num_blocks
=
block_table
.
num_blocks_per_row
[
req_index
]
return
(
block_table
.
block_table_np
[
req_index
,
:
num_blocks
]
==
req_state
.
block_ids
).
all
()
req_state
.
block_ids
[
0
]
).
all
()
def
test_update_states_new_request
(
model_runner
):
...
...
@@ -200,7 +225,7 @@ def test_update_states_request_resumed(model_runner):
req_id
=
req_id
,
resumed_from_preemption
=
False
,
new_token_ids
=
[],
new_block_ids
=
[],
new_block_ids
=
[
[]
],
num_computed_tokens
=
0
,
)
...
...
tests/weight_loading/models.txt
View file @
e60f550b
...
...
@@ -2,7 +2,7 @@ gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
#
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
gptq, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq, TheBloke/Llama-2-7B-GPTQ, main
...
...
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
View file @
e60f550b
...
...
@@ -288,7 +288,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
for
new_req
in
scheduler_output
.
scheduled_new_reqs
:
if
new_req
.
req_id
in
self
.
_requests_need_load
:
meta
.
add_request
(
token_ids
=
new_req
.
prompt_token_ids
,
block_ids
=
new_req
.
block_ids
,
block_ids
=
new_req
.
block_ids
[
0
]
,
block_size
=
self
.
_block_size
,
is_store
=
False
)
total_need_load
+=
1
...
...
@@ -299,7 +299,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# the original prompt tokens.
if
not
self
.
_found_match_for_request
(
new_req
):
meta
.
add_request
(
token_ids
=
new_req
.
prompt_token_ids
,
block_ids
=
new_req
.
block_ids
,
block_ids
=
new_req
.
block_ids
[
0
]
,
block_size
=
self
.
_block_size
,
is_store
=
True
)
...
...
@@ -319,7 +319,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids
=
cached_req
.
new_block_ids
block_ids
=
cached_req
.
new_block_ids
[
0
]
meta
.
add_request
(
token_ids
=
token_ids
,
block_ids
=
block_ids
,
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
e60f550b
...
...
@@ -67,13 +67,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_model_len
=
self
.
runner
.
model_config
.
max_model_len
assert
max_model_len
==
32768
,
\
"AITER MLA requires max_model_len=32768"
assert
self
.
runner
.
block_size
==
1
,
"AITER MLA"
\
assert
self
.
kv_cache_spec
.
block_size
==
1
,
"AITER MLA"
\
"only supports block size 1."
def
_get_paged_kv_tensors
(
self
,
block_table
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
page_size
=
self
.
runner
.
block_size
page_size
=
self
.
kv_cache_spec
.
block_size
block_table_bounds
=
(
seq_lens
+
page_size
-
1
)
//
page_size
mask
=
(
torch
.
arange
(
block_table
.
size
(
1
),
...
...
vllm/v1/core/kv_cache_manager.py
View file @
e60f550b
...
...
@@ -32,9 +32,16 @@ class KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""
return
cls
([])
def
get_block_ids
(
self
)
->
list
[
int
]:
"""Converts the KVCacheBlocks instance to a list of block IDs."""
return
[
block
.
block_id
for
block
in
self
.
blocks
]
def
get_block_ids
(
self
)
->
list
[
list
[
int
]]:
"""
Converts the KVCacheBlocks instance to block_ids.
Returns:
list[list[int]]: A two-level list where
* the outer list corresponds to KV cache groups (only 1 group now)
* each inner list contains the block_ids of the blocks in that group
"""
return
[[
block
.
block_id
for
block
in
self
.
blocks
]]
def
get_unhashed_block_ids
(
self
)
->
list
[
int
]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
...
...
@@ -300,9 +307,9 @@ class KVCacheManager:
self
,
request
:
Request
,
num_running_requests
:
int
,
)
->
int
:
)
->
list
[
int
]
:
"""Calculate the number of common prefix blocks shared by all requests
in the RUNNING state.
in the RUNNING state
for each kv cache group
.
The function determines this by selecting any request and iterating
through its blocks. A block is considered a common prefix block if its
...
...
@@ -332,11 +339,14 @@ class KVCacheManager:
requests in the current step.
Returns:
int: The number of common prefix blocks.
list[int]: The number of common prefix blocks for each kv cache
group.
"""
assert
request
.
status
==
RequestStatus
.
RUNNING
return
self
.
single_type_manager
.
get_num_common_prefix_blocks
(
return
[
self
.
single_type_manager
.
get_num_common_prefix_blocks
(
request
.
request_id
,
num_running_requests
)
]
def
free_block_hashes
(
self
,
request
:
Request
)
->
None
:
"""Discard the block hashes for the request.
...
...
@@ -354,10 +364,8 @@ class KVCacheManager:
"""
return
self
.
block_pool
.
take_events
()
def
get_block_ids
(
self
,
request_id
:
str
)
->
list
[
int
]:
def
get_block_ids
(
self
,
request_id
:
str
)
->
list
[
list
[
int
]
]
:
"""Get the block ids of a request."""
assert
request_id
in
self
.
single_type_manager
.
req_to_blocks
return
[
block
.
block_id
for
block
in
self
.
single_type_manager
.
req_to_blocks
[
request_id
]
]
return
KVCacheBlocks
(
self
.
single_type_manager
.
req_to_blocks
[
request_id
]
).
get_block_ids
()
vllm/v1/core/kv_cache_utils.py
View file @
e60f550b
...
...
@@ -577,14 +577,12 @@ def create_kv_cache_group_specs(
"""
kv_cache_groups
=
[]
for
layer_names_one_group
in
grouped_layer_names
:
layer_spec
=
kv_cache_spec
[
layer_names_one_group
[
0
]]
assert
all
(
kv_cache_spec
[
layer_name
]
==
layer_spec
for
layer_name
in
layer_names_one_group
[
1
:]),
(
"All layers in the same KV cache group must share the same "
"KVCacheSpec."
)
layer_specs
=
[
kv_cache_spec
[
layer_name
]
for
layer_name
in
layer_names_one_group
]
merged_layer_spec
=
layer_specs
[
0
].
merge
(
layer_specs
)
kv_cache_groups
.
append
(
KVCacheGroupSpec
(
layer_names_one_group
,
layer_spec
))
KVCacheGroupSpec
(
layer_names_one_group
,
merged_
layer_spec
))
return
kv_cache_groups
...
...
@@ -683,6 +681,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
head_size
=
spec
.
head_size
,
dtype
=
spec
.
dtype
,
use_mla
=
spec
.
use_mla
,
sliding_window
=
spec
.
sliding_window
,
)
...
...
vllm/v1/core/sched/output.py
View file @
e60f550b
...
...
@@ -26,7 +26,7 @@ class NewRequestData:
mm_hashes
:
list
[
str
]
mm_positions
:
list
[
PlaceholderRange
]
sampling_params
:
SamplingParams
block_ids
:
list
[
int
]
block_ids
:
list
[
list
[
int
]
]
num_computed_tokens
:
int
lora_request
:
Optional
[
LoRARequest
]
...
...
@@ -34,7 +34,7 @@ class NewRequestData:
def
from_request
(
cls
,
request
:
Request
,
block_ids
:
list
[
int
],
block_ids
:
list
[
list
[
int
]
]
,
)
->
NewRequestData
:
return
cls
(
req_id
=
request
.
request_id
,
...
...
@@ -85,7 +85,7 @@ class CachedRequestData:
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption
:
bool
new_token_ids
:
list
[
int
]
new_block_ids
:
list
[
int
]
new_block_ids
:
list
[
list
[
int
]
]
num_computed_tokens
:
int
@
classmethod
...
...
@@ -94,7 +94,7 @@ class CachedRequestData:
request
:
Request
,
resumed_from_preemption
:
bool
,
new_token_ids
:
list
[
int
],
new_block_ids
:
list
[
int
],
new_block_ids
:
list
[
list
[
int
]
]
,
)
->
CachedRequestData
:
return
cls
(
req_id
=
request
.
request_id
,
...
...
@@ -131,9 +131,9 @@ class SchedulerOutput:
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]]
# Number of common prefix blocks for all requests.
# Number of common prefix blocks for all requests
in each KV cache group
.
# This can be used for cascade attention.
num_common_prefix_blocks
:
int
num_common_prefix_blocks
:
list
[
int
]
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
...
...
vllm/v1/core/sched/scheduler.py
View file @
e60f550b
...
...
@@ -173,7 +173,7 @@ class Scheduler(SchedulerInterface):
# uses structured decoding.
structured_output_request_ids
:
dict
[
str
,
int
]
=
{}
req_to_new_block_ids
:
dict
[
str
,
list
[
int
]]
=
{}
req_to_new_block_ids
:
dict
[
str
,
list
[
list
[
int
]]
]
=
{}
num_scheduled_tokens
:
dict
[
str
,
int
]
=
{}
token_budget
=
self
.
max_num_scheduled_tokens
# Encoder-related.
...
...
@@ -477,7 +477,8 @@ class Scheduler(SchedulerInterface):
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks
=
0
num_common_prefix_blocks
=
[
0
]
*
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
if
self
.
running
:
any_request
=
self
.
running
[
0
]
num_common_prefix_blocks
=
(
...
...
@@ -564,7 +565,7 @@ class Scheduler(SchedulerInterface):
request
:
Request
,
num_scheduled_tokens
:
int
,
num_scheduled_spec_tokens
:
int
,
new_block_ids
:
list
[
int
],
new_block_ids
:
list
[
list
[
int
]
]
,
resumed_from_preemption
:
bool
,
)
->
CachedRequestData
:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
...
...
@@ -939,7 +940,9 @@ class Scheduler(SchedulerInterface):
"""
if
self
.
connector
is
None
:
return
False
,
None
block_ids
=
self
.
kv_cache_manager
.
get_block_ids
(
request
.
request_id
)
assert
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
1
,
"KV connector only supports one KV cache group now"
block_ids
=
self
.
kv_cache_manager
.
get_block_ids
(
request
.
request_id
)[
0
]
return
self
.
connector
.
request_finished
(
request
,
block_ids
)
def
_update_waiting_for_remote_kv
(
self
,
request
:
Request
)
->
bool
:
...
...
@@ -956,9 +959,10 @@ class Scheduler(SchedulerInterface):
"""
if
request
.
request_id
not
in
self
.
finished_recving_kv_req_ids
:
return
False
assert
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
1
,
"KV connector only supports one KV cache group now"
# Now that the blocks are ready, actually cache them.
block_ids
=
self
.
kv_cache_manager
.
get_block_ids
(
request
.
request_id
)
block_ids
=
self
.
kv_cache_manager
.
get_block_ids
(
request
.
request_id
)
[
0
]
num_computed_tokens
=
len
(
block_ids
)
*
self
.
block_size
if
num_computed_tokens
==
request
.
num_tokens
:
num_computed_tokens
-=
1
...
...
vllm/v1/kv_cache_interface.py
View file @
e60f550b
# SPDX-License-Identifier: Apache-2.0
import
copy
from
dataclasses
import
dataclass
from
typing
import
Optional
import
torch
from
typing_extensions
import
Self
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
...
...
@@ -53,6 +56,16 @@ class KVCacheSpec:
"""
raise
NotImplementedError
@
classmethod
def
merge
(
cls
,
specs
:
list
[
Self
])
->
Self
:
"""
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
"""
assert
all
(
spec
.
type_id
==
specs
[
0
].
type_id
for
spec
in
specs
[
1
:]),
(
"All layers in the same KV cache group must share the same "
"type_id."
)
return
copy
.
deepcopy
(
specs
[
0
])
@
dataclass
class
AttentionSpec
(
KVCacheSpec
):
...
...
@@ -71,6 +84,16 @@ class AttentionSpec(KVCacheSpec):
@
dataclass
class
FullAttentionSpec
(
AttentionSpec
):
sliding_window
:
Optional
[
int
]
=
None
"""
When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding
window attention are regarded as full attention in KV cache manager
(blocks are allocated for all tokens), while computed as sliding window
attention in model runner.
In this case, we use FullAttentionSpec and record the sliding window size.
Default to None for not using sliding window attention.
"""
@
property
def
type_id
(
self
)
->
str
:
...
...
@@ -80,6 +103,25 @@ class FullAttentionSpec(AttentionSpec):
max_model_len
=
vllm_config
.
model_config
.
max_model_len
return
cdiv
(
max_model_len
,
self
.
block_size
)
*
self
.
page_size_bytes
@
classmethod
def
merge
(
cls
,
specs
:
list
[
Self
])
->
Self
:
"""
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
merged_spec
=
super
().
merge
(
specs
)
sliding_window
=
set
(
spec
.
sliding_window
for
spec
in
specs
if
spec
.
sliding_window
is
not
None
)
if
len
(
sliding_window
)
==
0
:
merged_spec
.
sliding_window
=
None
elif
len
(
sliding_window
)
==
1
:
merged_spec
.
sliding_window
=
sliding_window
.
pop
()
else
:
raise
ValueError
(
"All sliding window layers in the same KV cache group "
"must have the same window size."
)
return
merged_spec
@
dataclass
class
SlidingWindowSpec
(
AttentionSpec
):
...
...
vllm/v1/worker/block_table.py
View file @
e60f550b
...
...
@@ -4,6 +4,8 @@ import numpy as np
import
torch
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
logger
=
init_logger
(
__name__
)
...
...
@@ -96,3 +98,48 @@ class BlockTable:
def
get_numpy_array
(
self
)
->
np
.
ndarray
:
"""Returns the numpy array of the block table."""
return
self
.
block_table_np
class
MultiGroupBlockTable
:
"""The BlockTables for each KV cache group."""
def
__init__
(
self
,
max_num_reqs
:
int
,
max_model_len
:
int
,
max_num_batched_tokens
:
int
,
pin_memory
:
bool
,
device
:
torch
.
device
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
max_num_blocks_per_req
=
[
cdiv
(
max_model_len
,
g
.
kv_cache_spec
.
block_size
)
for
g
in
kv_cache_config
.
kv_cache_groups
]
self
.
block_tables
=
[
BlockTable
(
max_num_reqs
,
max_num_blocks_per_req
[
i
],
max_num_batched_tokens
,
pin_memory
,
device
)
for
i
in
range
(
len
(
kv_cache_config
.
kv_cache_groups
))
]
def
append_row
(
self
,
block_ids
:
list
[
list
[
int
]],
row_idx
:
int
)
->
None
:
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
block_table
.
append_row
(
block_ids
[
i
],
row_idx
)
def
add_row
(
self
,
block_ids
:
list
[
list
[
int
]],
row_idx
:
int
)
->
None
:
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
block_table
.
add_row
(
block_ids
[
i
],
row_idx
)
def
move_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
for
block_table
in
self
.
block_tables
:
block_table
.
move_row
(
src
,
tgt
)
def
swap_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
for
block_table
in
self
.
block_tables
:
block_table
.
swap_row
(
src
,
tgt
)
def
commit
(
self
,
num_reqs
:
int
)
->
None
:
for
block_table
in
self
.
block_tables
:
block_table
.
commit
(
num_reqs
)
def
clear
(
self
)
->
None
:
for
block_table
in
self
.
block_tables
:
block_table
.
clear
()
def
__getitem__
(
self
,
idx
:
int
)
->
"BlockTable"
:
"""Returns the BlockTable for the i-th KV cache group."""
return
self
.
block_tables
[
idx
]
vllm/v1/worker/gpu_input_batch.py
View file @
e60f550b
...
...
@@ -11,10 +11,11 @@ from vllm.lora.request import LoRARequest
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.utils
import
swap_dict_values
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.utils
import
copy_slice
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
MultiGroup
BlockTable
_SAMPLING_EPS
=
1e-5
...
...
@@ -29,7 +30,7 @@ class CachedRequestState:
sampling_params
:
SamplingParams
generator
:
Optional
[
torch
.
Generator
]
block_ids
:
list
[
int
]
block_ids
:
list
[
list
[
int
]
]
num_computed_tokens
:
int
output_token_ids
:
list
[
int
]
...
...
@@ -58,15 +59,14 @@ class InputBatch:
self
,
max_num_reqs
:
int
,
max_model_len
:
int
,
max_num_blocks_per_req
:
int
,
max_num_batched_tokens
:
int
,
device
:
torch
.
device
,
pin_memory
:
bool
,
vocab_size
:
int
,
kv_cache_config
:
KVCacheConfig
,
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
self
.
max_num_blocks_per_req
=
max_num_blocks_per_req
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
device
=
device
self
.
pin_memory
=
pin_memory
...
...
@@ -99,12 +99,13 @@ class InputBatch:
self
.
num_computed_tokens_cpu_tensor
.
numpy
()
# Block table.
self
.
block_table
=
BlockTable
(
self
.
block_table
=
MultiGroup
BlockTable
(
max_num_reqs
=
max_num_reqs
,
max_
num_blocks_per_req
=
max_num_blocks_per_req
,
max_
model_len
=
max_model_len
,
max_num_batched_tokens
=
max_num_batched_tokens
,
pin_memory
=
pin_memory
,
device
=
device
,
kv_cache_config
=
kv_cache_config
,
)
# Sampling-related.
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
e60f550b
...
...
@@ -12,6 +12,8 @@ import torch.distributed
import
torch.nn
as
nn
from
vllm.attention
import
AttentionType
,
get_attn_backend
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadataBuilder
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
...
...
@@ -31,8 +33,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
GiB_bytes
,
La
yerBlockType
,
LazyLoader
,
cdiv
,
check_use_alibi
,
is_pin_memory_available
)
GiB_bytes
,
La
zyLoader
,
cdiv
,
check_use_alibi
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
...
...
@@ -49,6 +51,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.utils
import
is_spec_decode_supported
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
...
...
@@ -100,59 +103,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
kv_cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
# NOTE(woosuk): sliding_window is None for models with interleaved
# attention. Use interleaved_sliding_window instead.
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
interleaved_sliding_window
=
getattr
(
model_config
.
hf_text_config
,
"interleaved_sliding_window"
,
None
)
self
.
window_size
=
(
self
.
sliding_window
or
self
.
interleaved_sliding_window
)
self
.
is_multimodal_model
=
model_config
.
is_multimodal_model
self
.
block_size
=
cache_config
.
block_size
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_num_blocks_per_req
=
cdiv
(
self
.
max_model_len
,
self
.
block_size
)
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
# Model-related.
self
.
num_attn_layers
=
model_config
.
get_num_layers_by_block_type
(
parallel_config
,
LayerBlockType
.
attention
)
self
.
num_query_heads
=
model_config
.
get_num_attention_heads
(
parallel_config
)
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
head_size
=
model_config
.
get_head_size
()
self
.
hidden_size
=
model_config
.
get_hidden_size
()
self
.
attention_chunk_size
=
model_config
.
attention_chunk_size
self
.
attn_backend
=
get_attn_backend
(
self
.
head_size
,
self
.
dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
use_mla
=
self
.
model_config
.
use_mla
,
)
if
self
.
attn_backend
is
None
:
error_msg
=
(
f
"Error with get_att_backend:
{
self
.
head_size
=
}
, "
f
"
{
self
.
dtype
=
}
,
{
self
.
kv_cache_dtype
=
}
,
{
self
.
block_size
=
}
, "
f
"
{
self
.
model_config
.
is_attention_free
=
}
, "
f
"
{
self
.
model_config
.
use_mla
=
}
"
)
logger
.
error
(
error_msg
)
raise
NotImplementedError
(
"Non-Attention backend is not supported by V1 GPUModelRunner."
)
if
self
.
vllm_config
.
compilation_config
.
full_cuda_graph
:
attn_backend_name
=
self
.
attn_backend
.
__name__
flash_attn_version
=
get_flash_attn_version
()
if
attn_backend_name
!=
"FlashAttentionBackend"
or
\
flash_attn_version
!=
3
:
raise
ValueError
(
f
"full_cuda_graph is only supported with "
f
"FA3. Current attention backend is
{
attn_backend_name
}
, "
f
"FlashAttention version is
{
flash_attn_version
}
."
)
self
.
cascade_attn_enabled
=
not
self
.
model_config
.
disable_cascade_attn
# Multi-modal data support
...
...
@@ -174,8 +135,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
self
.
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
]
=
[]
self
.
attn_backends
:
list
[
type
[
AttentionBackend
]]
=
[]
# self.kv_cache_config: KVCacheConfig
# self.
attn_metadata_builder: type[AttentionMetadataBuilder]
# self.
input_batch: InputBatch # Persistent batch.
# req_id -> (input_id -> encoder_output)
self
.
encoder_cache
:
dict
[
str
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
...
...
@@ -200,16 +163,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
# Persistent batch.
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_blocks_per_req
=
self
.
max_num_blocks_per_req
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
model_config
.
get_vocab_size
(),
)
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
...
...
@@ -304,6 +257,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory
=
self
.
pin_memory
)
self
.
seq_lens_np
=
self
.
seq_lens_cpu
.
numpy
()
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
scheduler_output: The scheduler output.
Returns:
True if the batch was reordered, False otherwise.
"""
batch_reordered
=
self
.
attn_metadata_builders
[
0
].
reorder_batch
(
self
.
input_batch
,
scheduler_output
)
# For models with multiple KV cache groups, the groups should agree on
# the same order of requests. We ensure this by only allowing the first
# group to reorder the batch and asserting that all other groups do not
# reorder the batch.
for
i
in
range
(
1
,
len
(
self
.
kv_cache_config
.
kv_cache_groups
)):
assert
not
self
.
attn_metadata_builders
[
i
].
reorder_batch
(
self
.
input_batch
,
scheduler_output
)
return
batch_reordered
def
_update_states
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""Update the cached states and the persistent batch with the scheduler
output.
...
...
@@ -440,7 +418,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update the block IDs.
if
not
req_data
.
resumed_from_preemption
:
# Append the new blocks to the existing block IDs.
req_state
.
block_ids
.
extend
(
req_data
.
new_block_ids
)
for
i
in
range
(
len
(
self
.
kv_cache_config
.
kv_cache_groups
)):
req_state
.
block_ids
[
i
].
extend
(
req_data
.
new_block_ids
[
i
])
else
:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
...
...
@@ -498,11 +477,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
removed_req_indices
:
self
.
input_batch
.
condense
(
removed_req_indices
)
# Some attention backends (namely MLA) may want to separate requests
# based on if the attention computation will be compute-bound or
# memory-bound. This gives them a hook to do that.
batch_reordered
=
self
.
attn_metadata_builder
.
reorder_batch
(
self
.
input_batch
,
scheduler_output
)
batch_reordered
=
self
.
_may_reorder_batch
(
scheduler_output
)
if
batch_changed
or
batch_reordered
:
self
.
input_batch
.
refresh_sampling_metadata
()
...
...
@@ -570,21 +545,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
torch
.
from_numpy
(
token_indices
),
out
=
self
.
input_ids_cpu
[:
total_num_scheduled_tokens
])
# Calculate the slot mapping.
# Calculate the slot mapping for each KV cache group.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
block_size
=
kv_cache_group_spec
.
kv_cache_spec
.
block_size
block_table
:
BlockTable
=
self
.
input_batch
.
block_table
[
kv_cache_group_id
]
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
# because M (max_model_len) is not necessarily divisible by block_size.
block_table_indices
=
(
req_indices
*
self
.
max_num_blocks_per_req
+
positions_np
//
self
.
block_size
)
block_table_cpu
=
self
.
input_batch
.
block_table
.
get_cpu_tensor
()
block_numbers
=
block_table_cpu
.
flatten
()[
block_table_indices
].
numpy
()
block_offsets
=
positions_np
%
self
.
block_size
np
.
add
(
block_numbers
*
self
.
block_size
,
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
block_table_indices
=
(
req_indices
*
block_table
.
max_num_blocks_per_req
+
positions_np
//
block_size
)
block_table_cpu
=
block_table
.
get_cpu_tensor
()
block_numbers
=
block_table_cpu
.
flatten
(
)[
block_table_indices
].
numpy
()
block_offsets
=
positions_np
%
block_size
np
.
add
(
block_numbers
*
block_size
,
block_offsets
,
out
=
self
.
input_batch
.
block_table
.
slot_mapping_np
[:
total_num_scheduled_tokens
])
out
=
block_table
.
slot_mapping_np
[:
total_num_scheduled_tokens
])
# Prepare the attention metadata.
self
.
query_start_loc_np
[
0
]
=
0
...
...
@@ -626,10 +609,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata
:
dict
[
str
,
FlashAttentionMetadata
]
=
{}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
# NOTE(Chen): there is exactly one KV cache group that contains all
# attetnion layers in the model for now, so the current logic for
# getting attn_metadata is not related to kv_cache_group information.
# Will extend this part to support multiple KV cache groups later.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
...
...
@@ -638,15 +617,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
cascade_attn_enabled
:
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
scheduler_output
.
num_common_prefix_blocks
,
scheduler_output
.
num_common_prefix_blocks
[
kv_cache_group_id
],
kv_cache_group_spec
.
kv_cache_spec
,
self
.
attn_metadata_builders
[
kv_cache_group_id
],
)
attn_metadata_i
=
self
.
attn_metadata_builder
.
build
(
attn_metadata_i
=
(
self
.
attn_metadata_builders
[
kv_cache_group_id
].
build
(
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
common_prefix_len
=
common_prefix_len
,
common_attn_metadata
=
common_attn_metadata
)
common_attn_metadata
=
common_attn_metadata
)
)
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
...
...
@@ -684,6 +667,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
,
num_scheduled_tokens
:
np
.
ndarray
,
num_common_prefix_blocks
:
int
,
kv_cache_spec
:
KVCacheSpec
,
attn_metadata_builder
:
AttentionMetadataBuilder
,
)
->
int
:
"""Compute the length of the common prefix for cascade attention.
...
...
@@ -702,7 +687,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Returns:
int: Length of common prefix in tokens.
"""
common_prefix_len
=
num_common_prefix_blocks
*
self
.
block_size
common_prefix_len
=
num_common_prefix_blocks
*
kv_cache_spec
.
block_size
if
common_prefix_len
==
0
:
# Common case.
return
0
...
...
@@ -751,15 +736,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_prefix_len
,
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
].
min
())
# common_prefix_len should be a multiple of the block size.
common_prefix_len
=
(
common_prefix_len
//
self
.
block_size
*
self
.
block_size
)
use_cascade
=
self
.
attn_metadata_builder
.
use_cascade_attention
(
common_prefix_len
=
(
common_prefix_len
//
kv_cache_spec
.
block_size
*
kv_cache_spec
.
block_size
)
use_sliding_window
=
(
isinstance
(
kv_cache_spec
,
SlidingWindowSpec
)
or
(
isinstance
(
kv_cache_spec
,
FullAttentionSpec
)
and
kv_cache_spec
.
sliding_window
is
not
None
))
assert
isinstance
(
kv_cache_spec
,
AttentionSpec
)
use_cascade
=
attn_metadata_builder
.
use_cascade_attention
(
common_prefix_len
=
common_prefix_len
,
query_lens
=
num_scheduled_tokens
,
num_query_heads
=
self
.
num_query_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
kv_cache_spec
.
num_kv_heads
,
use_alibi
=
self
.
use_alibi
,
use_sliding_window
=
se
lf
.
window_size
is
not
None
,
use_sliding_window
=
u
se
_sliding_window
,
num_sms
=
self
.
num_sms
,
)
return
common_prefix_len
if
use_cascade
else
0
...
...
@@ -1577,7 +1566,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype
=
np
.
int32
)
if
skip_attn
:
attn_metadata
=
None
attn_metadata
:
Optional
[
dict
[
str
,
FlashAttentionMetadata
]]
=
None
else
:
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
...
...
@@ -1585,13 +1574,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
)
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
attn_metadata
=
{}
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
attn_metadata_i
=
(
self
.
attn_metadata_builders
[
kv_cache_group_id
].
build
(
num_reqs
=
num_tokens
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
num_tokens
,
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
)
))
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
with
self
.
maybe_dummy_run_with_lora
(
self
.
lora_config
,
num_scheduled_tokens
):
...
...
@@ -1822,6 +1817,56 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logger
.
info
(
"Graph capturing finished in %.0f secs, took %.2f GiB"
,
elapsed_time
,
cuda_graph_size
/
(
1
<<
30
))
def
initialize_attn_backend
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Initialize the attention backends and attention metadata builders.
"""
assert
len
(
self
.
attn_backends
)
==
0
and
len
(
self
.
attn_metadata_builders
)
==
0
,
"Attention backends are already initialized"
for
i
,
kv_cache_group_spec
in
enumerate
(
kv_cache_config
.
kv_cache_groups
):
kv_cache_spec
=
kv_cache_group_spec
.
kv_cache_spec
if
not
isinstance
(
kv_cache_spec
,
AttentionSpec
):
raise
NotImplementedError
(
"Only AttentionSpec is supported for now."
)
attn_backend_i
=
get_attn_backend
(
kv_cache_spec
.
head_size
,
self
.
dtype
,
kv_cache_spec
.
dtype
,
kv_cache_spec
.
block_size
,
self
.
model_config
.
is_attention_free
,
use_mla
=
kv_cache_spec
.
use_mla
,
)
if
attn_backend_i
is
None
:
error_msg
=
(
f
"Error with get_attn_backend:
{
kv_cache_spec
.
head_size
=
}
, "
f
"
{
self
.
dtype
=
}
,
{
kv_cache_spec
.
dtype
=
}
, "
f
"
{
kv_cache_spec
.
block_size
=
}
, "
f
"
{
self
.
model_config
.
is_attention_free
=
}
, "
f
"
{
kv_cache_spec
.
use_mla
=
}
"
)
logger
.
error
(
error_msg
)
raise
NotImplementedError
(
"Non-Attention backend is not supported by V1 "
"GPUModelRunner."
)
if
self
.
vllm_config
.
compilation_config
.
full_cuda_graph
:
attn_backend_name
=
attn_backend_i
.
__name__
flash_attn_version
=
get_flash_attn_version
()
if
attn_backend_name
!=
"FlashAttentionBackend"
or
\
flash_attn_version
!=
3
:
raise
ValueError
(
f
"full_cuda_graph is only supported with "
f
"FA3. Current attention backend is "
f
"
{
attn_backend_name
}
, FlashAttention version is "
f
"
{
flash_attn_version
}
."
)
block_table_i
=
self
.
input_batch
.
block_table
[
i
]
attn_metadata_builder_i
=
attn_backend_i
.
get_builder_cls
()(
weakref
.
proxy
(
self
),
kv_cache_spec
,
block_table_i
)
self
.
attn_backends
.
append
(
attn_backend_i
)
self
.
attn_metadata_builders
.
append
(
attn_metadata_builder_i
)
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Initialize KV cache based on `kv_cache_config`.
...
...
@@ -1829,15 +1874,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if
len
(
kv_cache_config
.
kv_cache_groups
)
>
1
:
raise
NotImplementedError
(
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
self
.
kv_cache_config
=
kv_cache_config
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
kv_cache_config
=
kv_cache_config
,
)
self
.
initialize_attn_backend
(
kv_cache_config
)
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
:
for
i
,
kv_cache_group
in
enumerate
(
kv_cache_config
.
kv_cache_groups
)
:
kv_cache_spec
=
kv_cache_group
.
kv_cache_spec
for
layer_name
in
kv_cache_group
.
layer_names
:
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
...
...
@@ -1852,7 +1903,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the min of all `num_blocks`. Verify it here.
assert
num_blocks
>=
kv_cache_config
.
num_blocks
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
kv_cache_shape
=
self
.
attn_backend
s
[
i
]
.
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
dtype
=
kv_cache_spec
.
dtype
...
...
@@ -1872,11 +1923,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
has_kv_transfer_group
():
get_kv_transfer_group
().
register_kv_caches
(
kv_caches
)
self
.
attn_metadata_builder
=
self
.
attn_backend
.
get_builder_cls
()(
weakref
.
proxy
(
self
),
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
,
self
.
input_batch
.
block_table
)
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
e60f550b
...
...
@@ -171,19 +171,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
# req_id -> (input_id -> encoder_output)
self
.
encoder_cache
:
dict
[
str
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
# self.input_batch: InputBatch # Persistent batch.
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
# Persistent batch.
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_blocks_per_req
=
self
.
max_num_blocks_per_req
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
vocab_size
,
)
# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
...
...
@@ -199,7 +190,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
block_table_cpu
=
torch
.
zeros
(
(
self
.
max_num_reqs
,
self
.
max_num_blocks_per_req
),
dtype
=
self
.
input_batch
.
block_table
.
get_cpu_tensor
().
dtype
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
query_start_loc_cpu
=
torch
.
zeros
(
self
.
max_num_tokens
+
1
,
...
...
@@ -524,12 +515,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
block_table_cpu
=
self
.
input_batch
.
block_table
.
get_cpu_tensor
()
block_table_cpu
=
self
.
input_batch
.
block_table
[
0
]
.
get_cpu_tensor
()
block_numbers
=
block_table_cpu
.
flatten
()[
block_table_indices
].
numpy
()
block_offsets
=
positions_np
%
self
.
block_size
np
.
add
(
block_numbers
*
self
.
block_size
,
block_offsets
,
out
=
self
.
input_batch
.
block_table
.
out
=
self
.
input_batch
.
block_table
[
0
]
.
slot_mapping_np
[:
total_num_scheduled_tokens
])
# Prepare the attention metadata.
...
...
@@ -554,15 +545,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
position_ids
=
self
.
positions_cpu
[:
padded_total_num_scheduled_tokens
].
to
(
self
.
device
)
self
.
input_batch
.
block_table
.
slot_mapping_cpu
[
self
.
input_batch
.
block_table
[
0
]
.
slot_mapping_cpu
[
total_num_scheduled_tokens
:]
=
_PAD_SLOT_ID
slot_mapping
=
(
self
.
input_batch
.
block_table
.
self
.
input_batch
.
block_table
[
0
]
.
slot_mapping_cpu
[:
padded_total_num_scheduled_tokens
].
to
(
self
.
device
))
block_tables
=
self
.
block_table_cpu
[:
self
.
max_num_reqs
]
block_tables
[:
num_reqs
,
:
self
.
max_num_blocks_per_req
]
=
(
self
.
input_batch
.
block_table
.
get_cpu_tensor
()[:
num_reqs
])
self
.
input_batch
.
block_table
[
0
]
.
get_cpu_tensor
()[:
num_reqs
])
block_tables
=
block_tables
.
to
(
self
.
device
)
query_start_loc
=
self
.
query_start_loc_cpu
[:
self
.
max_num_reqs
+
1
].
to
(
self
.
device
)
...
...
@@ -1263,6 +1254,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
kv_cache_config
=
kv_cache_config
,
)
assert
self
.
block_table_cpu
.
dtype
==
self
.
input_batch
.
block_table
[
0
].
get_cpu_tensor
().
dtype
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment