Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7cc302dd
Unverified
Commit
7cc302dd
authored
Mar 27, 2026
by
Or Ozeri
Committed by
GitHub
Mar 27, 2026
Browse files
[kv_offload+HMA][7/N]: Support register_kv_caches for hybrid models (#37853)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
999dfc16
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1477 additions
and
759 deletions
+1477
-759
tests/v1/kv_connector/unit/offloading_connector/__init__.py
tests/v1/kv_connector/unit/offloading_connector/__init__.py
+0
-0
tests/v1/kv_connector/unit/offloading_connector/conftest.py
tests/v1/kv_connector/unit/offloading_connector/conftest.py
+7
-0
tests/v1/kv_connector/unit/offloading_connector/test_metrics.py
...v1/kv_connector/unit/offloading_connector/test_metrics.py
+151
-0
tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py
.../kv_connector/unit/offloading_connector/test_scheduler.py
+341
-0
tests/v1/kv_connector/unit/offloading_connector/test_worker.py
.../v1/kv_connector/unit/offloading_connector/test_worker.py
+504
-0
tests/v1/kv_connector/unit/offloading_connector/utils.py
tests/v1/kv_connector/unit/offloading_connector/utils.py
+15
-487
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+3
-0
tests/v1/kv_offload/test_cpu_gpu.py
tests/v1/kv_offload/test_cpu_gpu.py
+88
-127
vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py
...tributed/kv_transfer/kv_connector/v1/offloading/worker.py
+227
-15
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+2
-1
vllm/v1/kv_offload/cpu/spec.py
vllm/v1/kv_offload/cpu/spec.py
+4
-14
vllm/v1/kv_offload/spec.py
vllm/v1/kv_offload/spec.py
+53
-6
vllm/v1/kv_offload/worker/cpu_gpu.py
vllm/v1/kv_offload/worker/cpu_gpu.py
+82
-109
No files found.
tests/v1/kv_connector/unit/offloading_connector/__init__.py
0 → 100644
View file @
7cc302dd
tests/v1/kv_connector/unit/offloading_connector/conftest.py
0 → 100644
View file @
7cc302dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
tests.v1.kv_connector.unit.offloading_connector.utils
import
(
request_runner
,
)
__all__
=
[
"request_runner"
]
tests/v1/kv_connector/unit/offloading_connector/test_metrics.py
0 → 100644
View file @
7cc302dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics
import
(
OffloadingConnectorStats
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector
import
(
OffloadingConnector
,
)
def
test_build_kv_connector_stats_with_none
():
"""Test that build_kv_connector_stats returns empty stats when given None."""
stats
=
OffloadingConnector
.
build_kv_connector_stats
(
data
=
None
)
assert
stats
is
not
None
assert
isinstance
(
stats
,
OffloadingConnectorStats
)
assert
len
(
stats
.
data
)
==
0
assert
stats
.
is_empty
()
def
test_build_kv_connector_stats_with_empty_dict
():
"""Test that build_kv_connector_stats returns empty stats with empty dict."""
stats
=
OffloadingConnector
.
build_kv_connector_stats
(
data
=
{})
assert
stats
is
not
None
assert
isinstance
(
stats
,
OffloadingConnectorStats
)
assert
len
(
stats
.
data
)
==
0
assert
stats
.
is_empty
()
def
test_build_kv_connector_stats_reconstructs_offload_stats
():
"""Test that OffloadingConnector stats are properly reconstructed with
correct data."""
serialized_data
=
{
"CPU_to_GPU"
:
[
{
"op_size"
:
16
,
"op_time"
:
1.0
},
{
"op_size"
:
8
,
"op_time"
:
0.5
},
],
"GPU_to_CPU"
:
[
{
"op_size"
:
1
,
"op_time"
:
0.1
},
{
"op_size"
:
2
,
"op_time"
:
0.2
},
],
}
stats
=
OffloadingConnector
.
build_kv_connector_stats
(
data
=
serialized_data
)
offload_connector_stats
=
stats
assert
isinstance
(
offload_connector_stats
,
OffloadingConnectorStats
)
assert
offload_connector_stats
.
data
[
"CPU_to_GPU"
]
==
[
{
"op_size"
:
16
,
"op_time"
:
1.0
},
{
"op_size"
:
8
,
"op_time"
:
0.5
},
]
assert
offload_connector_stats
.
data
[
"GPU_to_CPU"
]
==
[
{
"op_size"
:
1
,
"op_time"
:
0.1
},
{
"op_size"
:
2
,
"op_time"
:
0.2
},
]
def
test_aggregate_same_connector
():
"""Test aggregating stats from the same connector type."""
stats1
=
OffloadingConnectorStats
(
data
=
{
"CPU_to_GPU"
:
[
{
"op_size"
:
16
,
"op_time"
:
1.0
},
{
"op_size"
:
8
,
"op_time"
:
0.5
},
],
"GPU_to_CPU"
:
[
{
"op_size"
:
1
,
"op_time"
:
0.1
},
{
"op_size"
:
2
,
"op_time"
:
0.2
},
],
}
)
stats2
=
OffloadingConnectorStats
(
data
=
{
"CPU_to_GPU"
:
[
{
"op_size"
:
3
,
"op_time"
:
0.2
},
{
"op_size"
:
7
,
"op_time"
:
0.9
},
],
"GPU_to_CPU"
:
[{
"op_size"
:
16
,
"op_time"
:
2
}],
}
)
result
=
stats1
.
aggregate
(
stats2
)
assert
result
is
stats1
# Should return self
offload_connector_stats
=
result
assert
offload_connector_stats
.
data
[
"CPU_to_GPU"
]
==
[
{
"op_size"
:
16
,
"op_time"
:
1.0
},
{
"op_size"
:
8
,
"op_time"
:
0.5
},
{
"op_size"
:
3
,
"op_time"
:
0.2
},
{
"op_size"
:
7
,
"op_time"
:
0.9
},
]
assert
offload_connector_stats
.
data
[
"GPU_to_CPU"
]
==
[
{
"op_size"
:
1
,
"op_time"
:
0.1
},
{
"op_size"
:
2
,
"op_time"
:
0.2
},
{
"op_size"
:
16
,
"op_time"
:
2
},
]
def
test_reduce
():
"""Test that reduce() correctly reduces all nested connector stats."""
stats
=
OffloadingConnectorStats
(
data
=
{
"CPU_to_GPU"
:
[
{
"op_size"
:
16
,
"op_time"
:
1.0
},
{
"op_size"
:
8
,
"op_time"
:
0.5
},
{
"op_size"
:
3
,
"op_time"
:
0.2
},
{
"op_size"
:
7
,
"op_time"
:
0.9
},
],
"GPU_to_CPU"
:
[
{
"op_size"
:
1
,
"op_time"
:
0.1
},
{
"op_size"
:
2
,
"op_time"
:
0.2
},
{
"op_size"
:
16
,
"op_time"
:
2
},
],
}
)
reduced
=
stats
.
reduce
()
assert
isinstance
(
reduced
,
dict
)
# Check that the stats were reduced (should have aggregated values)
assert
"CPU_to_GPU_total_bytes"
in
reduced
assert
"CPU_to_GPU_total_time"
in
reduced
assert
"GPU_to_CPU_total_bytes"
in
reduced
assert
"GPU_to_CPU_total_time"
in
reduced
assert
reduced
[
"CPU_to_GPU_total_bytes"
]
==
34
assert
reduced
[
"CPU_to_GPU_total_time"
]
==
2.6
assert
reduced
[
"GPU_to_CPU_total_time"
]
==
2.3
assert
reduced
[
"GPU_to_CPU_total_bytes"
]
==
19
def
test_reset
():
"""Test that reset() resets all nested connector stats."""
offload_connector_stats
=
OffloadingConnectorStats
(
data
=
{
"CPU_to_GPU"
:
[
{
"op_size"
:
3
,
"op_time"
:
0.2
},
{
"op_size"
:
7
,
"op_time"
:
0.9
},
],
"GPU_to_CPU"
:
[{
"op_size"
:
16
,
"op_time"
:
2
}],
}
)
assert
not
offload_connector_stats
.
is_empty
()
offload_connector_stats
.
reset
()
# After reset, stats should be empty
assert
offload_connector_stats
.
is_empty
()
assert
len
(
offload_connector_stats
.
data
)
==
0
tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py
0 → 100644
View file @
7cc302dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
import
pytest
from
tests.v1.kv_connector.unit.offloading_connector.utils
import
(
generate_store_output
,
)
from
tests.v1.kv_connector.unit.utils
import
EOS_TOKEN_ID
from
vllm.distributed.kv_events
import
BlockRemoved
,
BlockStored
from
vllm.v1.core.kv_cache_utils
import
BlockHash
from
vllm.v1.kv_offload.abstract
import
OffloadingEvent
from
vllm.v1.request
import
RequestStatus
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
True
,
False
])
def
test_offloading_connector
(
request_runner
,
async_scheduling
:
bool
):
offloaded_block_size
=
12
gpu_block_size
=
4
num_gpu_blocks
=
100
block_size_factor
=
offloaded_block_size
//
gpu_block_size
runner
=
request_runner
(
offloaded_block_size
=
offloaded_block_size
,
gpu_block_size
=
gpu_block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
async_scheduling
=
async_scheduling
,
)
# 3 blocks, store just the middle block (skip first and last)
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
*
3
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
list
(
block_hashes
)[
1
:
2
])
)
runner
.
run
(
decoded_tokens
=
[
0
])
# add block missing 1 token -> no offload
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
offloaded_block_size
-
1
),
expected_stored_gpu_block_indexes
=
(
3
,
4
,
5
),
)
runner
.
manager
.
prepare_store
.
assert_not_called
()
# +1 token -> single block, fail prepare_store
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
block_hashes
:
None
runner
.
run
(
decoded_tokens
=
[
0
])
runner
.
manager
.
prepare_store
.
assert_called
()
# 1 more block (+ token for async scheduling)
# now set block_hashes_to_store = []
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
([])
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
offloaded_block_size
+
1
))
# 1 more block (+ token for kicking off offloading)
# now check touch was called with all 6 blocks
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
offloaded_block_size
+
1
),
expected_stored_gpu_block_indexes
=
(
15
,
16
,
17
),
)
runner
.
manager
.
touch
.
assert_called
()
block_hashes1
=
list
(
runner
.
manager
.
touch
.
call_args
.
args
[
0
])
assert
len
(
block_hashes1
)
==
6
# terminate request
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
])
# create a new request differing only on the last token
runner
.
new_request
(
token_ids
=
[
0
]
*
(
offloaded_block_size
*
6
-
1
)
+
[
1
])
runner
.
run
(
decoded_tokens
=
[
0
])
runner
.
manager
.
touch
.
assert_called
()
block_hashes2
=
list
(
runner
.
manager
.
touch
.
call_args
.
args
[
0
])
assert
len
(
block_hashes2
)
==
6
# verify hashes are the same, except for the last block
assert
block_hashes1
[:
5
]
==
block_hashes2
[:
5
]
assert
block_hashes1
[
5
]
!=
block_hashes2
[
5
]
# terminate request
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
tuple
(
range
(
6
*
block_size_factor
)),
)
# full_block_tokens - num_computed_tokens < offloaded_block_size
runner
.
new_request
(
token_ids
=
[
0
]
*
gpu_block_size
+
[
1
]
*
(
offloaded_block_size
-
gpu_block_size
)
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
([])
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
])
runner
.
manager
.
lookup
.
assert_not_called
()
# single block lookup with no hits
runner
.
new_request
(
token_ids
=
[
1
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
([])
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
])
runner
.
manager
.
lookup
.
assert_called
()
assert
len
(
list
(
runner
.
manager
.
lookup
.
call_args
.
args
[
0
]))
==
1
# single block lookup with a hit
runner
.
scheduler
.
reset_prefix_cache
()
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
([])
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
)
)
# single block lookup with a hit in a middle block
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
*
2
+
[
1
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
([])
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_loaded_gpu_block_indexes
=
(
3
,
4
,
5
)
)
# test take_events
def
to_hashes
(
int_hashes
:
list
[
int
])
->
list
[
BlockHash
]:
return
[
BlockHash
(
str
(
i
).
encode
())
for
i
in
int_hashes
]
def
take_events
()
->
Iterable
[
OffloadingEvent
]:
yield
OffloadingEvent
(
block_hashes
=
to_hashes
([
1
,
2
,
3
]),
block_size
=
16
,
medium
=
"A"
,
removed
=
False
)
yield
OffloadingEvent
(
block_hashes
=
to_hashes
([
4
,
5
,
6
]),
block_size
=
32
,
medium
=
"B"
,
removed
=
True
)
runner
.
manager
.
take_events
.
side_effect
=
take_events
events
=
list
(
runner
.
scheduler_connector
.
take_events
())
assert
len
(
events
)
==
2
event
=
events
[
0
]
assert
isinstance
(
event
,
BlockStored
)
assert
event
.
block_hashes
==
to_hashes
([
1
,
2
,
3
])
assert
event
.
block_size
==
16
assert
event
.
medium
==
"A"
assert
event
.
token_ids
==
[]
assert
event
.
parent_block_hash
is
None
assert
event
.
lora_id
is
None
assert
event
.
lora_name
is
None
event
=
events
[
1
]
assert
isinstance
(
event
,
BlockRemoved
)
assert
event
.
block_hashes
==
to_hashes
([
4
,
5
,
6
])
assert
event
.
medium
==
"B"
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
True
,
False
])
def
test_request_preemption
(
request_runner
,
async_scheduling
:
bool
):
offloaded_block_size
=
12
gpu_block_size
=
4
num_gpu_blocks
=
100
runner
=
request_runner
(
offloaded_block_size
=
offloaded_block_size
,
gpu_block_size
=
gpu_block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
async_scheduling
=
async_scheduling
,
)
free_block_queue
=
runner
.
scheduler
.
kv_cache_manager
.
block_pool
.
free_block_queue
num_free_blocks_empty
=
free_block_queue
.
num_free_blocks
# 2 blocks, store all, without flushing
# blocks = [0, 1, 2], [3, 4, 5]
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
*
2
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
],
complete_transfers
=
False
,
)
# decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
2
*
offloaded_block_size
-
gpu_block_size
),
complete_transfers
=
False
,
)
# simulate KV cache running out of space
free_block_queue
.
num_free_blocks
=
0
# request should be preempted now
runner
.
run
(
decoded_tokens
=
[],
complete_transfers
=
False
,
expected_flushed_gpu_block_indexes
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
),
expected_stored_gpu_block_indexes
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
),
)
# restore KV cache space and reset GPU prefix cache
free_block_queue
.
num_free_blocks
=
num_free_blocks_empty
runner
.
scheduler
.
reset_prefix_cache
()
# request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner
.
manager
.
lookup
.
return_value
=
3
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
gpu_block_size
,
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
),
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
(
9
,
10
,
11
),
)
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
True
,
False
])
def
test_concurrent_lookups_of_the_same_prefix
(
request_runner
,
async_scheduling
:
bool
):
offloaded_block_size
=
12
gpu_block_size
=
4
num_gpu_blocks
=
100
runner
=
request_runner
(
offloaded_block_size
=
offloaded_block_size
,
gpu_block_size
=
gpu_block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
async_scheduling
=
async_scheduling
,
)
# store 1 blocks
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
(
0
,
1
,
2
),
)
# start a request to load the first block, but don't complete
runner
.
scheduler
.
reset_prefix_cache
()
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[],
complete_transfers
=
False
,
)
# request triggered a load
transfer_jobs
=
list
(
runner
.
offloading_spec
.
handler
.
transfer_specs
)
assert
transfer_jobs
# start a new request to load the same first block
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[],
complete_transfers
=
False
,
)
# request did not trigger a load
assert
transfer_jobs
==
list
(
runner
.
offloading_spec
.
handler
.
transfer_specs
)
# complete transfers
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
([])
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
),
)
# second request will use the GPU prefix cache
assert
transfer_jobs
==
list
(
runner
.
offloading_spec
.
handler
.
transfer_specs
)
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
True
,
False
])
def
test_abort_loading_requests
(
request_runner
,
async_scheduling
:
bool
):
offloaded_block_size
=
12
gpu_block_size
=
4
num_gpu_blocks
=
100
runner
=
request_runner
(
offloaded_block_size
=
offloaded_block_size
,
gpu_block_size
=
gpu_block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
async_scheduling
=
async_scheduling
,
)
# store 1 blocks
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
(
0
,
1
,
2
),
)
# start a request to load the first block, but don't complete
runner
.
scheduler
.
reset_prefix_cache
()
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[],
complete_transfers
=
False
,
)
# request triggered a load
transfer_jobs
=
list
(
runner
.
offloading_spec
.
handler
.
transfer_specs
)
assert
transfer_jobs
# abort request
req_id
=
str
(
runner
.
req_id
)
runner
.
scheduler
.
finish_requests
((
req_id
,),
RequestStatus
.
FINISHED_ABORTED
)
# verify request is not deleted
assert
req_id
in
runner
.
scheduler
.
requests
# complete loading request
runner
.
run
(
decoded_tokens
=
[],
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
),
)
# assert request is deleted
assert
req_id
not
in
runner
.
scheduler
.
requests
tests/v1/kv_connector/unit/offloading_connector/test_worker.py
0 → 100644
View file @
7cc302dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections
import
defaultdict
from
unittest.mock
import
MagicMock
,
patch
import
pytest
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
get_dtype_size
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.utils
import
set_kv_cache_layout
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
,
MambaSpec
,
MLAAttentionSpec
,
UniformTypeKVCacheSpecs
,
)
from
vllm.v1.kv_offload.spec
import
(
CanonicalKVCacheRef
,
CanonicalKVCaches
,
OffloadingSpec
,
)
NUM_BLOCKS
=
10
BLOCK_SIZE
=
16
NUM_KV_HEADS
=
4
HEAD_SIZE
=
64
DTYPE
=
torch
.
float16
# Attention backends to test
ATTN_BACKENDS
:
list
[
str
]
=
[]
if
current_platform
.
is_cuda
():
ATTN_BACKENDS
=
[
"FLASH_ATTN"
,
"FLEX_ATTENTION"
,
"FLASHINFER"
,
"TRITON_ATTN"
,
]
elif
current_platform
.
is_rocm
():
ATTN_BACKENDS
=
[
"TRITON_ATTN"
]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def
_allocate_and_reshape_kv_caches
(
kv_cache_config
:
KVCacheConfig
,
attn_groups
:
list
[
list
],
device
:
torch
.
device
,
):
"""
Use the real GPUModelRunner allocation and reshape methods to produce
kv_caches, just like the model runner does during initialization.
"""
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
# Some backends (e.g. FlashAttention) query the KV cache layout during
# reshape, which ultimately calls get_current_vllm_config(). Setting
# the layout override avoids needing a full VllmConfig context.
set_kv_cache_layout
(
"NHD"
)
try
:
runner
=
object
.
__new__
(
GPUModelRunner
)
runner
.
device
=
device
runner
.
runner_only_attn_layers
=
set
()
runner
.
attn_groups
=
attn_groups
runner
.
kv_cache_config
=
kv_cache_config
runner
.
cache_config
=
MagicMock
(
cache_dtype
=
"auto"
)
runner
.
shared_kv_cache_layers
=
{}
runner
.
model_config
=
MagicMock
()
runner
.
model_config
.
hf_config
.
model_type
=
""
runner
.
compilation_config
=
MagicMock
(
static_forward_context
=
defaultdict
(
MagicMock
)
)
runner
.
kv_caches
=
[]
kernel_block_sizes
=
[
BLOCK_SIZE
]
*
len
(
kv_cache_config
.
kv_cache_groups
)
return
runner
.
initialize_kv_cache_tensors
(
kv_cache_config
,
kernel_block_sizes
)
finally
:
set_kv_cache_layout
(
None
)
def
_make_mock_layer
(
backend_cls
:
type
[
AttentionBackend
]):
"""
Create a mock AttentionLayerBase whose get_attn_backend returns backend_cls.
"""
layer
=
MagicMock
()
layer
.
get_attn_backend
.
return_value
=
backend_cls
return
layer
def
_make_worker
(
kv_cache_config
:
KVCacheConfig
):
"""
Create an OffloadingConnectorWorker with mocked dependencies.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading.worker
import
(
OffloadingConnectorWorker
,
)
spec
=
MagicMock
(
spec
=
OffloadingSpec
)
spec
.
kv_cache_config
=
kv_cache_config
spec
.
vllm_config
=
MagicMock
()
spec
.
get_handlers
.
return_value
=
iter
([])
worker
=
OffloadingConnectorWorker
(
spec
=
spec
)
worker
.
worker
=
MagicMock
()
return
worker
,
spec
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@
pytest
.
mark
.
parametrize
(
"backend"
,
ATTN_BACKENDS
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.offloading"
".worker.get_layers_from_vllm_config"
)
def
test_register_kv_caches
(
mock_get_layers
,
backend
):
"""Test register_kv_caches with multiple groups covering all layer types.
Creates one FullAttention group, one MLA group, one Mamba group, and
one Mamba-padded group. Each group has GROUP_SIZE layers.
KVCacheTensors are shared across all groups mirroring the real allocation
in kv_cache_utils.py: tensor i is shared by layer i from every group.
The padded-mamba group has a different page size so its layers get their
own dedicated tensors.
Uses the real GPUModelRunner.initialize_kv_cache_tensors to produce
kv_caches, which automatically applies
_update_hybrid_attention_mamba_layout for hybrid models.
Verifies that the canonicalized CanonicalKVCaches has the correct
block tensors, tensor_idx references, and page sizes across all groups.
"""
from
vllm.v1.attention.backends.mla.indexer
import
(
DeepseekV32IndexerBackend
,
)
from
vllm.v1.worker.utils
import
AttentionGroup
MLA_HEAD_SIZE
=
NUM_KV_HEADS
*
HEAD_SIZE
*
2
# padded mamba (missing HEAD_SIZE)
CONV_STATE_SHAPE
=
(
BLOCK_SIZE
*
NUM_KV_HEADS
,
HEAD_SIZE
)
UNALIGNED_SSM_STATE_SHAPE
=
(
BLOCK_SIZE
*
NUM_KV_HEADS
-
1
,
HEAD_SIZE
)
PAGE_SIZE_BYTES
=
2
*
BLOCK_SIZE
*
NUM_KV_HEADS
*
HEAD_SIZE
*
get_dtype_size
(
DTYPE
)
unaligned_mamba_page_size
=
PAGE_SIZE_BYTES
-
HEAD_SIZE
*
get_dtype_size
(
DTYPE
)
# unpadded mamba (fills page exactly)
ALIGNED_SSM_STATE_SHAPE
=
(
BLOCK_SIZE
*
NUM_KV_HEADS
,
HEAD_SIZE
)
backend_cls
=
AttentionBackendEnum
[
backend
].
get_class
()
attn_spec
=
FullAttentionSpec
(
block_size
=
BLOCK_SIZE
,
num_kv_heads
=
NUM_KV_HEADS
,
head_size
=
HEAD_SIZE
,
dtype
=
DTYPE
,
)
mla_spec
=
MLAAttentionSpec
(
block_size
=
BLOCK_SIZE
,
num_kv_heads
=
1
,
head_size
=
MLA_HEAD_SIZE
,
dtype
=
DTYPE
,
)
unaligned_mamba_spec
=
MambaSpec
(
block_size
=
BLOCK_SIZE
,
shapes
=
(
CONV_STATE_SHAPE
,
UNALIGNED_SSM_STATE_SHAPE
),
dtypes
=
(
DTYPE
,
DTYPE
),
page_size_padded
=
PAGE_SIZE_BYTES
,
)
aligned_mamba_spec
=
MambaSpec
(
block_size
=
BLOCK_SIZE
,
shapes
=
(
CONV_STATE_SHAPE
,
ALIGNED_SSM_STATE_SHAPE
),
dtypes
=
(
DTYPE
,
DTYPE
),
page_size_padded
=
PAGE_SIZE_BYTES
,
)
assert
attn_spec
.
page_size_bytes
==
PAGE_SIZE_BYTES
assert
mla_spec
.
page_size_bytes
==
PAGE_SIZE_BYTES
assert
unaligned_mamba_spec
.
page_size_bytes
==
PAGE_SIZE_BYTES
assert
aligned_mamba_spec
.
page_size_bytes
==
PAGE_SIZE_BYTES
GROUP_SIZE
=
3
# -- Build per-group layer info ----------------------------------------
layer_idx
=
0
attn_layer_names
=
[]
for
_
in
range
(
GROUP_SIZE
):
attn_layer_names
.
append
(
f
"model.layers.
{
layer_idx
}
.self_attn"
)
layer_idx
+=
1
mla_layer_names
=
[]
for
_
in
range
(
GROUP_SIZE
):
mla_layer_names
.
append
(
f
"model.layers.
{
layer_idx
}
.self_attn"
)
layer_idx
+=
1
unaligned_mamba_layer_names
=
[]
for
_
in
range
(
GROUP_SIZE
):
unaligned_mamba_layer_names
.
append
(
f
"model.layers.
{
layer_idx
}
.mamba_unpadded"
)
layer_idx
+=
1
aligned_mamba_layer_names
=
[]
for
_
in
range
(
GROUP_SIZE
-
1
):
aligned_mamba_layer_names
.
append
(
f
"model.layers.
{
layer_idx
}
.mamba_padded"
)
layer_idx
+=
1
layer_groups
=
[
attn_layer_names
,
mla_layer_names
,
unaligned_mamba_layer_names
,
aligned_mamba_layer_names
,
]
kv_cache_tensors
:
list
[
KVCacheTensor
]
=
[]
for
i
in
range
(
GROUP_SIZE
):
shared_by
:
list
[
str
]
=
[]
for
group_layer_names
in
layer_groups
:
if
len
(
group_layer_names
)
>
i
:
shared_by
.
append
(
group_layer_names
[
i
])
kv_cache_tensors
.
append
(
KVCacheTensor
(
size
=
PAGE_SIZE_BYTES
*
NUM_BLOCKS
,
shared_by
=
shared_by
,
)
)
kv_cache_groups
=
[
KVCacheGroupSpec
(
layer_names
=
attn_layer_names
,
kv_cache_spec
=
attn_spec
),
KVCacheGroupSpec
(
layer_names
=
mla_layer_names
,
kv_cache_spec
=
mla_spec
),
KVCacheGroupSpec
(
layer_names
=
unaligned_mamba_layer_names
,
kv_cache_spec
=
unaligned_mamba_spec
),
KVCacheGroupSpec
(
layer_names
=
aligned_mamba_layer_names
,
kv_cache_spec
=
aligned_mamba_spec
),
]
attn_groups
=
[
[
AttentionGroup
(
backend
=
backend_cls
,
layer_names
=
attn_layer_names
,
kv_cache_spec
=
attn_spec
,
kv_cache_group_id
=
0
,
),
AttentionGroup
(
backend
=
DeepseekV32IndexerBackend
,
layer_names
=
mla_layer_names
,
kv_cache_spec
=
mla_spec
,
kv_cache_group_id
=
1
,
),
AttentionGroup
(
backend
=
DeepseekV32IndexerBackend
,
# unused for mamba
layer_names
=
unaligned_mamba_layer_names
,
kv_cache_spec
=
unaligned_mamba_spec
,
kv_cache_group_id
=
2
,
),
AttentionGroup
(
backend
=
DeepseekV32IndexerBackend
,
# unused for mamba
layer_names
=
aligned_mamba_layer_names
,
kv_cache_spec
=
aligned_mamba_spec
,
kv_cache_group_id
=
3
,
),
]
]
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
NUM_BLOCKS
,
kv_cache_tensors
=
kv_cache_tensors
,
kv_cache_groups
=
kv_cache_groups
,
)
kv_caches
=
_allocate_and_reshape_kv_caches
(
kv_cache_config
,
attn_groups
,
device
=
torch
.
device
(
"cuda:0"
),
)
mock_layers
:
dict
[
str
,
MagicMock
]
=
{}
for
layer_name
in
attn_layer_names
:
mock_layers
[
layer_name
]
=
_make_mock_layer
(
backend_cls
)
for
layer_name
in
mla_layer_names
:
mock_layers
[
layer_name
]
=
_make_mock_layer
(
DeepseekV32IndexerBackend
)
mock_get_layers
.
return_value
=
mock_layers
worker
,
spec
=
_make_worker
(
kv_cache_config
)
worker
.
register_kv_caches
(
kv_caches
)
canonical
=
spec
.
get_handlers
.
call_args
[
0
][
0
]
assert
isinstance
(
canonical
,
CanonicalKVCaches
)
# -- Expected block tensors ----------------------------------------------
# All tensors have the same padded page size (PAGE_SIZE_BYTES).
# Tensor 0: shared by attn[0], mla[0], mamba_unaligned[0], mamba_aligned[0]
# Tensor 1: shared by attn[1], mla[1], mamba_unaligned[1], mamba_aligned[1]
# Tensor 2: shared by attn[2], mla[2], mamba_unaligned[2]
# (mamba_aligned has only GROUP_SIZE-1 = 2 layers)
expected_tensors
=
[
(
NUM_BLOCKS
,
PAGE_SIZE_BYTES
),
(
NUM_BLOCKS
,
PAGE_SIZE_BYTES
),
(
NUM_BLOCKS
,
PAGE_SIZE_BYTES
),
]
# -- Expected group data refs (order matches kv_cache_groups) -------------
ref
=
CanonicalKVCacheRef
expected_group_refs
=
[
# attn group: layers attn[0..2] → tensors 0,1,2 with full page size
[
ref
(
tensor_idx
=
0
,
page_size_bytes
=
PAGE_SIZE_BYTES
),
ref
(
tensor_idx
=
1
,
page_size_bytes
=
PAGE_SIZE_BYTES
),
ref
(
tensor_idx
=
2
,
page_size_bytes
=
PAGE_SIZE_BYTES
),
],
# mla group: layers mla[0..2] → tensors 0,1,2 with full page size
[
ref
(
tensor_idx
=
0
,
page_size_bytes
=
PAGE_SIZE_BYTES
),
ref
(
tensor_idx
=
1
,
page_size_bytes
=
PAGE_SIZE_BYTES
),
ref
(
tensor_idx
=
2
,
page_size_bytes
=
PAGE_SIZE_BYTES
),
],
# unaligned mamba group: layers [0..2] → tensors 0,1,2 with unaligned page
[
ref
(
tensor_idx
=
0
,
page_size_bytes
=
unaligned_mamba_page_size
),
ref
(
tensor_idx
=
1
,
page_size_bytes
=
unaligned_mamba_page_size
),
ref
(
tensor_idx
=
2
,
page_size_bytes
=
unaligned_mamba_page_size
),
],
# aligned mamba group: layers [0..1] → tensors 0,1 with full page size
[
ref
(
tensor_idx
=
0
,
page_size_bytes
=
PAGE_SIZE_BYTES
),
ref
(
tensor_idx
=
1
,
page_size_bytes
=
PAGE_SIZE_BYTES
),
],
]
# Verify block tensors
assert
len
(
canonical
.
tensors
)
==
len
(
expected_tensors
)
for
block_tensor
,
(
exp_num_blocks
,
exp_page_size
)
in
zip
(
canonical
.
tensors
,
expected_tensors
):
tensor
=
block_tensor
.
tensor
assert
tensor
.
dtype
==
torch
.
int8
assert
tensor
.
shape
==
(
exp_num_blocks
,
exp_page_size
)
assert
block_tensor
.
page_size_bytes
==
exp_page_size
# Verify group data refs
assert
len
(
canonical
.
group_data_refs
)
==
len
(
expected_group_refs
)
for
actual_refs
,
exp_refs
in
zip
(
canonical
.
group_data_refs
,
expected_group_refs
):
assert
len
(
actual_refs
)
==
len
(
exp_refs
)
for
actual
,
expected
in
zip
(
actual_refs
,
exp_refs
):
assert
actual
.
tensor_idx
==
expected
.
tensor_idx
assert
actual
.
page_size_bytes
==
expected
.
page_size_bytes
@
pytest
.
mark
.
parametrize
(
"backend"
,
ATTN_BACKENDS
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.offloading"
".worker.get_layers_from_vllm_config"
)
def
test_register_kv_caches_uniform_type
(
mock_get_layers
,
backend
):
"""Test register_kv_caches with UniformTypeKVCacheSpecs.
Two attention layers use the same backend but different num_kv_heads,
giving them different per-layer page sizes. Each has its own
KVCacheTensor and are wrapped in a UniformTypeKVCacheSpecs group.
Verifies that each layer gets the correct tensor_idx and
page_size_bytes in its block data ref.
"""
from
vllm.v1.worker.utils
import
AttentionGroup
backend_cls
=
AttentionBackendEnum
[
backend
].
get_class
()
layer_a
=
"model.layers.0.self_attn"
layer_b
=
"model.layers.1.self_attn"
spec_a
=
FullAttentionSpec
(
block_size
=
BLOCK_SIZE
,
num_kv_heads
=
NUM_KV_HEADS
,
head_size
=
HEAD_SIZE
,
dtype
=
DTYPE
,
)
spec_b
=
FullAttentionSpec
(
block_size
=
BLOCK_SIZE
,
num_kv_heads
=
NUM_KV_HEADS
*
2
,
head_size
=
HEAD_SIZE
,
dtype
=
DTYPE
,
)
assert
spec_a
.
page_size_bytes
!=
spec_b
.
page_size_bytes
uniform_spec
=
UniformTypeKVCacheSpecs
(
block_size
=
BLOCK_SIZE
,
kv_cache_specs
=
{
layer_a
:
spec_a
,
layer_b
:
spec_b
},
)
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
NUM_BLOCKS
,
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
spec_a
.
page_size_bytes
*
NUM_BLOCKS
,
shared_by
=
[
layer_a
],
),
KVCacheTensor
(
size
=
spec_b
.
page_size_bytes
*
NUM_BLOCKS
,
shared_by
=
[
layer_b
],
),
],
kv_cache_groups
=
[
KVCacheGroupSpec
(
layer_names
=
[
layer_a
,
layer_b
],
kv_cache_spec
=
uniform_spec
,
)
],
)
attn_groups
=
[
[
AttentionGroup
(
backend
=
backend_cls
,
layer_names
=
[
layer_a
],
kv_cache_spec
=
spec_a
,
kv_cache_group_id
=
0
,
),
AttentionGroup
(
backend
=
backend_cls
,
layer_names
=
[
layer_b
],
kv_cache_spec
=
spec_b
,
kv_cache_group_id
=
0
,
),
]
]
kv_caches
=
_allocate_and_reshape_kv_caches
(
kv_cache_config
,
attn_groups
,
device
=
torch
.
device
(
"cuda:0"
),
)
mock_get_layers
.
return_value
=
{
layer_a
:
_make_mock_layer
(
backend_cls
),
layer_b
:
_make_mock_layer
(
backend_cls
),
}
worker
,
spec
=
_make_worker
(
kv_cache_config
)
worker
.
register_kv_caches
(
kv_caches
)
canonical
=
spec
.
get_handlers
.
call_args
[
0
][
0
]
assert
isinstance
(
canonical
,
CanonicalKVCaches
)
unbinds
=
backend_cls
.
get_name
()
in
(
"FLASH_ATTN"
,
"FLEX_ATTENTION"
)
tensors_per_layer
=
2
if
unbinds
else
1
for
block_tensor
in
canonical
.
tensors
:
assert
block_tensor
.
tensor
.
dtype
==
torch
.
int8
# Single group with refs from both layers
assert
len
(
canonical
.
group_data_refs
)
==
1
group_refs
=
canonical
.
group_data_refs
[
0
]
assert
len
(
group_refs
)
==
2
*
tensors_per_layer
if
unbinds
:
half_a
=
spec_a
.
page_size_bytes
//
2
half_b
=
spec_b
.
page_size_bytes
//
2
assert
len
(
canonical
.
tensors
)
==
4
assert
canonical
.
tensors
[
0
].
page_size_bytes
==
half_a
assert
canonical
.
tensors
[
1
].
page_size_bytes
==
half_a
assert
canonical
.
tensors
[
2
].
page_size_bytes
==
half_b
assert
canonical
.
tensors
[
3
].
page_size_bytes
==
half_b
assert
canonical
.
tensors
[
0
].
tensor
.
shape
==
(
NUM_BLOCKS
,
half_a
)
assert
canonical
.
tensors
[
1
].
tensor
.
shape
==
(
NUM_BLOCKS
,
half_a
)
assert
canonical
.
tensors
[
2
].
tensor
.
shape
==
(
NUM_BLOCKS
,
half_b
)
assert
canonical
.
tensors
[
3
].
tensor
.
shape
==
(
NUM_BLOCKS
,
half_b
)
assert
group_refs
[
0
]
==
CanonicalKVCacheRef
(
tensor_idx
=
0
,
page_size_bytes
=
half_a
)
assert
group_refs
[
1
]
==
CanonicalKVCacheRef
(
tensor_idx
=
1
,
page_size_bytes
=
half_a
)
assert
group_refs
[
2
]
==
CanonicalKVCacheRef
(
tensor_idx
=
2
,
page_size_bytes
=
half_b
)
assert
group_refs
[
3
]
==
CanonicalKVCacheRef
(
tensor_idx
=
3
,
page_size_bytes
=
half_b
)
else
:
assert
len
(
canonical
.
tensors
)
==
2
assert
canonical
.
tensors
[
0
].
page_size_bytes
==
spec_a
.
page_size_bytes
assert
canonical
.
tensors
[
1
].
page_size_bytes
==
spec_b
.
page_size_bytes
assert
canonical
.
tensors
[
0
].
tensor
.
shape
==
(
NUM_BLOCKS
,
spec_a
.
page_size_bytes
)
assert
canonical
.
tensors
[
1
].
tensor
.
shape
==
(
NUM_BLOCKS
,
spec_b
.
page_size_bytes
)
assert
group_refs
[
0
]
==
CanonicalKVCacheRef
(
tensor_idx
=
0
,
page_size_bytes
=
spec_a
.
page_size_bytes
)
assert
group_refs
[
1
]
==
CanonicalKVCacheRef
(
tensor_idx
=
1
,
page_size_bytes
=
spec_b
.
page_size_bytes
)
tests/v1/kv_connector/unit/
test_
offloading_connector.py
→
tests/v1/kv_connector/unit/offloading_connector
/utils
.py
View file @
7cc302dd
...
@@ -9,16 +9,17 @@ from unittest.mock import MagicMock
...
@@ -9,16 +9,17 @@ from unittest.mock import MagicMock
import
pytest
import
pytest
import
torch
import
torch
from
tests.v1.kv_connector.unit.utils
import
(
EOS_TOKEN_ID
,
create_model_runner_output
,
create_vllm_config
,
)
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.config
import
KVTransferConfig
,
VllmConfig
from
vllm.config
import
KVTransferConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.distributed.kv_events
import
BlockRemoved
,
BlockStored
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorRole
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorRole
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading.common
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading.common
import
(
OffloadingConnectorMetadata
,
OffloadingConnectorMetadata
,
)
)
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics
import
(
OffloadingConnectorStats
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector
import
(
OffloadingConnector
,
OffloadingConnector
,
)
)
...
@@ -39,7 +40,6 @@ from vllm.v1.kv_cache_interface import (
...
@@ -39,7 +40,6 @@ from vllm.v1.kv_cache_interface import (
)
)
from
vllm.v1.kv_offload.abstract
import
(
from
vllm.v1.kv_offload.abstract
import
(
LoadStoreSpec
,
LoadStoreSpec
,
OffloadingEvent
,
OffloadingManager
,
OffloadingManager
,
PrepareStoreOutput
,
PrepareStoreOutput
,
)
)
...
@@ -51,15 +51,9 @@ from vllm.v1.kv_offload.worker.worker import (
...
@@ -51,15 +51,9 @@ from vllm.v1.kv_offload.worker.worker import (
TransferSpec
,
TransferSpec
,
)
)
from
vllm.v1.outputs
import
EMPTY_MODEL_RUNNER_OUTPUT
,
KVConnectorOutput
from
vllm.v1.outputs
import
EMPTY_MODEL_RUNNER_OUTPUT
,
KVConnectorOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.structured_output
import
StructuredOutputManager
from
.utils
import
(
EOS_TOKEN_ID
,
create_model_runner_output
,
create_vllm_config
,
)
class
MockLoadStoreSpec
(
LoadStoreSpec
):
class
MockLoadStoreSpec
(
LoadStoreSpec
):
def
__init__
(
self
,
block_hashes
:
Iterable
[
BlockHash
]):
def
__init__
(
self
,
block_hashes
:
Iterable
[
BlockHash
]):
...
@@ -125,7 +119,7 @@ class MockOffloadingSpec(OffloadingSpec):
...
@@ -125,7 +119,7 @@ class MockOffloadingSpec(OffloadingSpec):
return
self
.
manager
return
self
.
manager
def
get_handlers
(
def
get_handlers
(
self
,
_
,
__
self
,
_
)
->
Iterator
[
tuple
[
type
[
LoadStoreSpec
],
type
[
LoadStoreSpec
],
OffloadingHandler
]]:
)
->
Iterator
[
tuple
[
type
[
LoadStoreSpec
],
type
[
LoadStoreSpec
],
OffloadingHandler
]]:
yield
GPULoadStoreSpec
,
MockLoadStoreSpec
,
self
.
handler
yield
GPULoadStoreSpec
,
MockLoadStoreSpec
,
self
.
handler
yield
MockLoadStoreSpec
,
GPULoadStoreSpec
,
self
.
handler
yield
MockLoadStoreSpec
,
GPULoadStoreSpec
,
self
.
handler
...
@@ -179,7 +173,7 @@ class RequestRunner:
...
@@ -179,7 +173,7 @@ class RequestRunner:
kv_role
=
"kv_both"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
kv_connector_extra_config
=
{
"spec_name"
:
"MockOffloadingSpec"
,
"spec_name"
:
"MockOffloadingSpec"
,
"spec_module_path"
:
"tests.v1.kv_connector.unit.
test_
offloading_connector"
,
# noqa: E501
"spec_module_path"
:
"tests.v1.kv_connector.unit.offloading_connector
.utils
"
,
# noqa: E501
"block_size"
:
offloaded_block_size
,
"block_size"
:
offloaded_block_size
,
},
},
)
)
...
@@ -217,10 +211,12 @@ class RequestRunner:
...
@@ -217,10 +211,12 @@ class RequestRunner:
)
)
# register worker kv_caches to enable OffloadingWorker creations
# register worker kv_caches to enable OffloadingWorker creations
self
.
worker_connector
.
register_cross_layers_kv_cache
(
# set_current_vllm_config is needed for get_kv_cache_layout() to work
kv_cache
=
torch
.
empty
(
0
),
with
set_current_vllm_config
(
vllm_config
):
attn_backend
=
FlashAttentionBackend
,
self
.
worker_connector
.
register_cross_layers_kv_cache
(
)
kv_cache
=
torch
.
empty
(
0
),
attn_backend
=
FlashAttentionBackend
,
)
# extract connector of scheduler
# extract connector of scheduler
scheduler_connector
=
self
.
scheduler
.
connector
scheduler_connector
=
self
.
scheduler
.
connector
...
@@ -521,471 +517,3 @@ def generate_store_output(block_hashes: Iterable[BlockHash]):
...
@@ -521,471 +517,3 @@ def generate_store_output(block_hashes: Iterable[BlockHash]):
store_spec
=
MockLoadStoreSpec
(
block_hashes
),
store_spec
=
MockLoadStoreSpec
(
block_hashes
),
block_hashes_evicted
=
[],
block_hashes_evicted
=
[],
)
)
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
True
,
False
])
def
test_offloading_connector
(
request_runner
,
async_scheduling
:
bool
):
offloaded_block_size
=
12
gpu_block_size
=
4
num_gpu_blocks
=
100
block_size_factor
=
offloaded_block_size
//
gpu_block_size
runner
=
request_runner
(
offloaded_block_size
=
offloaded_block_size
,
gpu_block_size
=
gpu_block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
async_scheduling
=
async_scheduling
,
)
# 3 blocks, store just the middle block (skip first and last)
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
*
3
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
list
(
block_hashes
)[
1
:
2
])
)
runner
.
run
(
decoded_tokens
=
[
0
])
# add block missing 1 token -> no offload
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
offloaded_block_size
-
1
),
expected_stored_gpu_block_indexes
=
(
3
,
4
,
5
),
)
runner
.
manager
.
prepare_store
.
assert_not_called
()
# +1 token -> single block, fail prepare_store
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
block_hashes
:
None
runner
.
run
(
decoded_tokens
=
[
0
])
runner
.
manager
.
prepare_store
.
assert_called
()
# 1 more block (+ token for async scheduling)
# now set block_hashes_to_store = []
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
([])
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
offloaded_block_size
+
1
))
# 1 more block (+ token for kicking off offloading)
# now check touch was called with all 6 blocks
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
offloaded_block_size
+
1
),
expected_stored_gpu_block_indexes
=
(
15
,
16
,
17
),
)
runner
.
manager
.
touch
.
assert_called
()
block_hashes1
=
list
(
runner
.
manager
.
touch
.
call_args
.
args
[
0
])
assert
len
(
block_hashes1
)
==
6
# terminate request
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
])
# create a new request differing only on the last token
runner
.
new_request
(
token_ids
=
[
0
]
*
(
offloaded_block_size
*
6
-
1
)
+
[
1
])
runner
.
run
(
decoded_tokens
=
[
0
])
runner
.
manager
.
touch
.
assert_called
()
block_hashes2
=
list
(
runner
.
manager
.
touch
.
call_args
.
args
[
0
])
assert
len
(
block_hashes2
)
==
6
# verify hashes are the same, except for the last block
assert
block_hashes1
[:
5
]
==
block_hashes2
[:
5
]
assert
block_hashes1
[
5
]
!=
block_hashes2
[
5
]
# terminate request
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
tuple
(
range
(
6
*
block_size_factor
)),
)
# full_block_tokens - num_computed_tokens < offloaded_block_size
runner
.
new_request
(
token_ids
=
[
0
]
*
gpu_block_size
+
[
1
]
*
(
offloaded_block_size
-
gpu_block_size
)
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
([])
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
])
runner
.
manager
.
lookup
.
assert_not_called
()
# single block lookup with no hits
runner
.
new_request
(
token_ids
=
[
1
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
([])
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
])
runner
.
manager
.
lookup
.
assert_called
()
assert
len
(
list
(
runner
.
manager
.
lookup
.
call_args
.
args
[
0
]))
==
1
# single block lookup with a hit
runner
.
scheduler
.
reset_prefix_cache
()
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
([])
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
)
)
# single block lookup with a hit in a middle block
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
*
2
+
[
1
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
([])
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_loaded_gpu_block_indexes
=
(
3
,
4
,
5
)
)
# test take_events
def
to_hashes
(
int_hashes
:
list
[
int
])
->
list
[
BlockHash
]:
return
[
BlockHash
(
str
(
i
).
encode
())
for
i
in
int_hashes
]
def
take_events
()
->
Iterable
[
OffloadingEvent
]:
yield
OffloadingEvent
(
block_hashes
=
to_hashes
([
1
,
2
,
3
]),
block_size
=
16
,
medium
=
"A"
,
removed
=
False
)
yield
OffloadingEvent
(
block_hashes
=
to_hashes
([
4
,
5
,
6
]),
block_size
=
32
,
medium
=
"B"
,
removed
=
True
)
runner
.
manager
.
take_events
.
side_effect
=
take_events
events
=
list
(
runner
.
scheduler_connector
.
take_events
())
assert
len
(
events
)
==
2
event
=
events
[
0
]
assert
isinstance
(
event
,
BlockStored
)
assert
event
.
block_hashes
==
to_hashes
([
1
,
2
,
3
])
assert
event
.
block_size
==
16
assert
event
.
medium
==
"A"
assert
event
.
token_ids
==
[]
assert
event
.
parent_block_hash
is
None
assert
event
.
lora_id
is
None
assert
event
.
lora_name
is
None
event
=
events
[
1
]
assert
isinstance
(
event
,
BlockRemoved
)
assert
event
.
block_hashes
==
to_hashes
([
4
,
5
,
6
])
assert
event
.
medium
==
"B"
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
True
,
False
])
def
test_request_preemption
(
request_runner
,
async_scheduling
:
bool
):
offloaded_block_size
=
12
gpu_block_size
=
4
num_gpu_blocks
=
100
runner
=
request_runner
(
offloaded_block_size
=
offloaded_block_size
,
gpu_block_size
=
gpu_block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
async_scheduling
=
async_scheduling
,
)
free_block_queue
=
runner
.
scheduler
.
kv_cache_manager
.
block_pool
.
free_block_queue
num_free_blocks_empty
=
free_block_queue
.
num_free_blocks
# 2 blocks, store all, without flushing
# blocks = [0, 1, 2], [3, 4, 5]
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
*
2
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
],
complete_transfers
=
False
,
)
# decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
2
*
offloaded_block_size
-
gpu_block_size
),
complete_transfers
=
False
,
)
# simulate KV cache running out of space
free_block_queue
.
num_free_blocks
=
0
# request should be preempted now
runner
.
run
(
decoded_tokens
=
[],
complete_transfers
=
False
,
expected_flushed_gpu_block_indexes
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
),
expected_stored_gpu_block_indexes
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
),
)
# restore KV cache space and reset GPU prefix cache
free_block_queue
.
num_free_blocks
=
num_free_blocks_empty
runner
.
scheduler
.
reset_prefix_cache
()
# request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner
.
manager
.
lookup
.
return_value
=
3
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
gpu_block_size
,
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
),
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
(
9
,
10
,
11
),
)
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
True
,
False
])
def
test_concurrent_lookups_of_the_same_prefix
(
request_runner
,
async_scheduling
:
bool
):
offloaded_block_size
=
12
gpu_block_size
=
4
num_gpu_blocks
=
100
runner
=
request_runner
(
offloaded_block_size
=
offloaded_block_size
,
gpu_block_size
=
gpu_block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
async_scheduling
=
async_scheduling
,
)
# store 1 blocks
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
(
0
,
1
,
2
),
)
# start a request to load the first block, but don't complete
runner
.
scheduler
.
reset_prefix_cache
()
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[],
complete_transfers
=
False
,
)
# request triggered a load
transfer_jobs
=
list
(
runner
.
offloading_spec
.
handler
.
transfer_specs
)
assert
transfer_jobs
# start a new request to load the same first block
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[],
complete_transfers
=
False
,
)
# request did not trigger a load
assert
transfer_jobs
==
list
(
runner
.
offloading_spec
.
handler
.
transfer_specs
)
# complete transfers
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
([])
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
),
)
# second request will use the GPU prefix cache
assert
transfer_jobs
==
list
(
runner
.
offloading_spec
.
handler
.
transfer_specs
)
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
True
,
False
])
def
test_abort_loading_requests
(
request_runner
,
async_scheduling
:
bool
):
offloaded_block_size
=
12
gpu_block_size
=
4
num_gpu_blocks
=
100
runner
=
request_runner
(
offloaded_block_size
=
offloaded_block_size
,
gpu_block_size
=
gpu_block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
async_scheduling
=
async_scheduling
,
)
# store 1 blocks
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
(
0
,
1
,
2
),
)
# start a request to load the first block, but don't complete
runner
.
scheduler
.
reset_prefix_cache
()
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[],
complete_transfers
=
False
,
)
# request triggered a load
transfer_jobs
=
list
(
runner
.
offloading_spec
.
handler
.
transfer_specs
)
assert
transfer_jobs
# abort request
req_id
=
str
(
runner
.
req_id
)
runner
.
scheduler
.
finish_requests
((
req_id
,),
RequestStatus
.
FINISHED_ABORTED
)
# verify request is not deleted
assert
req_id
in
runner
.
scheduler
.
requests
# complete loading request
runner
.
run
(
decoded_tokens
=
[],
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
),
)
# assert request is deleted
assert
req_id
not
in
runner
.
scheduler
.
requests
class
TestOffloadingConnectorStats
:
"""Tests for OffloadingConnector stats reconstruction and operations."""
def
test_build_kv_connector_stats_with_none
(
self
):
"""Test that build_kv_connector_stats returns empty stats when given None."""
stats
=
OffloadingConnector
.
build_kv_connector_stats
(
data
=
None
)
assert
stats
is
not
None
assert
isinstance
(
stats
,
OffloadingConnectorStats
)
assert
len
(
stats
.
data
)
==
0
assert
stats
.
is_empty
()
def
test_build_kv_connector_stats_with_empty_dict
(
self
):
"""Test that build_kv_connector_stats returns empty stats with empty dict."""
stats
=
OffloadingConnector
.
build_kv_connector_stats
(
data
=
{})
assert
stats
is
not
None
assert
isinstance
(
stats
,
OffloadingConnectorStats
)
assert
len
(
stats
.
data
)
==
0
assert
stats
.
is_empty
()
def
test_build_kv_connector_stats_reconstructs_offload_stats
(
self
):
"""Test that OffloadingConnector stats are properly reconstructed with
correct data."""
serialized_data
=
{
"CPU_to_GPU"
:
[
{
"op_size"
:
16
,
"op_time"
:
1.0
},
{
"op_size"
:
8
,
"op_time"
:
0.5
},
],
"GPU_to_CPU"
:
[
{
"op_size"
:
1
,
"op_time"
:
0.1
},
{
"op_size"
:
2
,
"op_time"
:
0.2
},
],
}
stats
=
OffloadingConnector
.
build_kv_connector_stats
(
data
=
serialized_data
)
offload_connector_stats
=
stats
assert
isinstance
(
offload_connector_stats
,
OffloadingConnectorStats
)
assert
offload_connector_stats
.
data
[
"CPU_to_GPU"
]
==
[
{
"op_size"
:
16
,
"op_time"
:
1.0
},
{
"op_size"
:
8
,
"op_time"
:
0.5
},
]
assert
offload_connector_stats
.
data
[
"GPU_to_CPU"
]
==
[
{
"op_size"
:
1
,
"op_time"
:
0.1
},
{
"op_size"
:
2
,
"op_time"
:
0.2
},
]
def
test_aggregate_same_connector
(
self
):
"""Test aggregating stats from the same connector type."""
stats1
=
OffloadingConnectorStats
(
data
=
{
"CPU_to_GPU"
:
[
{
"op_size"
:
16
,
"op_time"
:
1.0
},
{
"op_size"
:
8
,
"op_time"
:
0.5
},
],
"GPU_to_CPU"
:
[
{
"op_size"
:
1
,
"op_time"
:
0.1
},
{
"op_size"
:
2
,
"op_time"
:
0.2
},
],
}
)
stats2
=
OffloadingConnectorStats
(
data
=
{
"CPU_to_GPU"
:
[
{
"op_size"
:
3
,
"op_time"
:
0.2
},
{
"op_size"
:
7
,
"op_time"
:
0.9
},
],
"GPU_to_CPU"
:
[{
"op_size"
:
16
,
"op_time"
:
2
}],
}
)
result
=
stats1
.
aggregate
(
stats2
)
assert
result
is
stats1
# Should return self
offload_connector_stats
=
result
assert
offload_connector_stats
.
data
[
"CPU_to_GPU"
]
==
[
{
"op_size"
:
16
,
"op_time"
:
1.0
},
{
"op_size"
:
8
,
"op_time"
:
0.5
},
{
"op_size"
:
3
,
"op_time"
:
0.2
},
{
"op_size"
:
7
,
"op_time"
:
0.9
},
]
assert
offload_connector_stats
.
data
[
"GPU_to_CPU"
]
==
[
{
"op_size"
:
1
,
"op_time"
:
0.1
},
{
"op_size"
:
2
,
"op_time"
:
0.2
},
{
"op_size"
:
16
,
"op_time"
:
2
},
]
def
test_reduce
(
self
):
"""Test that reduce() correctly reduces all nested connector stats."""
stats
=
OffloadingConnectorStats
(
data
=
{
"CPU_to_GPU"
:
[
{
"op_size"
:
16
,
"op_time"
:
1.0
},
{
"op_size"
:
8
,
"op_time"
:
0.5
},
{
"op_size"
:
3
,
"op_time"
:
0.2
},
{
"op_size"
:
7
,
"op_time"
:
0.9
},
],
"GPU_to_CPU"
:
[
{
"op_size"
:
1
,
"op_time"
:
0.1
},
{
"op_size"
:
2
,
"op_time"
:
0.2
},
{
"op_size"
:
16
,
"op_time"
:
2
},
],
}
)
reduced
=
stats
.
reduce
()
assert
isinstance
(
reduced
,
dict
)
# Check that the stats were reduced (should have aggregated values)
assert
"CPU_to_GPU_total_bytes"
in
reduced
assert
"CPU_to_GPU_total_time"
in
reduced
assert
"GPU_to_CPU_total_bytes"
in
reduced
assert
"GPU_to_CPU_total_time"
in
reduced
assert
reduced
[
"CPU_to_GPU_total_bytes"
]
==
34
assert
reduced
[
"CPU_to_GPU_total_time"
]
==
2.6
assert
reduced
[
"GPU_to_CPU_total_time"
]
==
2.3
assert
reduced
[
"GPU_to_CPU_total_bytes"
]
==
19
def
test_reset
(
self
):
"""Test that reset() resets all nested connector stats."""
offload_connector_stats
=
OffloadingConnectorStats
(
data
=
{
"CPU_to_GPU"
:
[
{
"op_size"
:
3
,
"op_time"
:
0.2
},
{
"op_size"
:
7
,
"op_time"
:
0.9
},
],
"GPU_to_CPU"
:
[{
"op_size"
:
16
,
"op_time"
:
2
}],
}
)
assert
not
offload_connector_stats
.
is_empty
()
offload_connector_stats
.
reset
()
# After reset, stats should be empty
assert
offload_connector_stats
.
is_empty
()
assert
len
(
offload_connector_stats
.
data
)
==
0
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
7cc302dd
...
@@ -91,6 +91,9 @@ def clear_kv_transfer():
...
@@ -91,6 +91,9 @@ def clear_kv_transfer():
yield
yield
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
ensure_kv_transfer_shutdown
()
ensure_kv_transfer_shutdown
()
# Reset any KV cache layout override set during tests so it doesn't
# leak into tests in other modules.
set_kv_cache_layout
(
None
)
def
get_default_xfer_telemetry
(
def
get_default_xfer_telemetry
(
...
...
tests/v1/kv_offload/test_cpu_gpu.py
View file @
7cc302dd
...
@@ -6,32 +6,20 @@ import time
...
@@ -6,32 +6,20 @@ import time
import
pytest
import
pytest
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.v1.kv_offload.mediums
import
CPULoadStoreSpec
,
GPULoadStoreSpec
from
vllm.v1.kv_offload.mediums
import
CPULoadStoreSpec
,
GPULoadStoreSpec
from
vllm.v1.kv_offload.spec
import
(
CanonicalKVCacheRef
,
CanonicalKVCaches
,
CanonicalKVCacheTensor
,
)
from
vllm.v1.kv_offload.worker.cpu_gpu
import
CpuGpuOffloadingHandlers
from
vllm.v1.kv_offload.worker.cpu_gpu
import
CpuGpuOffloadingHandlers
BACKENDS_TO_TEST
=
[
FlashAttentionBackend
]
if
not
current_platform
.
is_rocm
():
from
vllm.v1.attention.backends.flashinfer
import
FlashInferBackend
BACKENDS_TO_TEST
.
append
(
FlashInferBackend
)
from
vllm.v1.attention.backends.mla.flashattn_mla
import
FlashAttnMLABackend
BACKENDS_TO_TEST
.
append
(
FlashAttnMLABackend
)
NUM_GPU_BLOCKS
=
[
64
]
NUM_GPU_BLOCKS
=
[
64
]
NUM_CPU_BLOCKS
=
[
256
]
NUM_CPU_BLOCKS
=
[
256
]
KERNEL_BLOCK_SIZES
=
[
16
]
GPU_PAGE_SIZES
=
[
512
,
1024
]
LOGICAL_BLOCK_SIZES
=
[
16
,
32
]
BLOCK_SIZE_FACTORS
=
[
1
,
3
]
LOGICAL_BLOCKS_PER_CPU_BLOCK
=
[
1
,
3
]
NUM_TENSORS
=
[
4
]
HEAD_SIZES
=
[
64
]
NUM_HEADS
=
[
8
]
NUM_LAYERS
=
[
4
]
DTYPES
=
[
torch
.
bfloat16
]
SEEDS
=
[
0
]
SEEDS
=
[
0
]
CUDA_DEVICES
=
[
"cuda:0"
]
CUDA_DEVICES
=
[
"cuda:0"
]
NUM_MAPPINGS
=
[
3
]
NUM_MAPPINGS
=
[
3
]
...
@@ -39,15 +27,11 @@ NUM_MAPPINGS = [3]
...
@@ -39,15 +27,11 @@ NUM_MAPPINGS = [3]
@
pytest
.
mark
.
parametrize
(
"gpu_to_cpu"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"gpu_to_cpu"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_mappings"
,
NUM_MAPPINGS
)
@
pytest
.
mark
.
parametrize
(
"num_mappings"
,
NUM_MAPPINGS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"gpu_page_size_bytes"
,
GPU_PAGE_SIZES
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"block_size_factor"
,
BLOCK_SIZE_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"kernel_block_size"
,
KERNEL_BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"logical_block_size"
,
LOGICAL_BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"logical_blocks_per_cpu_block"
,
LOGICAL_BLOCKS_PER_CPU_BLOCK
)
@
pytest
.
mark
.
parametrize
(
"num_gpu_blocks"
,
NUM_GPU_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_gpu_blocks"
,
NUM_GPU_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_cpu_blocks"
,
NUM_CPU_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_cpu_blocks"
,
NUM_CPU_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_layers"
,
NUM_LAYERS
)
@
pytest
.
mark
.
parametrize
(
"num_tensors"
,
NUM_TENSORS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -55,113 +39,89 @@ def test_transfer(
...
@@ -55,113 +39,89 @@ def test_transfer(
default_vllm_config
,
default_vllm_config
,
gpu_to_cpu
:
bool
,
gpu_to_cpu
:
bool
,
num_mappings
:
int
,
num_mappings
:
int
,
head_size
:
int
,
gpu_page_size_bytes
:
int
,
num_heads
:
int
,
block_size_factor
:
int
,
kernel_block_size
:
int
,
logical_block_size
:
int
,
logical_blocks_per_cpu_block
:
int
,
num_gpu_blocks
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
num_cpu_blocks
:
int
,
num_layers
:
int
,
num_tensors
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
)
->
None
:
)
->
None
:
set_random_seed
(
seed
)
set_random_seed
(
seed
)
# create per-layer GPU KV caches based on available attn_backends
# build CanonicalKVCacheTensor list: one per tensor
attn_backends_list
=
BACKENDS_TO_TEST
kv_cache_tensors
:
list
[
CanonicalKVCacheTensor
]
=
[]
for
i
in
range
(
num_tensors
):
assert
logical_block_size
%
kernel_block_size
==
0
gpu_tensor
=
torch
.
randint
(
kernel_blocks_per_gpu_block
=
logical_block_size
//
kernel_block_size
-
128
,
num_gpu_kernel_blocks
=
num_gpu_blocks
*
kernel_blocks_per_gpu_block
127
,
(
num_gpu_blocks
,
gpu_page_size_bytes
),
gpu_caches
=
{}
dtype
=
torch
.
int8
,
attn_backends
=
{}
device
=
device
,
for
i
in
range
(
num_layers
):
)
layer_name
=
f
"layer
{
i
}
"
kv_cache_tensors
.
append
(
CanonicalKVCacheTensor
(
attn_backend
=
attn_backends_list
[
i
%
len
(
attn_backends_list
)]
tensor
=
gpu_tensor
,
attn_backends
[
layer_name
]
=
attn_backend
page_size_bytes
=
gpu_page_size_bytes
,
)
gpu_cache_shape
=
attn_backend
.
get_kv_cache_shape
(
num_gpu_kernel_blocks
,
kernel_block_size
,
num_heads
,
head_size
)
)
gpu_caches
[
layer_name
]
=
torch
.
rand
(
gpu_cache_shape
,
dtype
=
dtype
,
device
=
device
)
# create handler
# one group containing all tensors, one data ref per tensor
cpu_block_size
=
logical_blocks_per_cpu_block
*
logical_block_size
kv_cache_groups_data_refs
:
list
[
list
[
CanonicalKVCacheRef
]]
=
[
kernel_blocks_per_cpu_block
=
cpu_block_size
//
kernel_block_size
[
CanonicalKVCacheRef
(
tensor_idx
=
i
,
page_size_bytes
=
gpu_page_size_bytes
,
)
for
i
in
range
(
num_tensors
)
]
]
kv_caches
=
CanonicalKVCaches
(
tensors
=
kv_cache_tensors
,
group_data_refs
=
kv_cache_groups_data_refs
,
)
handlers
=
CpuGpuOffloadingHandlers
(
handlers
=
CpuGpuOffloadingHandlers
(
attn_backends
=
attn_backends
,
kv_caches
=
kv_caches
,
gpu_block_size
=
logical_block_size
,
block_size_factor
=
block_size_factor
,
cpu_block_size
=
cpu_block_size
,
num_cpu_blocks
=
num_cpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
gpu_caches
=
gpu_caches
,
)
)
# select block mappings
# select block mappings
gpu_blocks
=
random
.
sample
(
gpu_blocks
=
random
.
sample
(
range
(
num_gpu_blocks
),
num_mappings
*
block_size_factor
)
range
(
num_gpu_blocks
),
num_mappings
*
logical_blocks_per_cpu_block
)
cpu_blocks
=
random
.
sample
(
range
(
num_cpu_blocks
),
num_mappings
)
cpu_blocks
=
random
.
sample
(
range
(
num_cpu_blocks
),
num_mappings
)
# convert gpu blocks to kernel block size
# expand cpu blocks to gpu-page granularity for uniform comparison:
gpu_blocks_in_kernel_block_size
=
[]
# each cpu block maps to block_size_factor consecutive sub-blocks
for
gpu_block
in
gpu_blocks
:
cpu_blocks_expanded
=
[
base_block_id
=
gpu_block
*
kernel_blocks_per_gpu_block
cpu_block
*
block_size_factor
+
j
for
i
in
range
(
kernel_blocks_per_gpu_block
):
for
cpu_block
in
cpu_blocks
gpu_blocks_in_kernel_block_size
.
append
(
i
+
base_block_id
)
for
j
in
range
(
block_size_factor
)
]
# convert cpu blocks to gpu block size
cpu_blocks_in_kernel_block_size
=
[]
# maybe skip some GPU blocks to test reading from the middle of a CPU block
for
cpu_block
in
cpu_blocks
:
base_block_id
=
cpu_block
*
kernel_blocks_per_cpu_block
for
i
in
range
(
kernel_blocks_per_cpu_block
):
cpu_blocks_in_kernel_block_size
.
append
(
i
+
base_block_id
)
# maybe skip some GPU block to test reading from the middle of a CPU block
if
not
gpu_to_cpu
:
if
not
gpu_to_cpu
:
gpu_blocks_to_skip
=
logical_blocks_per_cpu_block
-
1
blocks_to_skip
=
block_size_factor
-
1
gpu_blocks
=
gpu_blocks
[
gpu_blocks_to_skip
:]
gpu_blocks
=
gpu_blocks
[
blocks_to_skip
:]
kernel_blocks_to_skip
=
gpu_blocks_to_skip
*
kernel_blocks_per_gpu_block
cpu_blocks_expanded
=
cpu_blocks_expanded
[
blocks_to_skip
:]
gpu_blocks_in_kernel_block_size
=
gpu_blocks_in_kernel_block_size
[
kernel_blocks_to_skip
:
]
cpu_blocks_in_kernel_block_size
=
cpu_blocks_in_kernel_block_size
[
kernel_blocks_to_skip
:
]
# set transfer direction
# set transfer direction
if
gpu_to_cpu
:
if
gpu_to_cpu
:
handler
=
handlers
.
gpu_to_cpu_handler
handler
=
handlers
.
gpu_to_cpu_handler
src_blocks
=
gpu_blocks
src_spec
=
GPULoadStoreSpec
(
gpu_blocks
,
group_sizes
=
(
len
(
gpu_blocks
),))
dst_blocks
=
cpu_blocks
dst_spec
=
CPULoadStoreSpec
(
cpu_blocks
)
src_spec
=
GPULoadStoreSpec
(
src_blocks
,
group_sizes
=
(
len
(
src_blocks
),))
dst_to_src
=
dict
(
zip
(
cpu_blocks_expanded
,
gpu_blocks
))
dst_spec
=
CPULoadStoreSpec
(
dst_blocks
)
num_dst_sub_blocks
=
num_cpu_blocks
*
block_size_factor
src_blocks_in_kernel_block_size
=
gpu_blocks_in_kernel_block_size
dst_blocks_in_kernel_block_size
=
cpu_blocks_in_kernel_block_size
dst_size_in_kernel_blocks
=
num_cpu_blocks
*
kernel_blocks_per_cpu_block
else
:
else
:
handler
=
handlers
.
cpu_to_gpu_handler
handler
=
handlers
.
cpu_to_gpu_handler
src_blocks
=
cpu_blocks
src_spec
=
CPULoadStoreSpec
(
cpu_blocks
)
dst_blocks
=
gpu_blocks
dst_spec
=
GPULoadStoreSpec
(
gpu_blocks
,
group_sizes
=
(
len
(
gpu_blocks
),))
src_spec
=
CPULoadStoreSpec
(
src_blocks
)
dst_to_src
=
dict
(
zip
(
gpu_blocks
,
cpu_blocks_expanded
))
dst_spec
=
GPULoadStoreSpec
(
dst_blocks
,
group_sizes
=
(
len
(
dst_blocks
),))
num_dst_sub_blocks
=
num_gpu_blocks
src_blocks_in_kernel_block_size
=
cpu_blocks_in_kernel_block_size
dst_blocks_in_kernel_block_size
=
gpu_blocks_in_kernel_block_size
dst_size_in_kernel_blocks
=
num_gpu_blocks
*
kernel_blocks_per_gpu_block
# build dst -> src mapping
dst_to_src
=
{}
for
src_block
,
dst_block
in
zip
(
src_blocks_in_kernel_block_size
,
dst_blocks_in_kernel_block_size
):
dst_to_src
[
dst_block
]
=
src_block
# clone src and dst tensors before transfer
# clone src and dst tensors before transfer
orig_src_
cache
s
=
[
x
.
clone
()
for
x
in
handler
.
src_tensors
]
orig_src_
tensor
s
=
[
x
.
clone
()
for
x
in
handler
.
src_tensors
]
orig_dst_
cache
s
=
[
x
.
clone
()
for
x
in
handler
.
dst_tensors
]
orig_dst_
tensor
s
=
[
x
.
clone
()
for
x
in
handler
.
dst_tensors
]
# call transfer function
# call transfer function
start_time
=
time
.
time
()
start_time
=
time
.
time
()
...
@@ -180,11 +140,8 @@ def test_transfer(
...
@@ -180,11 +140,8 @@ def test_transfer(
if
gpu_to_cpu
if
gpu_to_cpu
else
(
"CPU"
,
"GPU"
)
else
(
"CPU"
,
"GPU"
)
)
)
assert
(
assert
finished
[
0
].
transfer_size
==
(
finished
[
0
].
transfer_size
len
(
gpu_blocks
)
*
handler
.
group_block_size_in_bytes
[
0
]
==
handler
.
total_block_size_in_bytes
*
handler
.
dst_block_size_factor
*
len
(
dst_blocks
)
)
)
assert
finished
[
0
].
transfer_time
>
0
assert
finished
[
0
].
transfer_time
>
0
assert
finished
[
0
].
transfer_time
<
(
time
.
time
()
-
start_time
)
assert
finished
[
0
].
transfer_time
<
(
time
.
time
()
-
start_time
)
...
@@ -192,19 +149,23 @@ def test_transfer(
...
@@ -192,19 +149,23 @@ def test_transfer(
time
.
sleep
(
0.1
)
time
.
sleep
(
0.1
)
# verify src tensors did not change
# verify src tensors did not change
for
orig_tensor
,
tensor
in
zip
(
orig_src_
cache
s
,
handler
.
src_tensors
):
for
orig_tensor
,
tensor
in
zip
(
orig_src_
tensor
s
,
handler
.
src_tensors
):
assert
torch
.
equal
(
orig_tensor
,
tensor
)
assert
torch
.
equal
(
orig_tensor
,
tensor
)
# verify dst tensors
# verify dst tensors at gpu-page granularity.
for
dst_block
in
range
(
dst_size_in_kernel_blocks
):
for
src_tensor
,
dst_tensor
,
orig_dst_tensor
in
zip
(
src_block_candidate
=
dst_to_src
.
get
(
dst_block
)
handler
.
src_tensors
,
for
src_cache
,
dst_cache
,
orig_dst_cache
in
zip
(
handler
.
dst_tensors
,
handler
.
src_tensors
,
orig_dst_tensors
,
handler
.
dst_tensors
,
):
orig_dst_caches
,
# view both GPU and CPU tensors as (n, gpu_page_size_bytes) for comparison.
):
src_view
=
src_tensor
.
view
(
-
1
,
gpu_page_size_bytes
)
if
src_block_candidate
is
not
None
:
dst_view
=
dst_tensor
.
view
(
-
1
,
gpu_page_size_bytes
)
expected_value
=
src_cache
[
src_block_candidate
]
orig_dst_view
=
orig_dst_tensor
.
view
(
-
1
,
gpu_page_size_bytes
)
for
dst_sub_block
in
range
(
num_dst_sub_blocks
):
src_sub_block
=
dst_to_src
.
get
(
dst_sub_block
)
if
src_sub_block
is
not
None
:
expected
=
src_view
[
src_sub_block
]
else
:
else
:
expected
_value
=
orig_dst_
cache
[
dst
_block
]
expected
=
orig_dst_
view
[
dst_sub
_block
]
torch
.
testing
.
assert_close
(
dst_
cache
[
dst
_block
].
cpu
(),
expected
_value
.
cpu
())
torch
.
testing
.
assert_close
(
dst_
view
[
dst_sub
_block
].
cpu
(),
expected
.
cpu
())
vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py
View file @
7cc302dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
replace
import
torch
import
torch
...
@@ -18,7 +19,17 @@ from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import (
...
@@ -18,7 +19,17 @@ from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import (
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.kv_offload.spec
import
OffloadingSpec
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
MambaSpec
,
UniformTypeKVCacheSpecs
,
)
from
vllm.v1.kv_offload.spec
import
(
CanonicalKVCacheRef
,
CanonicalKVCaches
,
CanonicalKVCacheTensor
,
OffloadingSpec
,
)
from
vllm.v1.kv_offload.worker.worker
import
(
from
vllm.v1.kv_offload.worker.worker
import
(
OffloadingWorker
,
OffloadingWorker
,
TransferSpec
,
TransferSpec
,
...
@@ -53,17 +64,13 @@ class OffloadingConnectorWorker:
...
@@ -53,17 +64,13 @@ class OffloadingConnectorWorker:
self
.
_job_counter
=
job_id
+
1
self
.
_job_counter
=
job_id
+
1
return
job_id
return
job_id
def
_register_handlers
(
def
_register_handlers
(
self
,
kv_caches
:
CanonicalKVCaches
):
self
,
for
src_cls
,
dst_cls
,
handler
in
self
.
spec
.
get_handlers
(
kv_caches
):
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
attn_backends
:
dict
[
str
,
type
[
AttentionBackend
]],
):
for
src_cls
,
dst_cls
,
handler
in
self
.
spec
.
get_handlers
(
kv_caches
,
attn_backends
):
self
.
worker
.
register_handler
(
src_cls
,
dst_cls
,
handler
)
self
.
worker
.
register_handler
(
src_cls
,
dst_cls
,
handler
)
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
|
list
[
torch
.
Tensor
]]
):
layer_names
=
list
(
kv_caches
.
keys
())
layer_names
=
list
(
kv_caches
.
keys
())
layers
=
get_layers_from_vllm_config
(
layers
=
get_layers_from_vllm_config
(
self
.
spec
.
vllm_config
,
self
.
spec
.
vllm_config
,
...
@@ -73,16 +80,221 @@ class OffloadingConnectorWorker:
...
@@ -73,16 +80,221 @@ class OffloadingConnectorWorker:
attn_backends
=
{
attn_backends
=
{
layer_name
:
layers
[
layer_name
].
get_attn_backend
()
layer_name
:
layers
[
layer_name
].
get_attn_backend
()
for
layer_name
in
layer_names
for
layer_name
in
layer_names
if
layer_name
in
layers
}
}
self
.
_register_handlers
(
kv_caches
,
attn_backends
)
# layer_name -> list of matching KV cache tensors
# such that each tensor starts with the num_blocks dimension.
# FlashAttention layers which use the (2, num_blocks, ...) layout
# will possibly map to 2 tensors, one per K and one per V.
# All other layers will probably map to a single tensor.
tensors_per_block
:
dict
[
str
,
tuple
[
torch
.
Tensor
,
...]]
=
{}
# layer_name -> size of (un-padded) page in bytes
unpadded_page_size_bytes
:
dict
[
str
,
int
]
=
{}
# layer_name -> size of page in bytes
page_size_bytes
:
dict
[
str
,
int
]
=
{}
for
kv_cache_group
in
self
.
spec
.
kv_cache_config
.
kv_cache_groups
:
group_layer_names
=
kv_cache_group
.
layer_names
group_kv_cache_spec
=
kv_cache_group
.
kv_cache_spec
if
isinstance
(
group_kv_cache_spec
,
UniformTypeKVCacheSpecs
):
per_layer_specs
=
group_kv_cache_spec
.
kv_cache_specs
else
:
per_layer_specs
=
{}
for
layer_name
in
group_layer_names
:
layer_kv_cache_spec
=
per_layer_specs
.
get
(
layer_name
,
group_kv_cache_spec
)
if
isinstance
(
layer_kv_cache_spec
,
AttentionSpec
):
layer_kv_cache
=
kv_caches
[
layer_name
]
assert
isinstance
(
layer_kv_cache
,
torch
.
Tensor
)
assert
layer_kv_cache
.
storage_offset
()
==
0
# get the logical dimension for num_blocks
test_shape
=
attn_backends
[
layer_name
].
get_kv_cache_shape
(
num_blocks
=
1234
,
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
256
,
)
num_blocks_logical_dim
=
test_shape
.
index
(
1234
)
# sort the logical dimensions by stride (high to low)
# to get a physical-to-logical mapping:
# physical_to_logical[physical_pos] = logical_dim
logical_strides
=
layer_kv_cache
.
stride
()
physical_to_logical
=
sorted
(
range
(
len
(
logical_strides
)),
key
=
lambda
idx
:
logical_strides
[
idx
],
reverse
=
True
,
)
num_blocks_physical_dim
=
physical_to_logical
.
index
(
num_blocks_logical_dim
)
if
num_blocks_physical_dim
==
0
:
num_blocks
=
layer_kv_cache
.
shape
[
num_blocks_logical_dim
]
storage
=
layer_kv_cache
.
untyped_storage
()
page
=
layer_kv_cache_spec
.
page_size_bytes
tensors_per_block
[
layer_name
]
=
(
torch
.
tensor
(
[],
dtype
=
torch
.
int8
,
device
=
layer_kv_cache
.
device
,
)
.
set_
(
storage
)
.
view
(
num_blocks
,
page
),
)
page_size_bytes
[
layer_name
]
=
(
layer_kv_cache_spec
.
page_size_bytes
)
else
:
# Flash Attention case: (2, num_blocks, ...)
assert
test_shape
[
0
]
==
2
assert
physical_to_logical
[
0
]
==
0
assert
num_blocks_physical_dim
==
1
# unbind the tensor to separate K and V tensors
num_blocks
=
layer_kv_cache
.
shape
[
num_blocks_logical_dim
]
half_page_size
=
layer_kv_cache_spec
.
page_size_bytes
//
2
storage
=
layer_kv_cache
.
untyped_storage
()
raw
=
(
torch
.
tensor
(
[],
dtype
=
torch
.
int8
,
device
=
layer_kv_cache
.
device
,
)
.
set_
(
storage
)
.
view
(
2
,
num_blocks
,
half_page_size
)
)
tensors_per_block
[
layer_name
]
=
tuple
(
raw
.
unbind
(
0
))
page_size_bytes
[
layer_name
]
=
half_page_size
unpadded_page_size_bytes
[
layer_name
]
=
page_size_bytes
[
layer_name
]
elif
isinstance
(
layer_kv_cache_spec
,
MambaSpec
):
state_tensors
=
kv_caches
[
layer_name
]
assert
isinstance
(
state_tensors
,
list
)
# re-construct the raw (num_blocks, page_size) tensor
# from the first state tensor
assert
len
(
state_tensors
)
>
0
first_state_tensor
=
state_tensors
[
0
]
assert
first_state_tensor
.
storage_offset
()
==
0
num_blocks
=
first_state_tensor
.
shape
[
0
]
tensor
=
(
torch
.
tensor
(
[],
dtype
=
torch
.
int8
,
device
=
first_state_tensor
.
device
,
)
.
set_
(
first_state_tensor
.
untyped_storage
())
.
view
((
num_blocks
,
layer_kv_cache_spec
.
page_size_bytes
))
)
tensors_per_block
[
layer_name
]
=
(
tensor
,)
page_size_bytes
[
layer_name
]
=
layer_kv_cache_spec
.
page_size_bytes
unpadded_page_size_bytes
[
layer_name
]
=
replace
(
layer_kv_cache_spec
,
page_size_padded
=
None
).
page_size_bytes
else
:
raise
NotImplementedError
block_tensors
:
list
[
CanonicalKVCacheTensor
]
=
[]
block_data_refs
:
dict
[
str
,
list
[
CanonicalKVCacheRef
]]
=
defaultdict
(
list
)
for
kv_cache_tensor
in
self
.
spec
.
kv_cache_config
.
kv_cache_tensors
:
tensor_layer_names
=
kv_cache_tensor
.
shared_by
# verify all layers in the group reference the exact same tensors
assert
len
({
len
(
tensors_per_block
[
n
])
for
n
in
tensor_layer_names
})
==
1
assert
(
len
({
tensors_per_block
[
n
][
0
].
data_ptr
()
for
n
in
tensor_layer_names
})
==
1
)
assert
(
len
({
tensors_per_block
[
n
][
0
].
stride
()
for
n
in
tensor_layer_names
})
==
1
)
# pick the first layer to represent the group
first_layer_name
=
tensor_layer_names
[
0
]
for
tensor
in
tensors_per_block
[
first_layer_name
]:
block_tensors
.
append
(
CanonicalKVCacheTensor
(
tensor
=
tensor
,
page_size_bytes
=
page_size_bytes
[
first_layer_name
],
)
)
curr_tensor_idx
=
len
(
block_tensors
)
-
1
for
layer_name
in
tensor_layer_names
:
block_data_refs
[
layer_name
].
append
(
CanonicalKVCacheRef
(
tensor_idx
=
curr_tensor_idx
,
page_size_bytes
=
(
unpadded_page_size_bytes
[
layer_name
]),
)
)
group_data_refs
:
list
[
list
[
CanonicalKVCacheRef
]]
=
[]
for
kv_cache_group
in
self
.
spec
.
kv_cache_config
.
kv_cache_groups
:
group_refs
:
list
[
CanonicalKVCacheRef
]
=
[]
for
layer_name
in
kv_cache_group
.
layer_names
:
group_refs
+=
block_data_refs
[
layer_name
]
group_data_refs
.
append
(
group_refs
)
canonical_kv_caches
=
CanonicalKVCaches
(
tensors
=
block_tensors
,
group_data_refs
=
group_data_refs
,
)
self
.
_register_handlers
(
canonical_kv_caches
)
def
register_cross_layers_kv_cache
(
def
register_cross_layers_kv_cache
(
self
,
kv_cache
:
torch
.
Tensor
,
attn_backend
:
type
[
AttentionBackend
]
self
,
kv_cache
:
torch
.
Tensor
,
attn_backend
:
type
[
AttentionBackend
]
):
):
cross_layer_name
=
"ALL_LAYERS"
# verify that num_blocks is at physical position 0 in the cross-layers
kv_caches
=
{
cross_layer_name
:
kv_cache
}
# tensor layout.
attn_backends
=
{
cross_layer_name
:
attn_backend
}
test_shape
=
attn_backend
.
get_kv_cache_shape
(
self
.
_register_handlers
(
kv_caches
,
attn_backends
)
num_blocks
=
1234
,
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
256
)
num_blocks_logical_dim
=
test_shape
.
index
(
1234
)
+
1
physical_to_logical
=
attn_backend
.
get_kv_cache_stride_order
(
include_num_layers_dimension
=
True
)
num_blocks_physical_dim
=
physical_to_logical
.
index
(
num_blocks_logical_dim
)
assert
num_blocks_physical_dim
==
0
kv_cache_groups
=
self
.
spec
.
kv_cache_config
.
kv_cache_groups
assert
len
(
kv_cache_groups
)
==
1
kv_cache_spec
=
kv_cache_groups
[
0
].
kv_cache_spec
num_layers
=
len
(
kv_cache_groups
[
0
].
layer_names
)
page_size_bytes
=
kv_cache_spec
.
page_size_bytes
*
num_layers
assert
kv_cache
.
storage_offset
()
==
0
storage
=
kv_cache
.
untyped_storage
()
assert
len
(
storage
)
%
page_size_bytes
==
0
num_blocks
=
len
(
storage
)
//
page_size_bytes
tensor
=
(
torch
.
tensor
(
[],
dtype
=
torch
.
int8
,
device
=
kv_cache
.
device
,
)
.
set_
(
storage
)
.
view
(
num_blocks
,
page_size_bytes
)
)
kv_cache_tensor
=
CanonicalKVCacheTensor
(
tensor
=
tensor
,
page_size_bytes
=
page_size_bytes
)
# in cross layers layout, there's currently only a single group
kv_cache_data_ref
=
CanonicalKVCacheRef
(
tensor_idx
=
0
,
page_size_bytes
=
page_size_bytes
)
canonical_kv_caches
=
CanonicalKVCaches
(
tensors
=
[
kv_cache_tensor
],
group_data_refs
=
[[
kv_cache_data_ref
]]
)
self
.
_register_handlers
(
canonical_kv_caches
)
def
handle_preemptions
(
self
,
kv_connector_metadata
:
OffloadingConnectorMetadata
):
def
handle_preemptions
(
self
,
kv_connector_metadata
:
OffloadingConnectorMetadata
):
for
job_id
,
transfer_spec
in
self
.
_unsubmitted_store_jobs
:
for
job_id
,
transfer_spec
in
self
.
_unsubmitted_store_jobs
:
...
...
vllm/v1/attention/backends/utils.py
View file @
7cc302dd
...
@@ -78,9 +78,10 @@ def get_kv_cache_layout():
...
@@ -78,9 +78,10 @@ def get_kv_cache_layout():
return
cache_layout
return
cache_layout
def
set_kv_cache_layout
(
cache_layout
:
KVCacheLayoutType
):
def
set_kv_cache_layout
(
cache_layout
:
KVCacheLayoutType
|
None
):
global
_KV_CACHE_LAYOUT_OVERRIDE
global
_KV_CACHE_LAYOUT_OVERRIDE
_KV_CACHE_LAYOUT_OVERRIDE
=
cache_layout
_KV_CACHE_LAYOUT_OVERRIDE
=
cache_layout
get_kv_cache_layout
.
cache_clear
()
@
dataclass
@
dataclass
...
...
vllm/v1/kv_offload/cpu/spec.py
View file @
7cc302dd
...
@@ -2,17 +2,14 @@
...
@@ -2,17 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterator
from
collections.abc
import
Iterator
import
torch
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_offload.abstract
import
LoadStoreSpec
,
OffloadingManager
from
vllm.v1.kv_offload.abstract
import
LoadStoreSpec
,
OffloadingManager
from
vllm.v1.kv_offload.cpu.manager
import
CPUOffloadingManager
from
vllm.v1.kv_offload.cpu.manager
import
CPUOffloadingManager
from
vllm.v1.kv_offload.mediums
import
CPULoadStoreSpec
,
GPULoadStoreSpec
from
vllm.v1.kv_offload.mediums
import
CPULoadStoreSpec
,
GPULoadStoreSpec
from
vllm.v1.kv_offload.reuse_manager
import
FilterReusedOffloadingManager
from
vllm.v1.kv_offload.reuse_manager
import
FilterReusedOffloadingManager
from
vllm.v1.kv_offload.spec
import
OffloadingSpec
from
vllm.v1.kv_offload.spec
import
CanonicalKVCaches
,
OffloadingSpec
from
vllm.v1.kv_offload.worker.cpu_gpu
import
CpuGpuOffloadingHandlers
from
vllm.v1.kv_offload.worker.cpu_gpu
import
CpuGpuOffloadingHandlers
from
vllm.v1.kv_offload.worker.worker
import
OffloadingHandler
from
vllm.v1.kv_offload.worker.worker
import
OffloadingHandler
...
@@ -90,9 +87,7 @@ class CPUOffloadingSpec(OffloadingSpec):
...
@@ -90,9 +87,7 @@ class CPUOffloadingSpec(OffloadingSpec):
return
self
.
_manager
return
self
.
_manager
def
get_handlers
(
def
get_handlers
(
self
,
self
,
kv_caches
:
CanonicalKVCaches
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
attn_backends
:
dict
[
str
,
type
[
AttentionBackend
]],
)
->
Iterator
[
tuple
[
type
[
LoadStoreSpec
],
type
[
LoadStoreSpec
],
OffloadingHandler
]]:
)
->
Iterator
[
tuple
[
type
[
LoadStoreSpec
],
type
[
LoadStoreSpec
],
OffloadingHandler
]]:
if
not
self
.
_handlers
:
if
not
self
.
_handlers
:
if
not
current_platform
.
is_cuda_alike
():
if
not
current_platform
.
is_cuda_alike
():
...
@@ -100,15 +95,10 @@ class CPUOffloadingSpec(OffloadingSpec):
...
@@ -100,15 +95,10 @@ class CPUOffloadingSpec(OffloadingSpec):
"CPU Offloading is currently only supported on CUDA-alike GPUs"
"CPU Offloading is currently only supported on CUDA-alike GPUs"
)
)
assert
len
(
self
.
gpu_block_size
)
==
1
gpu_block_size
=
self
.
gpu_block_size
[
0
]
self
.
_handlers
=
CpuGpuOffloadingHandlers
(
self
.
_handlers
=
CpuGpuOffloadingHandlers
(
attn_backends
=
attn_backends
,
kv_caches
=
kv_caches
,
gpu_block_size
=
gpu_block_size
,
block_size_factor
=
self
.
block_size_factor
,
cpu_block_size
=
gpu_block_size
*
self
.
block_size_factor
,
num_cpu_blocks
=
self
.
num_blocks
,
num_cpu_blocks
=
self
.
num_blocks
,
gpu_caches
=
kv_caches
,
)
)
assert
self
.
_handlers
is
not
None
assert
self
.
_handlers
is
not
None
...
...
vllm/v1/kv_offload/spec.py
View file @
7cc302dd
...
@@ -2,12 +2,12 @@
...
@@ -2,12 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Iterator
from
collections.abc
import
Iterator
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
import
torch
import
torch
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.kv_offload.abstract
import
LoadStoreSpec
,
OffloadingManager
from
vllm.v1.kv_offload.abstract
import
LoadStoreSpec
,
OffloadingManager
from
vllm.v1.kv_offload.worker.worker
import
OffloadingHandler
from
vllm.v1.kv_offload.worker.worker
import
OffloadingHandler
...
@@ -18,6 +18,56 @@ if TYPE_CHECKING:
...
@@ -18,6 +18,56 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
dataclass
class
CanonicalKVCacheTensor
:
"""
A canonicalized KV cache tensor whose first dimension is num_blocks.
For attention backends where the raw tensor has num_blocks at a
non-leading physical dimension (e.g. FlashAttention's
(2, num_blocks, ...) layout), the tensor is split so that each
resulting CanonicalKVCacheTensor starts with (num_blocks, ...).
"""
# The KV cache tensor with shape (num_blocks, ...)
tensor
:
torch
.
Tensor
# The (possibly padded) page size per block in bytes
page_size_bytes
:
int
@
dataclass
class
CanonicalKVCacheRef
:
"""
Per-layer (or group of layers) reference to a specific (by index)
CanonicalKVCacheTensor and records the un-padded page size used by that layer.
"""
# Index into the list of CanonicalKVCacheTensor objects
tensor_idx
:
int
# The un-padded page size per block in bytes
page_size_bytes
:
int
@
dataclass
class
CanonicalKVCaches
:
"""
Canonicalized block-level representation of the KV caches.
Composed of:
- Unique list of KV cache data tensors,
each with shape (num_blocks, page_size_in_bytes) and int8 dtype.
- Per-group data references of the tensors.
i.e. how each KV cache group maps to the tensors.
"""
# Ordered list of unique block tensors, each with shape
# (num_blocks, ...).
tensors
:
list
[
CanonicalKVCacheTensor
]
# Per-KV-cache-group list of data references that map each layer
# in the group to the appropriate entry in the tensors list.
group_data_refs
:
list
[
list
[
CanonicalKVCacheRef
]]
class
OffloadingSpec
(
ABC
):
class
OffloadingSpec
(
ABC
):
"""Spec for an offloading connector"""
"""Spec for an offloading connector"""
...
@@ -73,16 +123,13 @@ class OffloadingSpec(ABC):
...
@@ -73,16 +123,13 @@ class OffloadingSpec(ABC):
@
abstractmethod
@
abstractmethod
def
get_handlers
(
def
get_handlers
(
self
,
self
,
kv_caches
:
CanonicalKVCaches
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
attn_backends
:
dict
[
str
,
type
[
AttentionBackend
]],
)
->
Iterator
[
tuple
[
type
[
LoadStoreSpec
],
type
[
LoadStoreSpec
],
OffloadingHandler
]]:
)
->
Iterator
[
tuple
[
type
[
LoadStoreSpec
],
type
[
LoadStoreSpec
],
OffloadingHandler
]]:
"""
"""
Get offloading handlers along with their respective src and dst types.
Get offloading handlers along with their respective src and dst types.
Args:
Args:
kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor.
kv_caches: Canonicalized KV caches.
attn_backends: A dictionary of layer_name -> AttentionBackend.
Yields:
Yields:
Tuples of (src_type, dst_type, offloading_handler).
Tuples of (src_type, dst_type, offloading_handler).
...
...
vllm/v1/kv_offload/worker/cpu_gpu.py
View file @
7cc302dd
...
@@ -9,8 +9,8 @@ import torch
...
@@ -9,8 +9,8 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.kv_offload.mediums
import
BlockIDsLoadStoreSpec
from
vllm.v1.kv_offload.mediums
import
BlockIDsLoadStoreSpec
from
vllm.v1.kv_offload.spec
import
CanonicalKVCacheRef
,
CanonicalKVCaches
from
vllm.v1.kv_offload.worker.worker
import
(
from
vllm.v1.kv_offload.worker.worker
import
(
OffloadingHandler
,
OffloadingHandler
,
TransferResult
,
TransferResult
,
...
@@ -73,39 +73,72 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
...
@@ -73,39 +73,72 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
def
__init__
(
def
__init__
(
self
,
self
,
src_tensors
:
list
[
torch
.
Tensor
],
gpu_tensors
:
list
[
torch
.
Tensor
],
dst_tensors
:
list
[
torch
.
Tensor
],
cpu_tensors
:
list
[
torch
.
Tensor
],
src_block_size_factor
:
int
,
block_size_factor
:
int
,
dst_block_size_factor
:
int
,
kv_cache_groups_data_refs
:
list
[
list
[
CanonicalKVCacheRef
]],
gpu_to_cpu
:
bool
,
):
):
"""
"""
Initialize a SingleDirectionOffloadingHandler.
Initialize a SingleDirectionOffloadingHandler.
Args:
Args:
src
_tensors: list of KV cache tensors
to copy from
.
gpu
_tensors: list of
GPU
KV cache tensors.
dst_tensors: list of KV cache tensors to copy to
.
Each of shape (num_gpu_blocks, gpu_page_size_bytes) with dtype int8
.
Order should match src_
tensors.
cpu_tensors: list of CPU KV cache
tensors.
src_block_size_factor: The number of kernel blocks
Each of shape (num_cpu_blocks, cpu_page_size_bytes) with dtype int8.
per KV block in a source
tensor.
Order should match gpu_
tensor
s
.
dst_block_size_factor: The number o
f
k
er
nel blocks
kv_cache_groups_data_refs: list of CanonicalKVCacheRe
f
p
er
group.
per KV block in a destination tensor
.
gpu_to_cpu: if True, transfer from GPU to CPU; otherwise CPU to GPU
.
"""
"""
assert
len
(
src_tensors
)
==
len
(
dst_tensors
)
assert
len
(
gpu_tensors
)
==
len
(
cpu_tensors
)
assert
len
(
gpu_tensors
)
>
0
# assert a single KV group until transfer_async supports multiple groups
assert
len
(
kv_cache_groups_data_refs
)
==
1
# assert input tensors are as expected
for
gpu_tensor
,
cpu_tensor
in
zip
(
gpu_tensors
,
cpu_tensors
):
assert
gpu_tensor
.
dtype
==
torch
.
int8
assert
gpu_tensor
.
ndim
==
2
assert
gpu_tensor
.
is_cuda
assert
cpu_tensor
.
dtype
==
torch
.
int8
assert
cpu_tensor
.
ndim
==
2
assert
cpu_tensor
.
device
.
type
==
"cpu"
_
,
gpu_page_size
=
gpu_tensor
.
shape
_
,
cpu_page_size
=
cpu_tensor
.
shape
assert
cpu_page_size
==
gpu_page_size
*
block_size_factor
self
.
src_tensors
:
list
[
torch
.
Tensor
]
=
(
gpu_tensors
if
gpu_to_cpu
else
cpu_tensors
)
self
.
dst_tensors
:
list
[
torch
.
Tensor
]
=
(
cpu_tensors
if
gpu_to_cpu
else
gpu_tensors
)
self
.
gpu_to_cpu
:
bool
=
gpu_to_cpu
self
.
src_tensors
:
list
[
torch
.
Tensor
]
=
src_tensors
# GPU blocks may be smaller
self
.
dst_tensors
:
list
[
torch
.
Tensor
]
=
dst_tensors
# cpu_page_size = gpu_page_size * block_size_factor.
min_block_size_factor
=
min
(
src_block_size_factor
,
dst_block_size_factor
)
self
.
src_block_size_factor
=
1
if
self
.
gpu_to_cpu
else
block_size_factor
self
.
src_block_size_factor
:
int
=
src_block_size_factor
//
min_block_size_factor
self
.
dst_block_size_factor
=
block_size_factor
if
self
.
gpu_to_cpu
else
1
self
.
dst_block_size_factor
:
int
=
dst_block_size_factor
//
min_block_size_factor
self
.
block
_
size
_
in
_
byte
s
=
[
# per-tensor
block
size
in
byte
tensor
.
element_size
()
*
tensor
.
stride
(
0
)
*
min_block_size_factor
self
.
tensor_block_size_in_bytes
=
[
for
tensor
in
src
_tensors
gpu_tensor
.
shape
[
1
]
for
gpu_
tensor
in
gpu
_tensors
]
]
self
.
total_block_size_in_bytes
=
sum
(
self
.
block_size_in_bytes
)
assert
len
(
src_tensors
)
>
0
# per-group block size in bytes
self
.
gpu_to_cpu
:
bool
=
self
.
src_tensors
[
0
].
is_cuda
self
.
group_block_size_in_bytes
=
[]
for
kv_cache_group_data_refs
in
kv_cache_groups_data_refs
:
group_block_size_in_bytes
=
0
for
kv_cache_data_ref
in
kv_cache_group_data_refs
:
# TODO(orozery): use kv_cache_data_ref.page_size_bytes
# once swap_blocks support it
group_block_size_in_bytes
+=
self
.
tensor_block_size_in_bytes
[
kv_cache_data_ref
.
tensor_idx
]
self
.
group_block_size_in_bytes
.
append
(
group_block_size_in_bytes
)
self
.
transfer_type
=
(
"GPU"
,
"CPU"
)
if
self
.
gpu_to_cpu
else
(
"CPU"
,
"GPU"
)
self
.
transfer_type
=
(
"GPU"
,
"CPU"
)
if
self
.
gpu_to_cpu
else
(
"CPU"
,
"GPU"
)
# job_id -> event
# job_id -> event
self
.
_transfer_events
:
dict
[
int
,
torch
.
Event
]
=
{}
self
.
_transfer_events
:
dict
[
int
,
torch
.
Event
]
=
{}
...
@@ -167,7 +200,7 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
...
@@ -167,7 +200,7 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
for
src_tensor
,
dst_tensor
,
block_size_in_bytes
in
zip
(
for
src_tensor
,
dst_tensor
,
block_size_in_bytes
in
zip
(
self
.
src_tensors
,
self
.
src_tensors
,
self
.
dst_tensors
,
self
.
dst_tensors
,
self
.
block_size_in_bytes
,
self
.
tensor_
block_size_in_bytes
,
):
):
ops
.
swap_blocks
(
ops
.
swap_blocks
(
src_tensor
,
src_tensor
,
...
@@ -184,7 +217,7 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
...
@@ -184,7 +217,7 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
stream
=
stream
,
stream
=
stream
,
start_event
=
start_event
,
start_event
=
start_event
,
end_event
=
end_event
,
end_event
=
end_event
,
num_bytes
=
dst_sub_block_count
*
self
.
total
_block_size_in_bytes
,
num_bytes
=
dst_sub_block_count
*
self
.
group
_block_size_in_bytes
[
0
]
,
)
)
)
)
...
@@ -223,102 +256,42 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
...
@@ -223,102 +256,42 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
class
CpuGpuOffloadingHandlers
:
class
CpuGpuOffloadingHandlers
:
def
__init__
(
def
__init__
(
self
,
self
,
gpu_block_size
:
int
,
kv_caches
:
CanonicalKVCaches
,
cpu_
block_size
:
int
,
block_size
_factor
:
int
,
num_cpu_blocks
:
int
,
num_cpu_blocks
:
int
,
gpu_caches
:
dict
[
str
,
torch
.
Tensor
],
attn_backends
:
dict
[
str
,
type
[
AttentionBackend
]],
):
):
assert
gpu_caches
assert
cpu_block_size
%
gpu_block_size
==
0
# find kernel block size and determine layout per each gpu tensor
kernel_block_size
:
int
|
None
=
None
# list of (gpu_tensor, split_k_and_v)
parsed_gpu_tensors
:
list
[
tuple
[
torch
.
Tensor
,
bool
]]
=
[]
for
layer_name
,
gpu_tensor
in
gpu_caches
.
items
():
gpu_shape
=
gpu_tensor
.
shape
attn_backend
=
attn_backends
[
layer_name
]
test_shape
=
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
1234
,
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
256
)
has_layers_dim
=
False
split_k_and_v
=
False
if
len
(
gpu_shape
)
!=
len
(
test_shape
):
# cross-layers tensor
# shape is (num_blocks, ...)
assert
len
(
gpu_shape
)
==
len
(
test_shape
)
+
1
has_layers_dim
=
True
# prepend a dummy num_layers=80 to test_shape
test_shape
=
(
80
,)
+
test_shape
elif
test_shape
[
0
]
!=
1234
:
# shape should be (2, num_blocks, ...)
assert
test_shape
[
0
]
==
2
assert
test_shape
[
1
]
==
1234
assert
gpu_shape
[
0
]
==
2
split_k_and_v
=
True
if
has_layers_dim
:
# in the cross layers case, the registered kv cache tensor
# shape matches the physical layout, whereas test_shape
# is the logical layout.
# To match them, we need to permute test_shape
try
:
kv_cache_stride_order
=
attn_backend
.
get_kv_cache_stride_order
(
include_num_layers_dimension
=
has_layers_dim
)
assert
len
(
kv_cache_stride_order
)
==
len
(
gpu_shape
)
except
(
AttributeError
,
NotImplementedError
):
kv_cache_stride_order
=
tuple
(
range
(
len
(
gpu_shape
)))
test_shape
=
tuple
(
test_shape
[
i
]
for
i
in
kv_cache_stride_order
)
# find block_size (16) dimension index
block_size_idx
=
test_shape
.
index
(
16
)
if
kernel_block_size
is
not
None
:
assert
kernel_block_size
==
gpu_shape
[
block_size_idx
]
else
:
kernel_block_size
=
gpu_shape
[
block_size_idx
]
assert
gpu_block_size
%
kernel_block_size
==
0
parsed_gpu_tensors
.
append
((
gpu_tensor
,
split_k_and_v
))
assert
kernel_block_size
is
not
None
cpu_block_size_factor
=
cpu_block_size
//
kernel_block_size
gpu_block_size_factor
=
gpu_block_size
//
kernel_block_size
num_cpu_kernel_blocks
=
num_cpu_blocks
*
cpu_block_size_factor
# allocate cpu tensors
pin_memory
=
is_pin_memory_available
()
pin_memory
=
is_pin_memory_available
()
logger
.
info
(
"Allocating %d CPU tensors..."
,
len
(
parsed_gpu_
tensors
))
logger
.
info
(
"Allocating %d CPU tensors..."
,
len
(
kv_caches
.
tensors
))
gpu_tensors
:
list
[
torch
.
Tensor
]
=
[]
gpu_tensors
:
list
[
torch
.
Tensor
]
=
[]
cpu_tensors
:
list
[
torch
.
Tensor
]
=
[]
cpu_tensors
:
list
[
torch
.
Tensor
]
=
[]
for
gpu_tensor
,
split_k_and_v
in
parsed_gpu_tensors
:
for
kv_cache_tensor
in
kv_caches
.
tensors
:
cpu_shape
=
list
(
gpu_tensor
.
shape
)
gpu_page_size_bytes
=
kv_cache_tensor
.
page_size_bytes
cpu_shape
[
1
if
split_k_and_v
else
0
]
=
num_cpu_kernel_blocks
gpu_tensor
=
kv_cache_tensor
.
tensor
.
view
(
torch
.
int8
).
view
(
(
-
1
,
gpu_page_size_bytes
)
logger
.
debug
(
"Allocating CPU tensor of shape %r"
,
cpu_shape
)
)
cpu_page_size_bytes
=
gpu_page_size_bytes
*
block_size_factor
cpu_tensor
=
torch
.
zeros
(
cpu_tensor
=
torch
.
zeros
(
cpu_shape
,
(
num_cpu_blocks
,
cpu_page_size_bytes
)
,
dtype
=
gpu_tensor
.
dtype
,
dtype
=
torch
.
int8
,
device
=
"cpu"
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
)
)
gpu_tensors
.
extend
(
gpu_tensor
.
unbind
(
0
)
if
split_k_and_v
else
[
gpu_tensor
]
)
gpu_tensors
.
append
(
gpu_tensor
)
cpu_tensors
.
extend
(
cpu_tensor
.
unbind
(
0
)
if
split_k_and_v
else
[
cpu_tensor
]
)
cpu_tensors
.
append
(
cpu_tensor
)
self
.
gpu_to_cpu_handler
=
SingleDirectionOffloadingHandler
(
self
.
gpu_to_cpu_handler
=
SingleDirectionOffloadingHandler
(
src_tensors
=
gpu_tensors
,
gpu_tensors
=
gpu_tensors
,
dst_tensors
=
cpu_tensors
,
cpu_tensors
=
cpu_tensors
,
src_block_size_factor
=
gpu_block_size_factor
,
block_size_factor
=
block_size_factor
,
dst_block_size_factor
=
cpu_block_size_factor
,
kv_cache_groups_data_refs
=
kv_caches
.
group_data_refs
,
gpu_to_cpu
=
True
,
)
)
self
.
cpu_to_gpu_handler
=
SingleDirectionOffloadingHandler
(
self
.
cpu_to_gpu_handler
=
SingleDirectionOffloadingHandler
(
src_tensors
=
cpu_tensors
,
gpu_tensors
=
gpu_tensors
,
dst_tensors
=
gpu_tensors
,
cpu_tensors
=
cpu_tensors
,
src_block_size_factor
=
cpu_block_size_factor
,
block_size_factor
=
block_size_factor
,
dst_block_size_factor
=
gpu_block_size_factor
,
kv_cache_groups_data_refs
=
kv_caches
.
group_data_refs
,
gpu_to_cpu
=
False
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment