Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e60f550b
Unverified
Commit
e60f550b
authored
May 15, 2025
by
Chen Zhang
Committed by
GitHub
May 14, 2025
Browse files
[v1] Support multiple KV cache groups in GPU model runner (#17945)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
f25e0d11
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
482 additions
and
215 deletions
+482
-215
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+68
-3
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+18
-18
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+33
-6
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+41
-16
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+1
-1
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
...d/kv_transfer/kv_connector/v1/shared_storage_connector.py
+3
-3
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+2
-2
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+21
-13
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+6
-7
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+6
-6
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+10
-6
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+42
-0
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+47
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+7
-6
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+158
-112
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+19
-16
No files found.
tests/v1/core/test_kv_cache_utils.py
View file @
e60f550b
...
@@ -19,7 +19,8 @@ from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
...
@@ -19,7 +19,8 @@ from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
hash_request_tokens
,
hash_request_tokens
,
unify_kv_cache_configs
)
unify_kv_cache_configs
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
)
KVCacheGroupSpec
,
KVCacheTensor
,
SlidingWindowSpec
)
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -54,12 +55,14 @@ def new_kv_cache_spec(block_size=16,
...
@@ -54,12 +55,14 @@ def new_kv_cache_spec(block_size=16,
num_kv_heads
=
2
,
num_kv_heads
=
2
,
head_size
=
64
,
head_size
=
64
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
use_mla
=
False
):
use_mla
=
False
,
sliding_window
=
None
):
return
FullAttentionSpec
(
block_size
=
block_size
,
return
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
num_kv_heads
,
num_kv_heads
=
num_kv_heads
,
head_size
=
head_size
,
head_size
=
head_size
,
dtype
=
dtype
,
dtype
=
dtype
,
use_mla
=
use_mla
)
use_mla
=
use_mla
,
sliding_window
=
sliding_window
)
def
test_none_hash
():
def
test_none_hash
():
...
@@ -471,6 +474,68 @@ def test_unify_kv_cache_configs():
...
@@ -471,6 +474,68 @@ def test_unify_kv_cache_configs():
unify_kv_cache_configs
(
diff_kv_cache_config
)
unify_kv_cache_configs
(
diff_kv_cache_config
)
def
test_merge_kv_cache_spec
():
same_layer_specs
=
[
new_kv_cache_spec
(
num_kv_heads
=
32
),
new_kv_cache_spec
(
num_kv_heads
=
32
),
]
merged_layer_spec
=
same_layer_specs
[
0
].
merge
(
same_layer_specs
)
assert
merged_layer_spec
.
block_size
==
16
assert
merged_layer_spec
.
num_kv_heads
==
32
assert
merged_layer_spec
.
head_size
==
64
assert
merged_layer_spec
.
dtype
==
torch
.
float32
assert
merged_layer_spec
.
sliding_window
is
None
different_layer_specs
=
[
new_kv_cache_spec
(
num_kv_heads
=
32
),
new_kv_cache_spec
(
num_kv_heads
=
16
),
]
with
pytest
.
raises
(
AssertionError
):
different_layer_specs
[
0
].
merge
(
different_layer_specs
)
full_spec
=
new_kv_cache_spec
(
num_kv_heads
=
32
)
different_type_layer_specs
=
[
full_spec
,
SlidingWindowSpec
(
block_size
=
full_spec
.
block_size
,
num_kv_heads
=
full_spec
.
num_kv_heads
,
head_size
=
full_spec
.
head_size
,
dtype
=
full_spec
.
dtype
,
use_mla
=
full_spec
.
use_mla
,
sliding_window
=
1
,
),
]
with
pytest
.
raises
(
AssertionError
):
different_type_layer_specs
[
0
].
merge
(
different_type_layer_specs
)
with
pytest
.
raises
(
AssertionError
):
different_type_layer_specs
[
1
].
merge
(
different_type_layer_specs
)
different_sliding_window_layer_specs
=
[
new_kv_cache_spec
(
num_kv_heads
=
32
),
new_kv_cache_spec
(
num_kv_heads
=
32
,
sliding_window
=
1
),
new_kv_cache_spec
(
num_kv_heads
=
32
,
sliding_window
=
2
),
]
with
pytest
.
raises
(
ValueError
):
different_sliding_window_layer_specs
[
0
].
merge
(
different_sliding_window_layer_specs
)
same_sliding_window_layer_specs
=
[
new_kv_cache_spec
(
num_kv_heads
=
32
,
sliding_window
=
1
),
new_kv_cache_spec
(
num_kv_heads
=
32
,
sliding_window
=
1
),
]
merged_layer_spec
=
same_sliding_window_layer_specs
[
0
].
merge
(
same_sliding_window_layer_specs
)
assert
merged_layer_spec
.
sliding_window
==
1
same_sliding_window_layer_spec_with_none
=
[
new_kv_cache_spec
(
num_kv_heads
=
32
,
sliding_window
=
1
),
new_kv_cache_spec
(
num_kv_heads
=
32
,
sliding_window
=
None
),
]
merged_layer_spec
=
same_sliding_window_layer_spec_with_none
[
0
].
merge
(
same_sliding_window_layer_spec_with_none
)
assert
merged_layer_spec
.
sliding_window
==
1
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"model_id"
,
"max_model_len"
,
"want_estimated_max_len"
),
[
(
"model_id"
,
"max_model_len"
,
"want_estimated_max_len"
),
[
(
"Qwen/Qwen1.5-7B"
,
16385
,
16384
),
(
"Qwen/Qwen1.5-7B"
,
16385
,
16384
),
...
...
tests/v1/core/test_prefix_caching.py
View file @
e60f550b
...
@@ -84,7 +84,7 @@ def test_prefill(hash_algo):
...
@@ -84,7 +84,7 @@ def test_prefill(hash_algo):
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
1
,
2
,
3
,
4
]
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
# Check full block metadata
# Check full block metadata
parent_block_hash
=
None
parent_block_hash
=
None
...
@@ -107,13 +107,13 @@ def test_prefill(hash_algo):
...
@@ -107,13 +107,13 @@ def test_prefill(hash_algo):
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
1
,
2
,
3
]
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
5
]
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
for
block
in
computed_blocks
.
blocks
:
for
block
in
computed_blocks
.
blocks
:
assert
block
.
ref_cnt
==
2
assert
block
.
ref_cnt
==
2
...
@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
...
@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
])
==
3
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
1
,
2
,
3
]
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
6
]
assert
blocks
.
get_block_ids
()
==
[
[
6
]
]
# Although we only have 6 free blocks, we have 8 blocks in
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
# the free block queue due to lazy removal.
...
@@ -171,7 +171,7 @@ def test_prefill(hash_algo):
...
@@ -171,7 +171,7 @@ def test_prefill(hash_algo):
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
computed_blocks
)
# This block ID order also checks the eviction order.
# This block ID order also checks the eviction order.
assert
blocks
.
get_block_ids
()
==
[
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
]
assert
blocks
.
get_block_ids
()
==
[
[
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
]
]
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
0
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
0
assert
manager
.
block_pool
.
free_block_queue
.
free_list_head
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_head
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
...
@@ -208,7 +208,7 @@ def test_prefill_plp():
...
@@ -208,7 +208,7 @@ def test_prefill_plp():
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
1
,
2
,
3
,
4
]
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
.
blocks
]
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
.
blocks
]
# Check full block metadata
# Check full block metadata
...
@@ -233,13 +233,13 @@ def test_prefill_plp():
...
@@ -233,13 +233,13 @@ def test_prefill_plp():
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
1
,
2
,
3
]
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
5
]
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
for
block
in
computed_blocks
.
blocks
:
for
block
in
computed_blocks
.
blocks
:
assert
block
.
ref_cnt
==
2
assert
block
.
ref_cnt
==
2
...
@@ -277,11 +277,11 @@ def test_prefill_plp():
...
@@ -277,11 +277,11 @@ def test_prefill_plp():
block_ids
=
blocks
.
get_block_ids
()
block_ids
=
blocks
.
get_block_ids
()
# Duplicate cached blocks have different ids but same hashes vs request #0
# Duplicate cached blocks have different ids but same hashes vs request #0
assert
[
b
.
block_hash
for
b
in
blocks
.
blocks
]
==
req0_block_hashes
assert
[
b
.
block_hash
for
b
in
blocks
.
blocks
]
==
req0_block_hashes
assert
block_ids
!=
[
1
,
2
,
3
,
4
]
assert
block_ids
!=
[
[
1
,
2
,
3
,
4
]
]
# Request #2 block hashes are valid since request #0 hashes are.
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
# Check block reference counts.
for
block_id
in
block_ids
:
for
block_id
in
block_ids
[
0
]
:
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
manager
.
free
(
req2
)
manager
.
free
(
req2
)
...
@@ -307,7 +307,7 @@ def test_decode():
...
@@ -307,7 +307,7 @@ def test_decode():
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
1
,
2
,
3
,
4
]
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
# Append slots without allocating a new block.
# Append slots without allocating a new block.
req0
.
num_computed_tokens
=
55
req0
.
num_computed_tokens
=
55
...
@@ -379,12 +379,12 @@ def test_evict():
...
@@ -379,12 +379,12 @@ def test_evict():
# Touch the first 2 blocks.
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
computed_blocks
.
get_block_ids
()
==
[
1
,
2
]
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
]
]
assert
num_computed_tokens
==
2
*
16
assert
num_computed_tokens
==
2
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
10
]
assert
blocks
.
get_block_ids
()
==
[
[
10
]
]
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
7
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
7
...
@@ -625,7 +625,7 @@ def test_mm_prefix_caching():
...
@@ -625,7 +625,7 @@ def test_mm_prefix_caching():
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
1
,
2
,
3
,
4
]
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
req0
.
num_computed_tokens
=
59
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
# Append slots without allocating a new block.
...
@@ -686,7 +686,7 @@ def test_cache_key_salting():
...
@@ -686,7 +686,7 @@ def test_cache_key_salting():
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
1
,
2
,
3
,
4
]
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
req0
.
num_computed_tokens
=
59
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
# Append slots without allocating a new block.
...
@@ -797,7 +797,7 @@ def test_reset_prefix_cache():
...
@@ -797,7 +797,7 @@ def test_reset_prefix_cache():
all_token_ids
=
full_block_token_ids
+
unique_token_ids
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
req0
=
make_request
(
"0"
,
all_token_ids
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
assert
blocks
.
get_block_ids
()
==
[
1
,
2
,
3
,
4
]
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
unique_token_ids
=
[
4
]
*
7
unique_token_ids
=
[
4
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
all_token_ids
=
full_block_token_ids
+
unique_token_ids
...
@@ -808,7 +808,7 @@ def test_reset_prefix_cache():
...
@@ -808,7 +808,7 @@ def test_reset_prefix_cache():
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
5
]
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
# Failed to reset prefix cache because some blocks are not freed yet.
# Failed to reset prefix cache because some blocks are not freed yet.
assert
not
manager
.
reset_prefix_cache
()
assert
not
manager
.
reset_prefix_cache
()
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
e60f550b
...
@@ -9,9 +9,11 @@ import torch
...
@@ -9,9 +9,11 @@ import torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.
gpu_input_batch
import
(
BlockTable
,
CachedRequestState
,
from
vllm.v1.worker.
block_table
import
BlockTable
,
MultiGroupBlockTable
InputBatch
)
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
VOCAB_SIZE
=
1024
VOCAB_SIZE
=
1024
NUM_OUTPUT_TOKENS
=
20
NUM_OUTPUT_TOKENS
=
20
...
@@ -22,6 +24,27 @@ CUDA_DEVICES = [
...
@@ -22,6 +24,27 @@ CUDA_DEVICES = [
MAX_NUM_PROMPT_TOKENS
=
64
MAX_NUM_PROMPT_TOKENS
=
64
def
get_kv_cache_config
()
->
KVCacheConfig
:
return
KVCacheConfig
(
num_blocks
=
10
,
tensors
=
{
"layer.0"
:
KVCacheTensor
(
size
=
1024
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
(
layer_names
=
[
"layer.0"
],
kv_cache_spec
=
FullAttentionSpec
(
block_size
=
1
,
num_kv_heads
=
1
,
head_size
=
16
,
dtype
=
torch
.
float16
,
use_mla
=
False
,
),
),
],
)
def
_compare_objs
(
obj1
,
obj2
):
def
_compare_objs
(
obj1
,
obj2
):
attrs
=
inspect
.
getmembers
(
obj1
,
lambda
a
:
not
(
inspect
.
isroutine
(
a
)))
attrs
=
inspect
.
getmembers
(
obj1
,
lambda
a
:
not
(
inspect
.
isroutine
(
a
)))
attr_names
=
set
([
attr_names
=
set
([
...
@@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2):
...
@@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2):
elif
isinstance
(
a
,
np
.
ndarray
):
elif
isinstance
(
a
,
np
.
ndarray
):
if
np
.
allclose
(
a
,
b
):
if
np
.
allclose
(
a
,
b
):
is_same
=
True
is_same
=
True
elif
isinstance
(
a
,
MultiGroupBlockTable
):
for
a_i
,
b_i
in
zip
(
a
.
block_tables
,
b
.
block_tables
):
_compare_objs
(
a_i
,
b_i
)
is_same
=
True
elif
isinstance
(
a
,
(
BlockTable
,
SamplingMetadata
)):
elif
isinstance
(
a
,
(
BlockTable
,
SamplingMetadata
)):
_compare_objs
(
a
,
b
)
_compare_objs
(
a
,
b
)
is_same
=
True
# if we make it here must be same
is_same
=
True
# if we make it here must be same
...
@@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int):
...
@@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params
=
_create_sampling_params
(),
sampling_params
=
_create_sampling_params
(),
mm_inputs
=
[],
mm_inputs
=
[],
mm_positions
=
[],
mm_positions
=
[],
block_ids
=
[],
block_ids
=
[
[]
],
generator
=
None
,
generator
=
None
,
num_computed_tokens
=
len
(
output_token_ids
),
num_computed_tokens
=
len
(
output_token_ids
),
output_token_ids
=
output_token_ids
,
output_token_ids
=
output_token_ids
,
...
@@ -220,11 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
...
@@ -220,11 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch
:
InputBatch
=
InputBatch
(
input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
max_num_batched_tokens
=
1024
,
max_num_batched_tokens
=
1024
,
device
=
torch
.
device
(
device
),
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
vocab_size
=
1024
,
kv_cache_config
=
get_kv_cache_config
(),
)
)
reqs
:
list
[
CachedRequestState
]
=
[]
reqs
:
list
[
CachedRequestState
]
=
[]
req_id_reqs
=
{}
req_id_reqs
=
{}
...
@@ -310,20 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
...
@@ -310,20 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
input_batch
:
InputBatch
=
InputBatch
(
input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
max_num_batched_tokens
=
1024
,
max_num_batched_tokens
=
1024
,
device
=
torch
.
device
(
device
),
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
vocab_size
=
1024
,
kv_cache_config
=
get_kv_cache_config
(),
)
)
ref_input_batch
:
InputBatch
=
InputBatch
(
ref_input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
max_num_batched_tokens
=
1024
,
max_num_batched_tokens
=
1024
,
device
=
torch
.
device
(
device
),
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
vocab_size
=
1024
,
kv_cache_config
=
get_kv_cache_config
(),
)
)
reqs
:
list
[
CachedRequestState
]
=
[]
reqs
:
list
[
CachedRequestState
]
=
[]
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
e60f550b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
weakref
import
pytest
import
pytest
import
torch
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VllmConfig
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
SchedulerOutput
)
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
...
@@ -17,13 +18,34 @@ def initialize_kv_cache(runner: GPUModelRunner):
...
@@ -17,13 +18,34 @@ def initialize_kv_cache(runner: GPUModelRunner):
"""
"""
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
"""
"""
kv_cache_spec
=
FullAttentionSpec
(
block_size
=
16
,
kv_cache_config
=
KVCacheConfig
(
num_kv_heads
=
1
,
num_blocks
=
10
,
head_size
=
64
,
tensors
=
{
dtype
=
torch
.
float16
,
"layer.0"
:
KVCacheTensor
(
size
=
1024
),
use_mla
=
False
)
},
runner
.
attn_metadata_builder
=
runner
.
attn_backend
.
get_builder_cls
()(
kv_cache_groups
=
[
weakref
.
proxy
(
runner
),
kv_cache_spec
,
runner
.
input_batch
.
block_table
)
KVCacheGroupSpec
(
layer_names
=
[
"layer.0"
],
kv_cache_spec
=
FullAttentionSpec
(
block_size
=
16
,
num_kv_heads
=
runner
.
model_config
.
get_num_kv_heads
(
runner
.
parallel_config
),
head_size
=
runner
.
model_config
.
get_head_size
(),
dtype
=
runner
.
kv_cache_dtype
,
use_mla
=
False
,
))
])
runner
.
kv_cache_config
=
kv_cache_config
runner
.
input_batch
=
InputBatch
(
max_num_reqs
=
runner
.
max_num_reqs
,
max_model_len
=
runner
.
max_model_len
,
max_num_batched_tokens
=
runner
.
max_num_tokens
,
device
=
runner
.
device
,
pin_memory
=
runner
.
pin_memory
,
vocab_size
=
runner
.
model_config
.
get_vocab_size
(),
kv_cache_config
=
kv_cache_config
,
)
runner
.
initialize_attn_backend
(
kv_cache_config
)
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -48,10 +70,12 @@ def model_runner():
...
@@ -48,10 +70,12 @@ def model_runner():
swap_space
=
0
,
swap_space
=
0
,
cache_dtype
=
"auto"
,
cache_dtype
=
"auto"
,
)
)
parallel_config
=
ParallelConfig
()
vllm_config
=
VllmConfig
(
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
model_config
=
model_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
scheduler_config
=
scheduler_config
,
scheduler_config
=
scheduler_config
,
parallel_config
=
parallel_config
,
)
)
device
=
"cuda"
device
=
"cuda"
...
@@ -73,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
...
@@ -73,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes
=
[],
mm_hashes
=
[],
mm_positions
=
[],
mm_positions
=
[],
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
block_ids
=
[
0
],
block_ids
=
[
[
0
]
],
num_computed_tokens
=
0
,
num_computed_tokens
=
0
,
lora_request
=
None
,
lora_request
=
None
,
))
))
...
@@ -111,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner,
...
@@ -111,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner,
def
_is_req_state_block_table_match
(
model_runner
,
req_id
:
str
)
->
bool
:
def
_is_req_state_block_table_match
(
model_runner
,
req_id
:
str
)
->
bool
:
req_index
=
model_runner
.
input_batch
.
req_id_to_index
[
req_id
]
req_index
=
model_runner
.
input_batch
.
req_id_to_index
[
req_id
]
block_table
=
model_runner
.
input_batch
.
block_table
block_table
=
model_runner
.
input_batch
.
block_table
[
0
]
req_state
=
model_runner
.
requests
[
req_id
]
req_state
=
model_runner
.
requests
[
req_id
]
if
block_table
.
num_blocks_per_row
[
req_index
]
!=
len
(
req_state
.
block_ids
):
if
block_table
.
num_blocks_per_row
[
req_index
]
!=
len
(
req_state
.
block_ids
[
0
]):
return
False
return
False
num_blocks
=
block_table
.
num_blocks_per_row
[
req_index
]
num_blocks
=
block_table
.
num_blocks_per_row
[
req_index
]
return
(
block_table
.
block_table_np
[
req_index
,
:
num_blocks
]
==
return
(
block_table
.
block_table_np
[
req_index
,
:
num_blocks
]
==
req_state
.
block_ids
).
all
()
req_state
.
block_ids
[
0
]
).
all
()
def
test_update_states_new_request
(
model_runner
):
def
test_update_states_new_request
(
model_runner
):
...
@@ -200,7 +225,7 @@ def test_update_states_request_resumed(model_runner):
...
@@ -200,7 +225,7 @@ def test_update_states_request_resumed(model_runner):
req_id
=
req_id
,
req_id
=
req_id
,
resumed_from_preemption
=
False
,
resumed_from_preemption
=
False
,
new_token_ids
=
[],
new_token_ids
=
[],
new_block_ids
=
[],
new_block_ids
=
[
[]
],
num_computed_tokens
=
0
,
num_computed_tokens
=
0
,
)
)
...
...
tests/weight_loading/models.txt
View file @
e60f550b
...
@@ -2,7 +2,7 @@ gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
...
@@ -2,7 +2,7 @@ gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main
gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
#
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
gptq, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq, TheBloke/Llama-2-7B-GPTQ, main
gptq, TheBloke/Llama-2-7B-GPTQ, main
...
...
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
View file @
e60f550b
...
@@ -288,7 +288,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
...
@@ -288,7 +288,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
for
new_req
in
scheduler_output
.
scheduled_new_reqs
:
for
new_req
in
scheduler_output
.
scheduled_new_reqs
:
if
new_req
.
req_id
in
self
.
_requests_need_load
:
if
new_req
.
req_id
in
self
.
_requests_need_load
:
meta
.
add_request
(
token_ids
=
new_req
.
prompt_token_ids
,
meta
.
add_request
(
token_ids
=
new_req
.
prompt_token_ids
,
block_ids
=
new_req
.
block_ids
,
block_ids
=
new_req
.
block_ids
[
0
]
,
block_size
=
self
.
_block_size
,
block_size
=
self
.
_block_size
,
is_store
=
False
)
is_store
=
False
)
total_need_load
+=
1
total_need_load
+=
1
...
@@ -299,7 +299,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
...
@@ -299,7 +299,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# the original prompt tokens.
# the original prompt tokens.
if
not
self
.
_found_match_for_request
(
new_req
):
if
not
self
.
_found_match_for_request
(
new_req
):
meta
.
add_request
(
token_ids
=
new_req
.
prompt_token_ids
,
meta
.
add_request
(
token_ids
=
new_req
.
prompt_token_ids
,
block_ids
=
new_req
.
block_ids
,
block_ids
=
new_req
.
block_ids
[
0
]
,
block_size
=
self
.
_block_size
,
block_size
=
self
.
_block_size
,
is_store
=
True
)
is_store
=
True
)
...
@@ -319,7 +319,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
...
@@ -319,7 +319,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# NOTE(rob): For resumed req, new_block_ids is all
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
# of the block_ids for the request.
block_ids
=
cached_req
.
new_block_ids
block_ids
=
cached_req
.
new_block_ids
[
0
]
meta
.
add_request
(
token_ids
=
token_ids
,
meta
.
add_request
(
token_ids
=
token_ids
,
block_ids
=
block_ids
,
block_ids
=
block_ids
,
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
e60f550b
...
@@ -67,13 +67,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
...
@@ -67,13 +67,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_model_len
=
self
.
runner
.
model_config
.
max_model_len
max_model_len
=
self
.
runner
.
model_config
.
max_model_len
assert
max_model_len
==
32768
,
\
assert
max_model_len
==
32768
,
\
"AITER MLA requires max_model_len=32768"
"AITER MLA requires max_model_len=32768"
assert
self
.
runner
.
block_size
==
1
,
"AITER MLA"
\
assert
self
.
kv_cache_spec
.
block_size
==
1
,
"AITER MLA"
\
"only supports block size 1."
"only supports block size 1."
def
_get_paged_kv_tensors
(
def
_get_paged_kv_tensors
(
self
,
block_table
:
torch
.
Tensor
,
self
,
block_table
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
seq_lens
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
page_size
=
self
.
runner
.
block_size
page_size
=
self
.
kv_cache_spec
.
block_size
block_table_bounds
=
(
seq_lens
+
page_size
-
1
)
//
page_size
block_table_bounds
=
(
seq_lens
+
page_size
-
1
)
//
page_size
mask
=
(
torch
.
arange
(
block_table
.
size
(
1
),
mask
=
(
torch
.
arange
(
block_table
.
size
(
1
),
...
...
vllm/v1/core/kv_cache_manager.py
View file @
e60f550b
...
@@ -32,9 +32,16 @@ class KVCacheBlocks:
...
@@ -32,9 +32,16 @@ class KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""
"""Creates a new KVCacheBlocks instance with no blocks."""
return
cls
([])
return
cls
([])
def
get_block_ids
(
self
)
->
list
[
int
]:
def
get_block_ids
(
self
)
->
list
[
list
[
int
]]:
"""Converts the KVCacheBlocks instance to a list of block IDs."""
"""
return
[
block
.
block_id
for
block
in
self
.
blocks
]
Converts the KVCacheBlocks instance to block_ids.
Returns:
list[list[int]]: A two-level list where
* the outer list corresponds to KV cache groups (only 1 group now)
* each inner list contains the block_ids of the blocks in that group
"""
return
[[
block
.
block_id
for
block
in
self
.
blocks
]]
def
get_unhashed_block_ids
(
self
)
->
list
[
int
]:
def
get_unhashed_block_ids
(
self
)
->
list
[
int
]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
...
@@ -300,9 +307,9 @@ class KVCacheManager:
...
@@ -300,9 +307,9 @@ class KVCacheManager:
self
,
self
,
request
:
Request
,
request
:
Request
,
num_running_requests
:
int
,
num_running_requests
:
int
,
)
->
int
:
)
->
list
[
int
]
:
"""Calculate the number of common prefix blocks shared by all requests
"""Calculate the number of common prefix blocks shared by all requests
in the RUNNING state.
in the RUNNING state
for each kv cache group
.
The function determines this by selecting any request and iterating
The function determines this by selecting any request and iterating
through its blocks. A block is considered a common prefix block if its
through its blocks. A block is considered a common prefix block if its
...
@@ -332,11 +339,14 @@ class KVCacheManager:
...
@@ -332,11 +339,14 @@ class KVCacheManager:
requests in the current step.
requests in the current step.
Returns:
Returns:
int: The number of common prefix blocks.
list[int]: The number of common prefix blocks for each kv cache
group.
"""
"""
assert
request
.
status
==
RequestStatus
.
RUNNING
assert
request
.
status
==
RequestStatus
.
RUNNING
return
self
.
single_type_manager
.
get_num_common_prefix_blocks
(
return
[
self
.
single_type_manager
.
get_num_common_prefix_blocks
(
request
.
request_id
,
num_running_requests
)
request
.
request_id
,
num_running_requests
)
]
def
free_block_hashes
(
self
,
request
:
Request
)
->
None
:
def
free_block_hashes
(
self
,
request
:
Request
)
->
None
:
"""Discard the block hashes for the request.
"""Discard the block hashes for the request.
...
@@ -354,10 +364,8 @@ class KVCacheManager:
...
@@ -354,10 +364,8 @@ class KVCacheManager:
"""
"""
return
self
.
block_pool
.
take_events
()
return
self
.
block_pool
.
take_events
()
def
get_block_ids
(
self
,
request_id
:
str
)
->
list
[
int
]:
def
get_block_ids
(
self
,
request_id
:
str
)
->
list
[
list
[
int
]
]
:
"""Get the block ids of a request."""
"""Get the block ids of a request."""
assert
request_id
in
self
.
single_type_manager
.
req_to_blocks
assert
request_id
in
self
.
single_type_manager
.
req_to_blocks
return
[
return
KVCacheBlocks
(
self
.
single_type_manager
.
req_to_blocks
[
request_id
]
block
.
block_id
).
get_block_ids
()
for
block
in
self
.
single_type_manager
.
req_to_blocks
[
request_id
]
]
vllm/v1/core/kv_cache_utils.py
View file @
e60f550b
...
@@ -577,14 +577,12 @@ def create_kv_cache_group_specs(
...
@@ -577,14 +577,12 @@ def create_kv_cache_group_specs(
"""
"""
kv_cache_groups
=
[]
kv_cache_groups
=
[]
for
layer_names_one_group
in
grouped_layer_names
:
for
layer_names_one_group
in
grouped_layer_names
:
layer_spec
=
kv_cache_spec
[
layer_names_one_group
[
0
]]
layer_specs
=
[
assert
all
(
kv_cache_spec
[
layer_name
]
for
layer_name
in
layer_names_one_group
kv_cache_spec
[
layer_name
]
==
layer_spec
]
for
layer_name
in
layer_names_one_group
[
1
:]),
(
merged_layer_spec
=
layer_specs
[
0
].
merge
(
layer_specs
)
"All layers in the same KV cache group must share the same "
"KVCacheSpec."
)
kv_cache_groups
.
append
(
kv_cache_groups
.
append
(
KVCacheGroupSpec
(
layer_names_one_group
,
layer_spec
))
KVCacheGroupSpec
(
layer_names_one_group
,
merged_
layer_spec
))
return
kv_cache_groups
return
kv_cache_groups
...
@@ -683,6 +681,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
...
@@ -683,6 +681,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
head_size
=
spec
.
head_size
,
head_size
=
spec
.
head_size
,
dtype
=
spec
.
dtype
,
dtype
=
spec
.
dtype
,
use_mla
=
spec
.
use_mla
,
use_mla
=
spec
.
use_mla
,
sliding_window
=
spec
.
sliding_window
,
)
)
...
...
vllm/v1/core/sched/output.py
View file @
e60f550b
...
@@ -26,7 +26,7 @@ class NewRequestData:
...
@@ -26,7 +26,7 @@ class NewRequestData:
mm_hashes
:
list
[
str
]
mm_hashes
:
list
[
str
]
mm_positions
:
list
[
PlaceholderRange
]
mm_positions
:
list
[
PlaceholderRange
]
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
block_ids
:
list
[
int
]
block_ids
:
list
[
list
[
int
]
]
num_computed_tokens
:
int
num_computed_tokens
:
int
lora_request
:
Optional
[
LoRARequest
]
lora_request
:
Optional
[
LoRARequest
]
...
@@ -34,7 +34,7 @@ class NewRequestData:
...
@@ -34,7 +34,7 @@ class NewRequestData:
def
from_request
(
def
from_request
(
cls
,
cls
,
request
:
Request
,
request
:
Request
,
block_ids
:
list
[
int
],
block_ids
:
list
[
list
[
int
]
]
,
)
->
NewRequestData
:
)
->
NewRequestData
:
return
cls
(
return
cls
(
req_id
=
request
.
request_id
,
req_id
=
request
.
request_id
,
...
@@ -85,7 +85,7 @@ class CachedRequestData:
...
@@ -85,7 +85,7 @@ class CachedRequestData:
# request's block IDs instead of appending to the existing block IDs.
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption
:
bool
resumed_from_preemption
:
bool
new_token_ids
:
list
[
int
]
new_token_ids
:
list
[
int
]
new_block_ids
:
list
[
int
]
new_block_ids
:
list
[
list
[
int
]
]
num_computed_tokens
:
int
num_computed_tokens
:
int
@
classmethod
@
classmethod
...
@@ -94,7 +94,7 @@ class CachedRequestData:
...
@@ -94,7 +94,7 @@ class CachedRequestData:
request
:
Request
,
request
:
Request
,
resumed_from_preemption
:
bool
,
resumed_from_preemption
:
bool
,
new_token_ids
:
list
[
int
],
new_token_ids
:
list
[
int
],
new_block_ids
:
list
[
int
],
new_block_ids
:
list
[
list
[
int
]
]
,
)
->
CachedRequestData
:
)
->
CachedRequestData
:
return
cls
(
return
cls
(
req_id
=
request
.
request_id
,
req_id
=
request
.
request_id
,
...
@@ -131,9 +131,9 @@ class SchedulerOutput:
...
@@ -131,9 +131,9 @@ class SchedulerOutput:
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
# to process that the request's 0-th and 1-th images in the current step.
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]]
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]]
# Number of common prefix blocks for all requests.
# Number of common prefix blocks for all requests
in each KV cache group
.
# This can be used for cascade attention.
# This can be used for cascade attention.
num_common_prefix_blocks
:
int
num_common_prefix_blocks
:
list
[
int
]
# Request IDs that are finished in between the previous and the current
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
# steps. This is used to notify the workers about the finished requests
...
...
vllm/v1/core/sched/scheduler.py
View file @
e60f550b
...
@@ -173,7 +173,7 @@ class Scheduler(SchedulerInterface):
...
@@ -173,7 +173,7 @@ class Scheduler(SchedulerInterface):
# uses structured decoding.
# uses structured decoding.
structured_output_request_ids
:
dict
[
str
,
int
]
=
{}
structured_output_request_ids
:
dict
[
str
,
int
]
=
{}
req_to_new_block_ids
:
dict
[
str
,
list
[
int
]]
=
{}
req_to_new_block_ids
:
dict
[
str
,
list
[
list
[
int
]]
]
=
{}
num_scheduled_tokens
:
dict
[
str
,
int
]
=
{}
num_scheduled_tokens
:
dict
[
str
,
int
]
=
{}
token_budget
=
self
.
max_num_scheduled_tokens
token_budget
=
self
.
max_num_scheduled_tokens
# Encoder-related.
# Encoder-related.
...
@@ -477,7 +477,8 @@ class Scheduler(SchedulerInterface):
...
@@ -477,7 +477,8 @@ class Scheduler(SchedulerInterface):
# Get the longest common prefix among all requests in the running queue.
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
# This can be potentially used for cascade attention.
num_common_prefix_blocks
=
0
num_common_prefix_blocks
=
[
0
]
*
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
if
self
.
running
:
if
self
.
running
:
any_request
=
self
.
running
[
0
]
any_request
=
self
.
running
[
0
]
num_common_prefix_blocks
=
(
num_common_prefix_blocks
=
(
...
@@ -564,7 +565,7 @@ class Scheduler(SchedulerInterface):
...
@@ -564,7 +565,7 @@ class Scheduler(SchedulerInterface):
request
:
Request
,
request
:
Request
,
num_scheduled_tokens
:
int
,
num_scheduled_tokens
:
int
,
num_scheduled_spec_tokens
:
int
,
num_scheduled_spec_tokens
:
int
,
new_block_ids
:
list
[
int
],
new_block_ids
:
list
[
list
[
int
]
]
,
resumed_from_preemption
:
bool
,
resumed_from_preemption
:
bool
,
)
->
CachedRequestData
:
)
->
CachedRequestData
:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
...
@@ -939,7 +940,9 @@ class Scheduler(SchedulerInterface):
...
@@ -939,7 +940,9 @@ class Scheduler(SchedulerInterface):
"""
"""
if
self
.
connector
is
None
:
if
self
.
connector
is
None
:
return
False
,
None
return
False
,
None
block_ids
=
self
.
kv_cache_manager
.
get_block_ids
(
request
.
request_id
)
assert
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
1
,
"KV connector only supports one KV cache group now"
block_ids
=
self
.
kv_cache_manager
.
get_block_ids
(
request
.
request_id
)[
0
]
return
self
.
connector
.
request_finished
(
request
,
block_ids
)
return
self
.
connector
.
request_finished
(
request
,
block_ids
)
def
_update_waiting_for_remote_kv
(
self
,
request
:
Request
)
->
bool
:
def
_update_waiting_for_remote_kv
(
self
,
request
:
Request
)
->
bool
:
...
@@ -956,9 +959,10 @@ class Scheduler(SchedulerInterface):
...
@@ -956,9 +959,10 @@ class Scheduler(SchedulerInterface):
"""
"""
if
request
.
request_id
not
in
self
.
finished_recving_kv_req_ids
:
if
request
.
request_id
not
in
self
.
finished_recving_kv_req_ids
:
return
False
return
False
assert
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
1
,
"KV connector only supports one KV cache group now"
# Now that the blocks are ready, actually cache them.
# Now that the blocks are ready, actually cache them.
block_ids
=
self
.
kv_cache_manager
.
get_block_ids
(
request
.
request_id
)
block_ids
=
self
.
kv_cache_manager
.
get_block_ids
(
request
.
request_id
)
[
0
]
num_computed_tokens
=
len
(
block_ids
)
*
self
.
block_size
num_computed_tokens
=
len
(
block_ids
)
*
self
.
block_size
if
num_computed_tokens
==
request
.
num_tokens
:
if
num_computed_tokens
==
request
.
num_tokens
:
num_computed_tokens
-=
1
num_computed_tokens
-=
1
...
...
vllm/v1/kv_cache_interface.py
View file @
e60f550b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
copy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
import
torch
import
torch
from
typing_extensions
import
Self
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -53,6 +56,16 @@ class KVCacheSpec:
...
@@ -53,6 +56,16 @@ class KVCacheSpec:
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
classmethod
def
merge
(
cls
,
specs
:
list
[
Self
])
->
Self
:
"""
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
"""
assert
all
(
spec
.
type_id
==
specs
[
0
].
type_id
for
spec
in
specs
[
1
:]),
(
"All layers in the same KV cache group must share the same "
"type_id."
)
return
copy
.
deepcopy
(
specs
[
0
])
@
dataclass
@
dataclass
class
AttentionSpec
(
KVCacheSpec
):
class
AttentionSpec
(
KVCacheSpec
):
...
@@ -71,6 +84,16 @@ class AttentionSpec(KVCacheSpec):
...
@@ -71,6 +84,16 @@ class AttentionSpec(KVCacheSpec):
@
dataclass
@
dataclass
class
FullAttentionSpec
(
AttentionSpec
):
class
FullAttentionSpec
(
AttentionSpec
):
sliding_window
:
Optional
[
int
]
=
None
"""
When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding
window attention are regarded as full attention in KV cache manager
(blocks are allocated for all tokens), while computed as sliding window
attention in model runner.
In this case, we use FullAttentionSpec and record the sliding window size.
Default to None for not using sliding window attention.
"""
@
property
@
property
def
type_id
(
self
)
->
str
:
def
type_id
(
self
)
->
str
:
...
@@ -80,6 +103,25 @@ class FullAttentionSpec(AttentionSpec):
...
@@ -80,6 +103,25 @@ class FullAttentionSpec(AttentionSpec):
max_model_len
=
vllm_config
.
model_config
.
max_model_len
max_model_len
=
vllm_config
.
model_config
.
max_model_len
return
cdiv
(
max_model_len
,
self
.
block_size
)
*
self
.
page_size_bytes
return
cdiv
(
max_model_len
,
self
.
block_size
)
*
self
.
page_size_bytes
@
classmethod
def
merge
(
cls
,
specs
:
list
[
Self
])
->
Self
:
"""
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
merged_spec
=
super
().
merge
(
specs
)
sliding_window
=
set
(
spec
.
sliding_window
for
spec
in
specs
if
spec
.
sliding_window
is
not
None
)
if
len
(
sliding_window
)
==
0
:
merged_spec
.
sliding_window
=
None
elif
len
(
sliding_window
)
==
1
:
merged_spec
.
sliding_window
=
sliding_window
.
pop
()
else
:
raise
ValueError
(
"All sliding window layers in the same KV cache group "
"must have the same window size."
)
return
merged_spec
@
dataclass
@
dataclass
class
SlidingWindowSpec
(
AttentionSpec
):
class
SlidingWindowSpec
(
AttentionSpec
):
...
...
vllm/v1/worker/block_table.py
View file @
e60f550b
...
@@ -4,6 +4,8 @@ import numpy as np
...
@@ -4,6 +4,8 @@ import numpy as np
import
torch
import
torch
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -96,3 +98,48 @@ class BlockTable:
...
@@ -96,3 +98,48 @@ class BlockTable:
def
get_numpy_array
(
self
)
->
np
.
ndarray
:
def
get_numpy_array
(
self
)
->
np
.
ndarray
:
"""Returns the numpy array of the block table."""
"""Returns the numpy array of the block table."""
return
self
.
block_table_np
return
self
.
block_table_np
class
MultiGroupBlockTable
:
"""The BlockTables for each KV cache group."""
def
__init__
(
self
,
max_num_reqs
:
int
,
max_model_len
:
int
,
max_num_batched_tokens
:
int
,
pin_memory
:
bool
,
device
:
torch
.
device
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
max_num_blocks_per_req
=
[
cdiv
(
max_model_len
,
g
.
kv_cache_spec
.
block_size
)
for
g
in
kv_cache_config
.
kv_cache_groups
]
self
.
block_tables
=
[
BlockTable
(
max_num_reqs
,
max_num_blocks_per_req
[
i
],
max_num_batched_tokens
,
pin_memory
,
device
)
for
i
in
range
(
len
(
kv_cache_config
.
kv_cache_groups
))
]
def
append_row
(
self
,
block_ids
:
list
[
list
[
int
]],
row_idx
:
int
)
->
None
:
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
block_table
.
append_row
(
block_ids
[
i
],
row_idx
)
def
add_row
(
self
,
block_ids
:
list
[
list
[
int
]],
row_idx
:
int
)
->
None
:
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
block_table
.
add_row
(
block_ids
[
i
],
row_idx
)
def
move_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
for
block_table
in
self
.
block_tables
:
block_table
.
move_row
(
src
,
tgt
)
def
swap_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
for
block_table
in
self
.
block_tables
:
block_table
.
swap_row
(
src
,
tgt
)
def
commit
(
self
,
num_reqs
:
int
)
->
None
:
for
block_table
in
self
.
block_tables
:
block_table
.
commit
(
num_reqs
)
def
clear
(
self
)
->
None
:
for
block_table
in
self
.
block_tables
:
block_table
.
clear
()
def
__getitem__
(
self
,
idx
:
int
)
->
"BlockTable"
:
"""Returns the BlockTable for the i-th KV cache group."""
return
self
.
block_tables
[
idx
]
vllm/v1/worker/gpu_input_batch.py
View file @
e60f550b
...
@@ -11,10 +11,11 @@ from vllm.lora.request import LoRARequest
...
@@ -11,10 +11,11 @@ from vllm.lora.request import LoRARequest
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.utils
import
swap_dict_values
from
vllm.utils
import
swap_dict_values
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.utils
import
copy_slice
from
vllm.v1.utils
import
copy_slice
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
MultiGroup
BlockTable
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
...
@@ -29,7 +30,7 @@ class CachedRequestState:
...
@@ -29,7 +30,7 @@ class CachedRequestState:
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
generator
:
Optional
[
torch
.
Generator
]
generator
:
Optional
[
torch
.
Generator
]
block_ids
:
list
[
int
]
block_ids
:
list
[
list
[
int
]
]
num_computed_tokens
:
int
num_computed_tokens
:
int
output_token_ids
:
list
[
int
]
output_token_ids
:
list
[
int
]
...
@@ -58,15 +59,14 @@ class InputBatch:
...
@@ -58,15 +59,14 @@ class InputBatch:
self
,
self
,
max_num_reqs
:
int
,
max_num_reqs
:
int
,
max_model_len
:
int
,
max_model_len
:
int
,
max_num_blocks_per_req
:
int
,
max_num_batched_tokens
:
int
,
max_num_batched_tokens
:
int
,
device
:
torch
.
device
,
device
:
torch
.
device
,
pin_memory
:
bool
,
pin_memory
:
bool
,
vocab_size
:
int
,
vocab_size
:
int
,
kv_cache_config
:
KVCacheConfig
,
):
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
self
.
max_num_blocks_per_req
=
max_num_blocks_per_req
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
device
=
device
self
.
device
=
device
self
.
pin_memory
=
pin_memory
self
.
pin_memory
=
pin_memory
...
@@ -99,12 +99,13 @@ class InputBatch:
...
@@ -99,12 +99,13 @@ class InputBatch:
self
.
num_computed_tokens_cpu_tensor
.
numpy
()
self
.
num_computed_tokens_cpu_tensor
.
numpy
()
# Block table.
# Block table.
self
.
block_table
=
BlockTable
(
self
.
block_table
=
MultiGroup
BlockTable
(
max_num_reqs
=
max_num_reqs
,
max_num_reqs
=
max_num_reqs
,
max_
num_blocks_per_req
=
max_num_blocks_per_req
,
max_
model_len
=
max_model_len
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_batched_tokens
=
max_num_batched_tokens
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
device
=
device
,
device
=
device
,
kv_cache_config
=
kv_cache_config
,
)
)
# Sampling-related.
# Sampling-related.
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
e60f550b
...
@@ -12,6 +12,8 @@ import torch.distributed
...
@@ -12,6 +12,8 @@ import torch.distributed
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.attention
import
AttentionType
,
get_attn_backend
from
vllm.attention
import
AttentionType
,
get_attn_backend
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadataBuilder
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
...
@@ -31,8 +33,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
...
@@ -31,8 +33,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
GiB_bytes
,
La
yerBlockType
,
LazyLoader
,
cdiv
,
GiB_bytes
,
La
zyLoader
,
cdiv
,
check_use_alibi
,
check_use_alibi
,
is_pin_memory_available
)
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
...
@@ -49,6 +51,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
...
@@ -49,6 +51,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.utils
import
is_spec_decode_supported
from
vllm.v1.spec_decode.utils
import
is_spec_decode_supported
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
...
@@ -100,59 +103,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -100,59 +103,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
kv_cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
kv_cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
cache_config
.
cache_dtype
]
# NOTE(woosuk): sliding_window is None for models with interleaved
# attention. Use interleaved_sliding_window instead.
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
interleaved_sliding_window
=
getattr
(
model_config
.
hf_text_config
,
"interleaved_sliding_window"
,
None
)
self
.
window_size
=
(
self
.
sliding_window
or
self
.
interleaved_sliding_window
)
self
.
is_multimodal_model
=
model_config
.
is_multimodal_model
self
.
is_multimodal_model
=
model_config
.
is_multimodal_model
self
.
block_size
=
cache_config
.
block_size
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_num_blocks_per_req
=
cdiv
(
self
.
max_model_len
,
self
.
block_size
)
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
# Model-related.
# Model-related.
self
.
num_attn_layers
=
model_config
.
get_num_layers_by_block_type
(
parallel_config
,
LayerBlockType
.
attention
)
self
.
num_query_heads
=
model_config
.
get_num_attention_heads
(
self
.
num_query_heads
=
model_config
.
get_num_attention_heads
(
parallel_config
)
parallel_config
)
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
head_size
=
model_config
.
get_head_size
()
self
.
hidden_size
=
model_config
.
get_hidden_size
()
self
.
hidden_size
=
model_config
.
get_hidden_size
()
self
.
attention_chunk_size
=
model_config
.
attention_chunk_size
self
.
attention_chunk_size
=
model_config
.
attention_chunk_size
self
.
attn_backend
=
get_attn_backend
(
self
.
head_size
,
self
.
dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
use_mla
=
self
.
model_config
.
use_mla
,
)
if
self
.
attn_backend
is
None
:
error_msg
=
(
f
"Error with get_att_backend:
{
self
.
head_size
=
}
, "
f
"
{
self
.
dtype
=
}
,
{
self
.
kv_cache_dtype
=
}
,
{
self
.
block_size
=
}
, "
f
"
{
self
.
model_config
.
is_attention_free
=
}
, "
f
"
{
self
.
model_config
.
use_mla
=
}
"
)
logger
.
error
(
error_msg
)
raise
NotImplementedError
(
"Non-Attention backend is not supported by V1 GPUModelRunner."
)
if
self
.
vllm_config
.
compilation_config
.
full_cuda_graph
:
attn_backend_name
=
self
.
attn_backend
.
__name__
flash_attn_version
=
get_flash_attn_version
()
if
attn_backend_name
!=
"FlashAttentionBackend"
or
\
flash_attn_version
!=
3
:
raise
ValueError
(
f
"full_cuda_graph is only supported with "
f
"FA3. Current attention backend is
{
attn_backend_name
}
, "
f
"FlashAttention version is
{
flash_attn_version
}
."
)
self
.
cascade_attn_enabled
=
not
self
.
model_config
.
disable_cascade_attn
self
.
cascade_attn_enabled
=
not
self
.
model_config
.
disable_cascade_attn
# Multi-modal data support
# Multi-modal data support
...
@@ -174,8 +135,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -174,8 +135,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# self.model: nn.Module # Set after load_model
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
# Initialize in initialize_kv_cache
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
self
.
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
]
=
[]
self
.
attn_backends
:
list
[
type
[
AttentionBackend
]]
=
[]
# self.kv_cache_config: KVCacheConfig
# self.kv_cache_config: KVCacheConfig
# self.
attn_metadata_builder: type[AttentionMetadataBuilder]
# self.
input_batch: InputBatch # Persistent batch.
# req_id -> (input_id -> encoder_output)
# req_id -> (input_id -> encoder_output)
self
.
encoder_cache
:
dict
[
str
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
self
.
encoder_cache
:
dict
[
str
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
...
@@ -200,16 +163,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -200,16 +163,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Request states.
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
# Persistent batch.
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_blocks_per_req
=
self
.
max_num_blocks_per_req
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
model_config
.
get_vocab_size
(),
)
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
==
CompilationLevel
.
PIECEWISE
...
@@ -304,6 +257,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -304,6 +257,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory
=
self
.
pin_memory
)
pin_memory
=
self
.
pin_memory
)
self
.
seq_lens_np
=
self
.
seq_lens_cpu
.
numpy
()
self
.
seq_lens_np
=
self
.
seq_lens_cpu
.
numpy
()
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
scheduler_output: The scheduler output.
Returns:
True if the batch was reordered, False otherwise.
"""
batch_reordered
=
self
.
attn_metadata_builders
[
0
].
reorder_batch
(
self
.
input_batch
,
scheduler_output
)
# For models with multiple KV cache groups, the groups should agree on
# the same order of requests. We ensure this by only allowing the first
# group to reorder the batch and asserting that all other groups do not
# reorder the batch.
for
i
in
range
(
1
,
len
(
self
.
kv_cache_config
.
kv_cache_groups
)):
assert
not
self
.
attn_metadata_builders
[
i
].
reorder_batch
(
self
.
input_batch
,
scheduler_output
)
return
batch_reordered
def
_update_states
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
def
_update_states
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""Update the cached states and the persistent batch with the scheduler
"""Update the cached states and the persistent batch with the scheduler
output.
output.
...
@@ -440,7 +418,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -440,7 +418,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update the block IDs.
# Update the block IDs.
if
not
req_data
.
resumed_from_preemption
:
if
not
req_data
.
resumed_from_preemption
:
# Append the new blocks to the existing block IDs.
# Append the new blocks to the existing block IDs.
req_state
.
block_ids
.
extend
(
req_data
.
new_block_ids
)
for
i
in
range
(
len
(
self
.
kv_cache_config
.
kv_cache_groups
)):
req_state
.
block_ids
[
i
].
extend
(
req_data
.
new_block_ids
[
i
])
else
:
else
:
# The request is resumed from preemption.
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
# Replace the existing block IDs with the new ones.
...
@@ -498,11 +477,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -498,11 +477,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
removed_req_indices
:
if
removed_req_indices
:
self
.
input_batch
.
condense
(
removed_req_indices
)
self
.
input_batch
.
condense
(
removed_req_indices
)
# Some attention backends (namely MLA) may want to separate requests
batch_reordered
=
self
.
_may_reorder_batch
(
scheduler_output
)
# based on if the attention computation will be compute-bound or
# memory-bound. This gives them a hook to do that.
batch_reordered
=
self
.
attn_metadata_builder
.
reorder_batch
(
self
.
input_batch
,
scheduler_output
)
if
batch_changed
or
batch_reordered
:
if
batch_changed
or
batch_reordered
:
self
.
input_batch
.
refresh_sampling_metadata
()
self
.
input_batch
.
refresh_sampling_metadata
()
...
@@ -570,21 +545,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -570,21 +545,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
torch
.
from_numpy
(
token_indices
),
torch
.
from_numpy
(
token_indices
),
out
=
self
.
input_ids_cpu
[:
total_num_scheduled_tokens
])
out
=
self
.
input_ids_cpu
[:
total_num_scheduled_tokens
])
# Calculate the slot mapping.
# Calculate the slot mapping for each KV cache group.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
block_size
=
kv_cache_group_spec
.
kv_cache_spec
.
block_size
block_table
:
BlockTable
=
self
.
input_batch
.
block_table
[
kv_cache_group_id
]
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# because M (max_model_len) is not necessarily divisible by block_size.
# here because M (max_model_len) is not necessarily divisible by
block_table_indices
=
(
req_indices
*
self
.
max_num_blocks_per_req
+
# block_size.
positions_np
//
self
.
block_size
)
block_table_indices
=
(
block_table_cpu
=
self
.
input_batch
.
block_table
.
get_cpu_tensor
()
req_indices
*
block_table
.
max_num_blocks_per_req
+
block_numbers
=
block_table_cpu
.
flatten
()[
block_table_indices
].
numpy
()
positions_np
//
block_size
)
block_offsets
=
positions_np
%
self
.
block_size
block_table_cpu
=
block_table
.
get_cpu_tensor
()
np
.
add
(
block_numbers
*
self
.
block_size
,
block_numbers
=
block_table_cpu
.
flatten
(
)[
block_table_indices
].
numpy
()
block_offsets
=
positions_np
%
block_size
np
.
add
(
block_numbers
*
block_size
,
block_offsets
,
block_offsets
,
out
=
self
.
input_batch
.
block_table
.
out
=
block_table
.
slot_mapping_np
[:
total_num_scheduled_tokens
])
slot_mapping_np
[:
total_num_scheduled_tokens
])
# Prepare the attention metadata.
# Prepare the attention metadata.
self
.
query_start_loc_np
[
0
]
=
0
self
.
query_start_loc_np
[
0
]
=
0
...
@@ -626,10 +609,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -626,10 +609,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata
:
dict
[
str
,
FlashAttentionMetadata
]
=
{}
attn_metadata
:
dict
[
str
,
FlashAttentionMetadata
]
=
{}
# Prepare the attention metadata for each KV cache group and make layers
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
# in the same group share the same metadata.
# NOTE(Chen): there is exactly one KV cache group that contains all
# attetnion layers in the model for now, so the current logic for
# getting attn_metadata is not related to kv_cache_group information.
# Will extend this part to support multiple KV cache groups later.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
self
.
kv_cache_config
.
kv_cache_groups
):
...
@@ -638,15 +617,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -638,15 +617,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
cascade_attn_enabled
:
if
self
.
cascade_attn_enabled
:
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
num_scheduled_tokens
,
scheduler_output
.
num_common_prefix_blocks
,
scheduler_output
.
num_common_prefix_blocks
[
kv_cache_group_id
],
kv_cache_group_spec
.
kv_cache_spec
,
self
.
attn_metadata_builders
[
kv_cache_group_id
],
)
)
attn_metadata_i
=
self
.
attn_metadata_builder
.
build
(
attn_metadata_i
=
(
self
.
attn_metadata_builders
[
kv_cache_group_id
].
build
(
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
common_attn_metadata
=
common_attn_metadata
)
common_attn_metadata
=
common_attn_metadata
)
)
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
attn_metadata
[
layer_name
]
=
attn_metadata_i
...
@@ -684,6 +667,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -684,6 +667,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
,
self
,
num_scheduled_tokens
:
np
.
ndarray
,
num_scheduled_tokens
:
np
.
ndarray
,
num_common_prefix_blocks
:
int
,
num_common_prefix_blocks
:
int
,
kv_cache_spec
:
KVCacheSpec
,
attn_metadata_builder
:
AttentionMetadataBuilder
,
)
->
int
:
)
->
int
:
"""Compute the length of the common prefix for cascade attention.
"""Compute the length of the common prefix for cascade attention.
...
@@ -702,7 +687,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -702,7 +687,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Returns:
Returns:
int: Length of common prefix in tokens.
int: Length of common prefix in tokens.
"""
"""
common_prefix_len
=
num_common_prefix_blocks
*
self
.
block_size
common_prefix_len
=
num_common_prefix_blocks
*
kv_cache_spec
.
block_size
if
common_prefix_len
==
0
:
if
common_prefix_len
==
0
:
# Common case.
# Common case.
return
0
return
0
...
@@ -751,15 +736,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -751,15 +736,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_prefix_len
,
common_prefix_len
,
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
].
min
())
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
].
min
())
# common_prefix_len should be a multiple of the block size.
# common_prefix_len should be a multiple of the block size.
common_prefix_len
=
(
common_prefix_len
//
self
.
block_size
*
common_prefix_len
=
(
common_prefix_len
//
kv_cache_spec
.
block_size
*
self
.
block_size
)
kv_cache_spec
.
block_size
)
use_cascade
=
self
.
attn_metadata_builder
.
use_cascade_attention
(
use_sliding_window
=
(
isinstance
(
kv_cache_spec
,
SlidingWindowSpec
)
or
(
isinstance
(
kv_cache_spec
,
FullAttentionSpec
)
and
kv_cache_spec
.
sliding_window
is
not
None
))
assert
isinstance
(
kv_cache_spec
,
AttentionSpec
)
use_cascade
=
attn_metadata_builder
.
use_cascade_attention
(
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
query_lens
=
num_scheduled_tokens
,
query_lens
=
num_scheduled_tokens
,
num_query_heads
=
self
.
num_query_heads
,
num_query_heads
=
self
.
num_query_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
kv_cache_spec
.
num_kv_heads
,
use_alibi
=
self
.
use_alibi
,
use_alibi
=
self
.
use_alibi
,
use_sliding_window
=
se
lf
.
window_size
is
not
None
,
use_sliding_window
=
u
se
_sliding_window
,
num_sms
=
self
.
num_sms
,
num_sms
=
self
.
num_sms
,
)
)
return
common_prefix_len
if
use_cascade
else
0
return
common_prefix_len
if
use_cascade
else
0
...
@@ -1577,7 +1566,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1577,7 +1566,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
if
skip_attn
:
if
skip_attn
:
attn_metadata
=
None
attn_metadata
:
Optional
[
dict
[
str
,
FlashAttentionMetadata
]]
=
None
else
:
else
:
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
...
@@ -1585,13 +1574,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1585,13 +1574,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
)
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
)
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
attn_metadata
=
{}
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
attn_metadata_i
=
(
self
.
attn_metadata_builders
[
kv_cache_group_id
].
build
(
num_reqs
=
num_tokens
,
num_reqs
=
num_tokens
,
num_actual_tokens
=
num_tokens
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
num_tokens
,
max_query_len
=
num_tokens
,
common_prefix_len
=
0
,
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
)
))
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
with
self
.
maybe_dummy_run_with_lora
(
self
.
lora_config
,
with
self
.
maybe_dummy_run_with_lora
(
self
.
lora_config
,
num_scheduled_tokens
):
num_scheduled_tokens
):
...
@@ -1822,6 +1817,56 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1822,6 +1817,56 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logger
.
info
(
"Graph capturing finished in %.0f secs, took %.2f GiB"
,
logger
.
info
(
"Graph capturing finished in %.0f secs, took %.2f GiB"
,
elapsed_time
,
cuda_graph_size
/
(
1
<<
30
))
elapsed_time
,
cuda_graph_size
/
(
1
<<
30
))
def
initialize_attn_backend
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Initialize the attention backends and attention metadata builders.
"""
assert
len
(
self
.
attn_backends
)
==
0
and
len
(
self
.
attn_metadata_builders
)
==
0
,
"Attention backends are already initialized"
for
i
,
kv_cache_group_spec
in
enumerate
(
kv_cache_config
.
kv_cache_groups
):
kv_cache_spec
=
kv_cache_group_spec
.
kv_cache_spec
if
not
isinstance
(
kv_cache_spec
,
AttentionSpec
):
raise
NotImplementedError
(
"Only AttentionSpec is supported for now."
)
attn_backend_i
=
get_attn_backend
(
kv_cache_spec
.
head_size
,
self
.
dtype
,
kv_cache_spec
.
dtype
,
kv_cache_spec
.
block_size
,
self
.
model_config
.
is_attention_free
,
use_mla
=
kv_cache_spec
.
use_mla
,
)
if
attn_backend_i
is
None
:
error_msg
=
(
f
"Error with get_attn_backend:
{
kv_cache_spec
.
head_size
=
}
, "
f
"
{
self
.
dtype
=
}
,
{
kv_cache_spec
.
dtype
=
}
, "
f
"
{
kv_cache_spec
.
block_size
=
}
, "
f
"
{
self
.
model_config
.
is_attention_free
=
}
, "
f
"
{
kv_cache_spec
.
use_mla
=
}
"
)
logger
.
error
(
error_msg
)
raise
NotImplementedError
(
"Non-Attention backend is not supported by V1 "
"GPUModelRunner."
)
if
self
.
vllm_config
.
compilation_config
.
full_cuda_graph
:
attn_backend_name
=
attn_backend_i
.
__name__
flash_attn_version
=
get_flash_attn_version
()
if
attn_backend_name
!=
"FlashAttentionBackend"
or
\
flash_attn_version
!=
3
:
raise
ValueError
(
f
"full_cuda_graph is only supported with "
f
"FA3. Current attention backend is "
f
"
{
attn_backend_name
}
, FlashAttention version is "
f
"
{
flash_attn_version
}
."
)
block_table_i
=
self
.
input_batch
.
block_table
[
i
]
attn_metadata_builder_i
=
attn_backend_i
.
get_builder_cls
()(
weakref
.
proxy
(
self
),
kv_cache_spec
,
block_table_i
)
self
.
attn_backends
.
append
(
attn_backend_i
)
self
.
attn_metadata_builders
.
append
(
attn_metadata_builder_i
)
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
"""
Initialize KV cache based on `kv_cache_config`.
Initialize KV cache based on `kv_cache_config`.
...
@@ -1829,15 +1874,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1829,15 +1874,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
cache size of each layer
"""
"""
if
len
(
kv_cache_config
.
kv_cache_groups
)
>
1
:
raise
NotImplementedError
(
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
self
.
kv_cache_config
=
kv_cache_config
self
.
kv_cache_config
=
kv_cache_config
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
kv_cache_config
=
kv_cache_config
,
)
self
.
initialize_attn_backend
(
kv_cache_config
)
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
:
for
i
,
kv_cache_group
in
enumerate
(
kv_cache_config
.
kv_cache_groups
)
:
kv_cache_spec
=
kv_cache_group
.
kv_cache_spec
kv_cache_spec
=
kv_cache_group
.
kv_cache_spec
for
layer_name
in
kv_cache_group
.
layer_names
:
for
layer_name
in
kv_cache_group
.
layer_names
:
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
...
@@ -1852,7 +1903,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1852,7 +1903,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the min of all `num_blocks`. Verify it here.
# the min of all `num_blocks`. Verify it here.
assert
num_blocks
>=
kv_cache_config
.
num_blocks
assert
num_blocks
>=
kv_cache_config
.
num_blocks
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
kv_cache_shape
=
self
.
attn_backend
s
[
i
]
.
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
dtype
=
kv_cache_spec
.
dtype
dtype
=
kv_cache_spec
.
dtype
...
@@ -1872,11 +1923,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1872,11 +1923,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
get_kv_transfer_group
().
register_kv_caches
(
kv_caches
)
get_kv_transfer_group
().
register_kv_caches
(
kv_caches
)
self
.
attn_metadata_builder
=
self
.
attn_backend
.
get_builder_cls
()(
weakref
.
proxy
(
self
),
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
,
self
.
input_batch
.
block_table
)
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]:
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]:
"""
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Generates the KVCacheSpec by parsing the kv cache format from each
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
e60f550b
...
@@ -171,19 +171,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -171,19 +171,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
# req_id -> (input_id -> encoder_output)
# req_id -> (input_id -> encoder_output)
self
.
encoder_cache
:
dict
[
str
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
self
.
encoder_cache
:
dict
[
str
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
# self.input_batch: InputBatch # Persistent batch.
# Request states.
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
# Persistent batch.
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_blocks_per_req
=
self
.
max_num_blocks_per_req
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
vocab_size
,
)
# Cached torch/numpy tensor
# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
# The pytorch tensor and numpy array share the same buffer.
...
@@ -199,7 +190,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -199,7 +190,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
block_table_cpu
=
torch
.
zeros
(
self
.
block_table_cpu
=
torch
.
zeros
(
(
self
.
max_num_reqs
,
self
.
max_num_blocks_per_req
),
(
self
.
max_num_reqs
,
self
.
max_num_blocks_per_req
),
dtype
=
self
.
input_batch
.
block_table
.
get_cpu_tensor
().
dtype
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
device
=
"cpu"
)
self
.
query_start_loc_cpu
=
torch
.
zeros
(
self
.
max_num_tokens
+
1
,
self
.
query_start_loc_cpu
=
torch
.
zeros
(
self
.
max_num_tokens
+
1
,
...
@@ -524,12 +515,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -524,12 +515,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): We use torch.index_select instead of np.take here
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# because torch.index_select is much faster than np.take for large
# tensors.
# tensors.
block_table_cpu
=
self
.
input_batch
.
block_table
.
get_cpu_tensor
()
block_table_cpu
=
self
.
input_batch
.
block_table
[
0
]
.
get_cpu_tensor
()
block_numbers
=
block_table_cpu
.
flatten
()[
block_table_indices
].
numpy
()
block_numbers
=
block_table_cpu
.
flatten
()[
block_table_indices
].
numpy
()
block_offsets
=
positions_np
%
self
.
block_size
block_offsets
=
positions_np
%
self
.
block_size
np
.
add
(
block_numbers
*
self
.
block_size
,
np
.
add
(
block_numbers
*
self
.
block_size
,
block_offsets
,
block_offsets
,
out
=
self
.
input_batch
.
block_table
.
out
=
self
.
input_batch
.
block_table
[
0
]
.
slot_mapping_np
[:
total_num_scheduled_tokens
])
slot_mapping_np
[:
total_num_scheduled_tokens
])
# Prepare the attention metadata.
# Prepare the attention metadata.
...
@@ -554,15 +545,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -554,15 +545,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
position_ids
=
self
.
positions_cpu
[:
self
.
position_ids
=
self
.
positions_cpu
[:
padded_total_num_scheduled_tokens
].
to
(
padded_total_num_scheduled_tokens
].
to
(
self
.
device
)
self
.
device
)
self
.
input_batch
.
block_table
.
slot_mapping_cpu
[
self
.
input_batch
.
block_table
[
0
]
.
slot_mapping_cpu
[
total_num_scheduled_tokens
:]
=
_PAD_SLOT_ID
total_num_scheduled_tokens
:]
=
_PAD_SLOT_ID
slot_mapping
=
(
slot_mapping
=
(
self
.
input_batch
.
block_table
.
self
.
input_batch
.
block_table
[
0
]
.
slot_mapping_cpu
[:
padded_total_num_scheduled_tokens
].
to
(
slot_mapping_cpu
[:
padded_total_num_scheduled_tokens
].
to
(
self
.
device
))
self
.
device
))
block_tables
=
self
.
block_table_cpu
[:
self
.
max_num_reqs
]
block_tables
=
self
.
block_table_cpu
[:
self
.
max_num_reqs
]
block_tables
[:
num_reqs
,
:
self
.
max_num_blocks_per_req
]
=
(
block_tables
[:
num_reqs
,
:
self
.
max_num_blocks_per_req
]
=
(
self
.
input_batch
.
block_table
.
get_cpu_tensor
()[:
num_reqs
])
self
.
input_batch
.
block_table
[
0
]
.
get_cpu_tensor
()[:
num_reqs
])
block_tables
=
block_tables
.
to
(
self
.
device
)
block_tables
=
block_tables
.
to
(
self
.
device
)
query_start_loc
=
self
.
query_start_loc_cpu
[:
self
.
max_num_reqs
+
1
].
to
(
query_start_loc
=
self
.
query_start_loc_cpu
[:
self
.
max_num_reqs
+
1
].
to
(
self
.
device
)
self
.
device
)
...
@@ -1263,6 +1254,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1263,6 +1254,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"Hybrid models with more than one KV cache type are not "
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
"supported yet."
)
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
kv_cache_config
=
kv_cache_config
,
)
assert
self
.
block_table_cpu
.
dtype
==
self
.
input_batch
.
block_table
[
0
].
get_cpu_tensor
().
dtype
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
:
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment