Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
64746471
Unverified
Commit
64746471
authored
Nov 20, 2025
by
Or Ozeri
Committed by
GitHub
Nov 20, 2025
Browse files
[KVConnector][Core] Support cross-layer KV blocks (#27743)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
e5bfcb6a
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
453 additions
and
90 deletions
+453
-90
tests/v1/kv_connector/unit/test_offloading_connector.py
tests/v1/kv_connector/unit/test_offloading_connector.py
+6
-2
tests/v1/kv_offload/test_cpu_offloading.py
tests/v1/kv_offload/test_cpu_offloading.py
+93
-52
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+4
-1
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+28
-1
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+31
-2
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
...buted/kv_transfer/kv_connector/v1/offloading_connector.py
+38
-5
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+10
-2
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+10
-2
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+9
-0
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+5
-1
vllm/v1/kv_offload/cpu.py
vllm/v1/kv_offload/cpu.py
+5
-12
vllm/v1/kv_offload/spec.py
vllm/v1/kv_offload/spec.py
+5
-1
vllm/v1/kv_offload/worker/cpu_gpu.py
vllm/v1/kv_offload/worker/cpu_gpu.py
+10
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+34
-7
vllm/v1/worker/kv_connector_model_runner_mixin.py
vllm/v1/worker/kv_connector_model_runner_mixin.py
+165
-0
No files found.
tests/v1/kv_connector/unit/test_offloading_connector.py
View file @
64746471
...
@@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
...
@@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
)
)
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
from
vllm.utils.hashing
import
sha256
from
vllm.utils.hashing
import
sha256
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.v1.core.kv_cache_utils
import
(
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHash
,
get_request_block_hasher
,
get_request_block_hasher
,
...
@@ -92,7 +93,7 @@ class MockOffloadingSpec(OffloadingSpec):
...
@@ -92,7 +93,7 @@ class MockOffloadingSpec(OffloadingSpec):
return
self
.
manager
return
self
.
manager
def
get_handlers
(
def
get_handlers
(
self
,
_
self
,
_
,
__
)
->
Iterator
[
tuple
[
type
[
LoadStoreSpec
],
type
[
LoadStoreSpec
],
OffloadingHandler
]]:
)
->
Iterator
[
tuple
[
type
[
LoadStoreSpec
],
type
[
LoadStoreSpec
],
OffloadingHandler
]]:
yield
GPULoadStoreSpec
,
MockLoadStoreSpec
,
self
.
handler
yield
GPULoadStoreSpec
,
MockLoadStoreSpec
,
self
.
handler
yield
MockLoadStoreSpec
,
GPULoadStoreSpec
,
self
.
handler
yield
MockLoadStoreSpec
,
GPULoadStoreSpec
,
self
.
handler
...
@@ -138,7 +139,10 @@ class RequestRunner:
...
@@ -138,7 +139,10 @@ class RequestRunner:
self
.
worker_connector
=
OffloadingConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
self
.
worker_connector
=
OffloadingConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
# register worker kv_caches to enable OffloadingWorker creations
# register worker kv_caches to enable OffloadingWorker creations
self
.
worker_connector
.
register_kv_caches
(
kv_caches
=
{
"a"
:
torch
.
empty
(
0
)})
self
.
worker_connector
.
register_cross_layers_kv_cache
(
kv_cache
=
torch
.
empty
(
0
),
attn_backend
=
FlashAttentionBackend
,
)
# extract connector of scheduler
# extract connector of scheduler
scheduler_connector
=
self
.
scheduler
.
connector
scheduler_connector
=
self
.
scheduler
.
connector
...
...
tests/v1/kv_offload/test_cpu_offloading.py
View file @
64746471
...
@@ -12,8 +12,10 @@ from tqdm import tqdm
...
@@ -12,8 +12,10 @@ from tqdm import tqdm
from
vllm
import
LLM
,
SamplingParams
,
TokensPrompt
from
vllm
import
LLM
,
SamplingParams
,
TokensPrompt
from
vllm.config
import
KVEventsConfig
,
KVTransferConfig
from
vllm.config
import
KVEventsConfig
,
KVTransferConfig
from
vllm.distributed.kv_events
import
BlockStored
,
KVEventBatch
from
vllm.distributed.kv_events
import
BlockStored
,
KVEventBatch
from
vllm.utils.system_utils
import
set_env_var
CPU_BLOCK_SIZES
=
[
16
,
48
]
CPU_BLOCK_SIZES
=
[
48
]
ATTN_BACKENDS
=
[
"FLASH_ATTN"
,
"FLASHINFER"
]
class
MockSubscriber
:
class
MockSubscriber
:
...
@@ -63,8 +65,88 @@ class MockSubscriber:
...
@@ -63,8 +65,88 @@ class MockSubscriber:
self
.
sub
.
close
()
self
.
sub
.
close
()
def
_latency_test
(
llm
:
LLM
,
subscriber
:
MockSubscriber
):
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
num_times_cpu_better_than_cold
=
0
num_tests
=
10
total_cold_time
=
0.0
total_gpu_hit_time
=
0.0
total_cpu_hit_time
=
0.0
prompt_token_ids
=
[
0
]
*
10001
for
i
in
tqdm
(
range
(
num_tests
),
desc
=
"Running tests"
):
prompt_token_ids
[
0
]
=
i
prompts
=
[
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
)]
# run generation - this should trigger saving KV cache
start_time
=
time
.
time
()
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
False
)
cold_time
=
time
.
time
()
-
start_time
total_cold_time
+=
cold_time
# run generation again - should hit the GPU prefix cache
start_time
=
time
.
time
()
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
False
)
gpu_hit_time
=
time
.
time
()
-
start_time
total_gpu_hit_time
+=
gpu_hit_time
# reset prefix cache to avoid GPU hit.
llm
.
reset_prefix_cache
()
assert
subscriber
.
get_new_cpu_stored_events
()
# run generation again - this should trigger loading from CPU
start_time
=
time
.
time
()
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
False
)
cpu_hit_time
=
time
.
time
()
-
start_time
total_cpu_hit_time
+=
cpu_hit_time
if
cpu_hit_time
<
cold_time
:
num_times_cpu_better_than_cold
+=
1
print
(
"Average times:"
)
print
(
f
" Cold:
{
total_cold_time
*
1000
/
num_tests
:.
2
f
}
ms"
)
print
(
f
" GPU hit:
{
total_gpu_hit_time
*
1000
/
num_tests
:.
2
f
}
ms"
)
print
(
f
" CPU hit:
{
total_cpu_hit_time
*
1000
/
num_tests
:.
2
f
}
ms"
)
assert
num_times_cpu_better_than_cold
>=
0.8
*
num_tests
def
_accuracy_test
(
llm
:
LLM
,
subscriber
:
MockSubscriber
):
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
cpu_block_size
=
(
llm
.
llm_engine
.
vllm_config
.
kv_transfer_config
.
kv_connector_extra_config
[
"block_size"
]
)
subscriber
.
get_new_cpu_stored_events
()
# prepend prompt to be cpu block aligned
prompt
=
"Let's count to 10. One, two, three, four,"
while
(
len
(
llm
.
generate
(
prompt
,
use_tqdm
=
False
)[
0
].
prompt_token_ids
)
%
cpu_block_size
!=
0
):
prompt
=
". "
+
prompt
assert
subscriber
.
get_new_cpu_stored_events
()
test_count
=
100
success_count
=
0
for
i
in
range
(
test_count
):
if
(
llm
.
generate
(
prompt
,
sampling_params
,
use_tqdm
=
False
)[
0
].
outputs
[
0
].
text
==
" five"
):
success_count
+=
1
assert
success_count
>=
0.5
*
test_count
@
pytest
.
mark
.
parametrize
(
"cpu_block_size"
,
CPU_BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"cpu_block_size"
,
CPU_BLOCK_SIZES
)
def
test_cpu_offloading
(
cpu_block_size
:
int
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
ATTN_BACKENDS
)
def
test_cpu_offloading
(
cpu_block_size
:
int
,
attn_backend
:
str
)
->
None
:
"""
"""
Tests OffloadingConnector with CPUOffloadingSpec.
Tests OffloadingConnector with CPUOffloadingSpec.
"""
"""
...
@@ -92,61 +174,20 @@ def test_cpu_offloading(cpu_block_size: int) -> None:
...
@@ -92,61 +174,20 @@ def test_cpu_offloading(cpu_block_size: int) -> None:
topic
=
"test"
,
topic
=
"test"
,
)
)
llm
=
LLM
(
with
set_env_var
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
):
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
llm
=
LLM
(
gpu_memory_utilization
=
0.5
,
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
kv_events_config
=
kv_events_config
,
gpu_memory_utilization
=
0.5
,
kv_transfer_config
=
kv_transfer_config
,
kv_events_config
=
kv_events_config
,
)
kv_transfer_config
=
kv_transfer_config
,
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
1
)
events_endpoint
=
events_endpoint
.
replace
(
"*"
,
"127.0.0.1"
)
events_endpoint
=
events_endpoint
.
replace
(
"*"
,
"127.0.0.1"
)
subscriber
=
MockSubscriber
(
events_endpoint
,
topic
=
kv_events_config
.
topic
)
subscriber
=
MockSubscriber
(
events_endpoint
,
topic
=
kv_events_config
.
topic
)
try
:
try
:
num_times_cpu_better_than_cold
=
0
_latency_test
(
llm
,
subscriber
)
num_tests
=
10
_accuracy_test
(
llm
,
subscriber
)
total_cold_time
=
0.0
total_gpu_hit_time
=
0.0
total_cpu_hit_time
=
0.0
prompt_token_ids
=
[
0
]
*
10001
for
i
in
tqdm
(
range
(
num_tests
),
desc
=
"Running tests"
):
prompt_token_ids
[
0
]
=
i
prompts
=
[
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
)]
# run generation - this should trigger saving KV cache
start_time
=
time
.
time
()
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
False
)
cold_time
=
time
.
time
()
-
start_time
total_cold_time
+=
cold_time
# run generation again - should hit the GPU prefix cache
start_time
=
time
.
time
()
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
False
)
gpu_hit_time
=
time
.
time
()
-
start_time
total_gpu_hit_time
+=
gpu_hit_time
# reset prefix cache to avoid GPU hit.
llm
.
reset_prefix_cache
()
assert
subscriber
.
get_new_cpu_stored_events
()
# run generation again - this should trigger loading from CPU
start_time
=
time
.
time
()
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
False
)
cpu_hit_time
=
time
.
time
()
-
start_time
total_cpu_hit_time
+=
cpu_hit_time
if
cpu_hit_time
<
cold_time
:
num_times_cpu_better_than_cold
+=
1
print
(
"Average times:"
)
print
(
f
" Cold:
{
total_cold_time
*
1000
/
num_tests
:.
2
f
}
ms"
)
print
(
f
" GPU hit:
{
total_gpu_hit_time
*
1000
/
num_tests
:.
2
f
}
ms"
)
print
(
f
" CPU hit:
{
total_cpu_hit_time
*
1000
/
num_tests
:.
2
f
}
ms"
)
assert
num_times_cpu_better_than_cold
>=
0.8
*
num_tests
finally
:
finally
:
subscriber
.
close
()
subscriber
.
close
()
del
llm
del
llm
tests/v1/worker/test_gpu_model_runner.py
View file @
64746471
...
@@ -483,7 +483,10 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
...
@@ -483,7 +483,10 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
# Permutation that gets you back to expected kv shape
# Permutation that gets you back to expected kv shape
for
test_stride
in
((
1
,
4
,
0
,
2
,
3
),
(
0
,
1
,
2
,
3
,
4
)):
for
test_stride
in
((
1
,
4
,
0
,
2
,
3
),
(
0
,
1
,
2
,
3
,
4
)):
def
rnd_stride_order
(
test_stride
=
test_stride
):
def
rnd_stride_order
(
include_num_layers_dimension
:
bool
=
False
,
test_stride
=
test_stride
):
assert
not
include_num_layers_dimension
return
test_stride
return
test_stride
# Patch the attention backend class and re-trigger the KV cache creation
# Patch the attention backend class and re-trigger the KV cache creation
...
...
vllm/attention/backends/abstract.py
View file @
64746471
...
@@ -76,7 +76,34 @@ class AttentionBackend(ABC):
...
@@ -76,7 +76,34 @@ class AttentionBackend(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
int
,
...]:
def
get_kv_cache_stride_order
(
include_num_layers_dimension
:
bool
=
False
,
)
->
tuple
[
int
,
...]:
"""
Get the physical (memory layout) ordering of the kv cache dimensions.
e.g. if the KV cache shape is
[2, num_blocks, block_size, num_heads, head_size],
and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
ordering of dimensions is
[num_blocks, num_heads, 2, block_size, head_size].
If this function is unimplemented / raises NotImplementedError,
the physical layout of the KV cache will match the logical shape.
Args:
include_num_layers_dimension: if True, includes an additional
num_layers dimension, which is assumed to be prepended
to the logical KV cache shape.
With the above example, a return value (2, 4, 0, 1, 3, 5)
corresponds to
[num_blocks, num_heads, num_layers, 2, block_size, head_size].
If an additional dimension is NOT included in the returned
tuple, the physical layout will not include a layers dimension.
Returns:
A tuple of ints which is a permutation of range(len(shape)).
"""
raise
NotImplementedError
raise
NotImplementedError
@
classmethod
@
classmethod
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
64746471
...
@@ -38,7 +38,7 @@ The class provides the following primitives:
...
@@ -38,7 +38,7 @@ The class provides the following primitives:
import
enum
import
enum
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
,
Iterable
from
collections.abc
import
Callable
,
Iterable
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Literal
,
Optional
import
torch
import
torch
...
@@ -47,7 +47,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
...
@@ -47,7 +47,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
from
vllm.v1.outputs
import
KVConnectorOutput
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionBackend
,
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
...
@@ -142,6 +142,18 @@ class KVConnectorMetadata(ABC): # noqa: B024
...
@@ -142,6 +142,18 @@ class KVConnectorMetadata(ABC): # noqa: B024
class
KVConnectorBase_V1
(
ABC
):
class
KVConnectorBase_V1
(
ABC
):
"""
Base class for KV connectors.
Attributes:
prefer_cross_layer_blocks (bool): Indicates whether this connector
prefers KV blocks that hold KV data for all layers (for speeding
up KV data transfers).
Defaults to False.
"""
prefer_cross_layer_blocks
:
ClassVar
[
bool
]
=
False
def
__init__
(
def
__init__
(
self
,
self
,
vllm_config
:
"VllmConfig"
,
vllm_config
:
"VllmConfig"
,
...
@@ -226,6 +238,23 @@ class KVConnectorBase_V1(ABC):
...
@@ -226,6 +238,23 @@ class KVConnectorBase_V1(ABC):
"""
"""
return
return
def
register_cross_layers_kv_cache
(
self
,
kv_cache
:
torch
.
Tensor
,
attn_backend
:
type
[
"AttentionBackend"
]
):
"""
Initialize with a single KV cache tensor used by all layers.
The first dimension should be num_layers.
This function will only be called for models with uniform layers,
and only if the prefers_cross_layer_blocks is set to True.
Only one of the functions
{register_kv_caches, register_cross_layers_kv_cache} will be called.
Args:
kv_cache: a cross-layers kv cache tensor
attn_backend: The attention backend that corresponds to all layers
"""
return
def
set_host_xfer_buffer_ops
(
self
,
copy_operation
:
CopyBlocksOp
):
def
set_host_xfer_buffer_ops
(
self
,
copy_operation
:
CopyBlocksOp
):
"""
"""
Set the xPU-specific ops for copying KV between host and device.
Set the xPU-specific ops for copying KV between host and device.
...
...
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
View file @
64746471
...
@@ -4,12 +4,12 @@ from collections import defaultdict
...
@@ -4,12 +4,12 @@ from collections import defaultdict
from
collections.abc
import
Iterable
,
Iterator
from
collections.abc
import
Iterable
,
Iterator
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
itertools
import
islice
from
itertools
import
islice
from
typing
import
Any
from
typing
import
Any
,
ClassVar
import
torch
import
torch
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionBackend
,
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.distributed.kv_events
import
BlockRemoved
,
BlockStored
,
KVCacheEvent
from
vllm.distributed.kv_events
import
BlockRemoved
,
BlockStored
,
KVCacheEvent
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
KVConnectorBase_V1
,
...
@@ -42,6 +42,8 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
...
@@ -42,6 +42,8 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
class
OffloadingConnector
(
KVConnectorBase_V1
):
class
OffloadingConnector
(
KVConnectorBase_V1
):
prefer_cross_layer_blocks
:
ClassVar
[
bool
]
=
True
def
__init__
(
def
__init__
(
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
...
@@ -63,6 +65,12 @@ class OffloadingConnector(KVConnectorBase_V1):
...
@@ -63,6 +65,12 @@ class OffloadingConnector(KVConnectorBase_V1):
assert
self
.
connector_worker
is
not
None
assert
self
.
connector_worker
is
not
None
self
.
connector_worker
.
register_kv_caches
(
kv_caches
)
self
.
connector_worker
.
register_kv_caches
(
kv_caches
)
def
register_cross_layers_kv_cache
(
self
,
kv_cache
:
torch
.
Tensor
,
attn_backend
:
type
[
AttentionBackend
]
):
assert
self
.
connector_worker
is
not
None
self
.
connector_worker
.
register_cross_layers_kv_cache
(
kv_cache
,
attn_backend
)
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
)
->
None
:
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
)
->
None
:
assert
self
.
connector_worker
is
not
None
assert
self
.
connector_worker
is
not
None
assert
isinstance
(
self
.
_connector_metadata
,
OffloadingConnectorMetadata
)
assert
isinstance
(
self
.
_connector_metadata
,
OffloadingConnectorMetadata
)
...
@@ -422,10 +430,35 @@ class OffloadingConnectorWorker:
...
@@ -422,10 +430,35 @@ class OffloadingConnectorWorker:
self
.
_job_counter
=
job_id
+
1
self
.
_job_counter
=
job_id
+
1
return
job_id
return
job_id
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
def
_register_handlers
(
for
src_cls
,
dst_cls
,
handler
in
self
.
spec
.
get_handlers
(
kv_caches
):
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
attn_backends
:
dict
[
str
,
type
[
AttentionBackend
]],
):
for
src_cls
,
dst_cls
,
handler
in
self
.
spec
.
get_handlers
(
kv_caches
,
attn_backends
):
self
.
worker
.
register_handler
(
src_cls
,
dst_cls
,
handler
)
self
.
worker
.
register_handler
(
src_cls
,
dst_cls
,
handler
)
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
layer_names
=
list
(
kv_caches
.
keys
())
layers
=
get_layers_from_vllm_config
(
self
.
spec
.
vllm_config
,
Attention
,
layer_names
)
attn_backends
=
{
layer_name
:
layers
[
layer_name
].
get_attn_backend
()
for
layer_name
in
layer_names
}
self
.
_register_handlers
(
kv_caches
,
attn_backends
)
def
register_cross_layers_kv_cache
(
self
,
kv_cache
:
torch
.
Tensor
,
attn_backend
:
type
[
AttentionBackend
]
):
cross_layer_name
=
"ALL_LAYERS"
kv_caches
=
{
cross_layer_name
:
kv_cache
}
attn_backends
=
{
cross_layer_name
:
attn_backend
}
self
.
_register_handlers
(
kv_caches
,
attn_backends
)
def
start_load_kv
(
self
,
metadata
:
OffloadingConnectorMetadata
):
def
start_load_kv
(
self
,
metadata
:
OffloadingConnectorMetadata
):
for
req_id
,
transfer_spec
in
metadata
.
reqs_to_load
.
items
():
for
req_id
,
transfer_spec
in
metadata
.
reqs_to_load
.
items
():
job_id
=
self
.
_generate_job_id
()
job_id
=
self
.
_generate_job_id
()
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
64746471
...
@@ -99,12 +99,20 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -99,12 +99,20 @@ class FlashAttentionBackend(AttentionBackend):
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
int
,
...]:
def
get_kv_cache_stride_order
(
include_num_layers_dimension
:
bool
=
False
,
)
->
tuple
[
int
,
...]:
# `stride_order` indicates the permutation that gets
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
:
if
cache_layout
==
"NHD"
and
include_num_layers_dimension
:
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return
(
2
,
0
,
1
,
3
,
4
,
5
)
elif
cache_layout
==
"NHD"
:
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
elif
cache_layout
==
"HND"
and
include_num_layers_dimension
:
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
return
(
2
,
4
,
0
,
1
,
3
,
5
)
elif
cache_layout
==
"HND"
:
elif
cache_layout
==
"HND"
:
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
else
:
else
:
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
64746471
...
@@ -309,12 +309,20 @@ class FlashInferBackend(AttentionBackend):
...
@@ -309,12 +309,20 @@ class FlashInferBackend(AttentionBackend):
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
int
,
...]:
def
get_kv_cache_stride_order
(
include_num_layers_dimension
:
bool
=
False
,
)
->
tuple
[
int
,
...]:
# `stride_order` indicates the permutation that gets us from
# `stride_order` indicates the permutation that gets us from
# `get_kv_cache_shape` to the actual memory layout we want.
# `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
:
if
cache_layout
==
"NHD"
and
include_num_layers_dimension
:
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return
(
1
,
0
,
2
,
3
,
4
,
5
)
elif
cache_layout
==
"NHD"
:
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
elif
cache_layout
==
"HND"
and
include_num_layers_dimension
:
# (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
return
(
1
,
2
,
4
,
0
,
3
,
5
)
elif
cache_layout
==
"HND"
:
elif
cache_layout
==
"HND"
:
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
else
:
else
:
...
...
vllm/v1/attention/backends/mla/common.py
View file @
64746471
...
@@ -308,6 +308,15 @@ class MLACommonBackend(AttentionBackend):
...
@@ -308,6 +308,15 @@ class MLACommonBackend(AttentionBackend):
)
->
tuple
[
int
,
...]:
)
->
tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
head_size
)
return
(
num_blocks
,
block_size
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
(
include_num_layers_dimension
:
bool
=
False
,
)
->
tuple
[
int
,
...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
# (num_blocks, num_layers, block_size, head_size)
return
(
1
,
0
,
2
,
3
)
if
include_num_layers_dimension
else
(
0
,
1
,
2
)
@
classmethod
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
576
]
return
[
576
]
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
64746471
...
@@ -48,7 +48,11 @@ class DeepseekV32IndexerBackend(AttentionBackend):
...
@@ -48,7 +48,11 @@ class DeepseekV32IndexerBackend(AttentionBackend):
return
(
num_blocks
,
block_size
,
head_size
)
return
(
num_blocks
,
block_size
,
head_size
)
@
staticmethod
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
int
,
...]:
def
get_kv_cache_stride_order
(
include_num_layers_dimension
:
bool
=
False
,
)
->
tuple
[
int
,
...]:
if
include_num_layers_dimension
:
return
(
0
,
1
,
2
,
3
)
return
(
0
,
1
,
2
)
return
(
0
,
1
,
2
)
...
...
vllm/v1/kv_offload/cpu.py
View file @
64746471
...
@@ -4,8 +4,8 @@ from collections.abc import Iterator
...
@@ -4,8 +4,8 @@ from collections.abc import Iterator
import
torch
import
torch
from
vllm.
config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.
attention
import
AttentionBackend
from
vllm.
model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.
config
import
VllmConfig
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.kv_offload.abstract
import
LoadStoreSpec
,
OffloadingManager
from
vllm.v1.kv_offload.abstract
import
LoadStoreSpec
,
OffloadingManager
from
vllm.v1.kv_offload.arc_manager
import
ARCOffloadingManager
from
vllm.v1.kv_offload.arc_manager
import
ARCOffloadingManager
...
@@ -63,7 +63,9 @@ class CPUOffloadingSpec(OffloadingSpec):
...
@@ -63,7 +63,9 @@ class CPUOffloadingSpec(OffloadingSpec):
return
self
.
_manager
return
self
.
_manager
def
get_handlers
(
def
get_handlers
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
attn_backends
:
dict
[
str
,
type
[
AttentionBackend
]],
)
->
Iterator
[
tuple
[
type
[
LoadStoreSpec
],
type
[
LoadStoreSpec
],
OffloadingHandler
]]:
)
->
Iterator
[
tuple
[
type
[
LoadStoreSpec
],
type
[
LoadStoreSpec
],
OffloadingHandler
]]:
if
not
self
.
_handler
:
if
not
self
.
_handler
:
if
not
current_platform
.
is_cuda_alike
():
if
not
current_platform
.
is_cuda_alike
():
...
@@ -71,15 +73,6 @@ class CPUOffloadingSpec(OffloadingSpec):
...
@@ -71,15 +73,6 @@ class CPUOffloadingSpec(OffloadingSpec):
"CPU Offloading is currently only supported on CUDA-alike GPUs"
"CPU Offloading is currently only supported on CUDA-alike GPUs"
)
)
layer_names
=
list
(
kv_caches
.
keys
())
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
,
layer_names
)
attn_backends
=
{
layer_name
:
layers
[
layer_name
].
get_attn_backend
()
for
layer_name
in
layer_names
}
self
.
_handler
=
CpuGpuOffloadingHandler
(
self
.
_handler
=
CpuGpuOffloadingHandler
(
attn_backends
=
attn_backends
,
attn_backends
=
attn_backends
,
gpu_block_size
=
self
.
gpu_block_size
,
gpu_block_size
=
self
.
gpu_block_size
,
...
...
vllm/v1/kv_offload/spec.py
View file @
64746471
...
@@ -11,6 +11,7 @@ from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
...
@@ -11,6 +11,7 @@ from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from
vllm.v1.kv_offload.worker.worker
import
OffloadingHandler
from
vllm.v1.kv_offload.worker.worker
import
OffloadingHandler
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention
import
AttentionBackend
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -48,13 +49,16 @@ class OffloadingSpec(ABC):
...
@@ -48,13 +49,16 @@ class OffloadingSpec(ABC):
@
abstractmethod
@
abstractmethod
def
get_handlers
(
def
get_handlers
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
attn_backends
:
dict
[
str
,
type
[
"AttentionBackend"
]],
)
->
Iterator
[
tuple
[
type
[
LoadStoreSpec
],
type
[
LoadStoreSpec
],
OffloadingHandler
]]:
)
->
Iterator
[
tuple
[
type
[
LoadStoreSpec
],
type
[
LoadStoreSpec
],
OffloadingHandler
]]:
"""
"""
Get offloading handlers along with their respective src and dst types.
Get offloading handlers along with their respective src and dst types.
Args:
Args:
kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor.
kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor.
attn_backends: A dictionary of layer_name -> AttentionBackend.
Yields:
Yields:
Tuples of (src_type, dst_type, offloading_handler).
Tuples of (src_type, dst_type, offloading_handler).
...
...
vllm/v1/kv_offload/worker/cpu_gpu.py
View file @
64746471
...
@@ -83,10 +83,18 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
...
@@ -83,10 +83,18 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
self
.
gpu_tensors
.
append
(
gpu_tensor
)
self
.
gpu_tensors
.
append
(
gpu_tensor
)
gpu_shape
=
gpu_tensor
.
shape
gpu_shape
=
gpu_tensor
.
shape
test_shape
=
attn_backends
[
layer_name
].
get_kv_cache_shape
(
attn_backend
=
attn_backends
[
layer_name
]
test_shape
=
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
1234
,
block_size
=
16
,
num_kv_heads
=
8
,
head_size
=
256
num_blocks
=
1234
,
block_size
=
16
,
num_kv_heads
=
8
,
head_size
=
256
)
)
if
test_shape
[
0
]
==
1234
:
if
len
(
gpu_shape
)
!=
len
(
test_shape
):
# cross-layers tensor
# shape is (num_blocks, ...)
assert
len
(
gpu_shape
)
==
len
(
test_shape
)
+
1
num_blocks_idx
=
0
self
.
kv_dim_before_num_blocks
.
append
(
False
)
elif
test_shape
[
0
]
==
1234
:
# shape is (num_blocks, ...)
# shape is (num_blocks, ...)
num_blocks_idx
=
0
num_blocks_idx
=
0
self
.
kv_dim_before_num_blocks
.
append
(
False
)
self
.
kv_dim_before_num_blocks
.
append
(
False
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
64746471
...
@@ -349,6 +349,9 @@ class GPUModelRunner(
...
@@ -349,6 +349,9 @@ class GPUModelRunner(
# self.model: nn.Module # Set after load_model
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
# Initialize in initialize_kv_cache
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
# Initialize in initialize_kv_cache_tensors
self
.
cross_layers_kv_cache
:
torch
.
Tensor
|
None
=
None
self
.
cross_layers_attn_backend
:
type
[
AttentionBackend
]
|
None
=
None
# indexes: [kv_cache_group_id][attn_group]
# indexes: [kv_cache_group_id][attn_group]
self
.
attn_groups
:
list
[
list
[
AttentionGroup
]]
=
[]
self
.
attn_groups
:
list
[
list
[
AttentionGroup
]]
=
[]
# self.kv_cache_config: KVCacheConfig
# self.kv_cache_config: KVCacheConfig
...
@@ -4930,12 +4933,30 @@ class GPUModelRunner(
...
@@ -4930,12 +4933,30 @@ class GPUModelRunner(
Dict[str, torch.Tensor]: A map between layer names to their
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
corresponding memory buffer for KV cache.
"""
"""
# Initialize the memory buffer for KV cache
kv_cache_raw_tensors
=
self
.
_allocate_kv_cache_tensors
(
kv_cache_config
)
# Try creating KV caches optimized for kv-connector transfers
# Change the memory buffer to the desired shape
cache_dtype
=
self
.
cache_config
.
cache_dtype
kv_caches
=
self
.
_reshape_kv_cache_tensors
(
if
self
.
use_uniform_kv_cache
(
self
.
attn_groups
,
cache_dtype
):
kv_cache_config
,
kv_cache_raw_tensors
,
kernel_block_sizes
kv_caches
,
cross_layers_kv_cache
,
attn_backend
=
(
)
self
.
allocate_uniform_kv_caches
(
kv_cache_config
,
self
.
attn_groups
,
cache_dtype
,
self
.
device
,
kernel_block_sizes
,
)
)
self
.
cross_layers_kv_cache
=
cross_layers_kv_cache
self
.
cross_layers_attn_backend
=
attn_backend
else
:
# Fallback to the general case
# Initialize the memory buffer for KV cache
kv_cache_raw_tensors
=
self
.
_allocate_kv_cache_tensors
(
kv_cache_config
)
# Change the memory buffer to the desired shape
kv_caches
=
self
.
_reshape_kv_cache_tensors
(
kv_cache_config
,
kv_cache_raw_tensors
,
kernel_block_sizes
)
# Set up cross-layer KV cache sharing
# Set up cross-layer KV cache sharing
for
layer_name
,
target_layer_name
in
self
.
shared_kv_cache_layers
.
items
():
for
layer_name
,
target_layer_name
in
self
.
shared_kv_cache_layers
.
items
():
...
@@ -5017,7 +5038,13 @@ class GPUModelRunner(
...
@@ -5017,7 +5038,13 @@ class GPUModelRunner(
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
kv_transfer_group
=
get_kv_transfer_group
()
kv_transfer_group
=
get_kv_transfer_group
()
kv_transfer_group
.
register_kv_caches
(
kv_caches
)
if
self
.
cross_layers_kv_cache
is
not
None
:
assert
self
.
cross_layers_attn_backend
is
not
None
kv_transfer_group
.
register_cross_layers_kv_cache
(
self
.
cross_layers_kv_cache
,
self
.
cross_layers_attn_backend
)
else
:
kv_transfer_group
.
register_kv_caches
(
kv_caches
)
kv_transfer_group
.
set_host_xfer_buffer_ops
(
copy_kv_blocks
)
kv_transfer_group
.
set_host_xfer_buffer_ops
(
copy_kv_blocks
)
if
self
.
dcp_world_size
>
1
:
if
self
.
dcp_world_size
>
1
:
...
...
vllm/v1/worker/kv_connector_model_runner_mixin.py
View file @
64746471
...
@@ -11,7 +11,11 @@ from typing import (
...
@@ -11,7 +11,11 @@ from typing import (
TYPE_CHECKING
,
# noqa: UP035
TYPE_CHECKING
,
# noqa: UP035
)
)
import
torch
from
vllm.attention
import
AttentionBackend
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.cache
import
CacheDType
from
vllm.distributed.kv_transfer
import
(
from
vllm.distributed.kv_transfer
import
(
ensure_kv_transfer_shutdown
,
ensure_kv_transfer_shutdown
,
get_kv_transfer_group
,
get_kv_transfer_group
,
...
@@ -21,11 +25,13 @@ from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
...
@@ -21,11 +25,13 @@ from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
KVConnectorStats
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
KVConnectorStats
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
KVCacheConfig
from
vllm.v1.outputs
import
(
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
EMPTY_MODEL_RUNNER_OUTPUT
,
KVConnectorOutput
,
KVConnectorOutput
,
ModelRunnerOutput
,
ModelRunnerOutput
,
)
)
from
vllm.v1.worker.utils
import
AttentionGroup
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
@@ -142,3 +148,162 @@ class KVConnectorModelRunnerMixin:
...
@@ -142,3 +148,162 @@ class KVConnectorModelRunnerMixin:
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
return
get_kv_transfer_group
().
get_kv_connector_stats
()
return
get_kv_transfer_group
().
get_kv_connector_stats
()
return
None
return
None
@
staticmethod
def
use_uniform_kv_cache
(
attn_groups
:
list
[
list
[
AttentionGroup
]],
cache_dtype
:
CacheDType
,
)
->
bool
:
"""
Determines whether a uniform KV layout should be used.
A uniform layout means all layers KV caches will share the same
underlying tensor, where for a given block number, the respective
KV data for all layers will be contiguous.
This will allow efficient KV transfer of per-block KV data for all
layers at once.
Note this layout will only be applied given 3 conditions:
1. The KV Cache config contains just a single group where all layers
have the same page size.
2. A KV connector is configured, and the KV connector instance prefers
to use this layout (prefer_cross_layer_blocks() returns True)
2. The flash attention backend supports this layout
(get_kv_cache_stride_order(True) includes a placement for a
num_layers dimension)
Note that the actual placement of the num_layers dimensions
in the unified layers tensors will be determined by the attention
backend.
Thus, the layers KV data may still not be contiguous per block
if the attention backend does not support it.
Args:
attn_groups: The list of attention groups for this model
cache_dtype: The KV cache dtype
Returns:
True if we should use a uniform KV cache layout.
"""
if
not
has_kv_transfer_group
():
return
False
if
not
get_kv_transfer_group
().
prefer_cross_layer_blocks
:
return
False
if
len
(
attn_groups
)
!=
1
or
len
(
attn_groups
[
0
])
!=
1
:
return
False
attn_group
=
attn_groups
[
0
][
0
]
kv_cache_spec
=
attn_group
.
kv_cache_spec
if
not
isinstance
(
kv_cache_spec
,
AttentionSpec
):
return
False
attn_backend
=
attn_group
.
backend
kv_cache_shape
=
attn_backend
.
get_kv_cache_shape
(
1234
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
,
cache_dtype_str
=
cache_dtype
,
)
try
:
kv_cache_stride_order
=
attn_backend
.
get_kv_cache_stride_order
(
include_num_layers_dimension
=
True
)
except
(
AttributeError
,
NotImplementedError
):
return
False
# check that attention backend include a layers dimension
return
len
(
kv_cache_stride_order
)
==
len
(
kv_cache_shape
)
+
1
@
staticmethod
def
allocate_uniform_kv_caches
(
kv_cache_config
:
KVCacheConfig
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
cache_dtype
:
CacheDType
,
device
:
torch
.
device
,
kernel_block_sizes
:
list
[
int
],
)
->
tuple
[
dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
,
type
[
AttentionBackend
]]:
"""
Initializes and reshapes KV caches for the simple case where all
layers have the same layout.
This function assumes use_uniform_kv_cache() returned True.
Args:
kv_cache_config: The KV cache config
attn_groups: The list of attention groups for this model
cache_dtype: The KV cache dtype
device: The torch device to allocate on.
kernel_block_sizes: The kernel block sizes for each KV cache group.
Returns:
A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where:
kv_caches is a dict mapping between layer names to their
corresponding memory buffer for KV cache.
cross_layers_kv_cache is the cross layers kv cache tensor
attn_backend is the attention backend matching this tensor
"""
attn_group
=
attn_groups
[
0
][
0
]
kv_cache_spec
=
attn_group
.
kv_cache_spec
assert
isinstance
(
kv_cache_spec
,
AttentionSpec
)
tensor_sizes
=
set
(
kv_cache_tensor
.
size
for
kv_cache_tensor
in
kv_cache_config
.
kv_cache_tensors
)
assert
len
(
tensor_sizes
)
==
1
tensor_size
=
tensor_sizes
.
pop
()
page_size
=
kv_cache_spec
.
page_size_bytes
assert
tensor_size
%
page_size
==
0
num_blocks
=
tensor_size
//
page_size
num_layers
=
len
(
kv_cache_config
.
kv_cache_tensors
)
total_size
=
tensor_size
*
num_layers
assert
len
(
kernel_block_sizes
)
==
1
kernel_block_size
=
kernel_block_sizes
[
0
]
num_blocks_per_kv_block
=
kv_cache_spec
.
block_size
//
kernel_block_size
kernel_num_blocks
=
num_blocks
*
num_blocks_per_kv_block
attn_backend
=
attn_group
.
backend
kv_cache_shape
=
attn_backend
.
get_kv_cache_shape
(
kernel_num_blocks
,
kernel_block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
,
cache_dtype_str
=
cache_dtype
,
)
# prepend a num_layers dimension into the shape
kv_cache_shape
=
(
num_layers
,)
+
kv_cache_shape
try
:
kv_cache_stride_order
=
attn_backend
.
get_kv_cache_stride_order
(
include_num_layers_dimension
=
True
)
assert
len
(
kv_cache_stride_order
)
==
len
(
kv_cache_shape
)
except
(
AttributeError
,
NotImplementedError
):
kv_cache_stride_order
=
tuple
(
range
(
len
(
kv_cache_shape
)))
kv_cache_shape
=
tuple
(
kv_cache_shape
[
i
]
for
i
in
kv_cache_stride_order
)
logger
.
info
(
"Allocating a cross layer KV cache of shape %s"
,
kv_cache_shape
)
# allocate one contiguous buffer for all layers
cross_layers_kv_cache
=
(
torch
.
zeros
(
total_size
,
dtype
=
torch
.
int8
,
device
=
device
)
.
view
(
kv_cache_spec
.
dtype
)
.
view
(
kv_cache_shape
)
)
# Maintain original KV shape view.
inv_order
=
[
kv_cache_stride_order
.
index
(
i
)
for
i
in
range
(
len
(
kv_cache_stride_order
))
]
permuted_kv_cache
=
cross_layers_kv_cache
.
permute
(
*
inv_order
)
kv_caches
=
{}
for
i
,
kv_cache_tensor
in
enumerate
(
kv_cache_config
.
kv_cache_tensors
):
tensor
=
permuted_kv_cache
[
i
]
for
layer_name
in
kv_cache_tensor
.
shared_by
:
kv_caches
[
layer_name
]
=
tensor
return
kv_caches
,
cross_layers_kv_cache
,
attn_backend
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment