Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bdf13965
Unverified
Commit
bdf13965
authored
Jun 03, 2025
by
Yong Hoon Shin
Committed by
GitHub
Jun 03, 2025
Browse files
[V1] Support cross-layer KV sharing (#18212)
Signed-off-by:
Yong Hoon Shin
<
yhshin@meta.com
>
parent
fa98d777
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
542 additions
and
28 deletions
+542
-28
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+226
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+237
-7
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+1
-0
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+3
-0
vllm/attention/backends/cpu_mla.py
vllm/attention/backends/cpu_mla.py
+2
-1
vllm/attention/backends/dual_chunk_flash_attn.py
vllm/attention/backends/dual_chunk_flash_attn.py
+3
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+3
-0
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+3
-0
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+2
-1
vllm/attention/backends/hpu_attn.py
vllm/attention/backends/hpu_attn.py
+3
-0
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+3
-0
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+3
-0
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+3
-0
vllm/attention/backends/rocm_aiter_mla.py
vllm/attention/backends/rocm_aiter_mla.py
+2
-1
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+3
-0
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+3
-0
vllm/attention/backends/triton_mla.py
vllm/attention/backends/triton_mla.py
+2
-1
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+3
-0
vllm/attention/layer.py
vllm/attention/layer.py
+16
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+21
-15
No files found.
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
bdf13965
...
@@ -4,8 +4,13 @@ import unittest.mock as mock
...
@@ -4,8 +4,13 @@ import unittest.mock as mock
import
pytest
import
pytest
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.attention.layer
import
Attention
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
,
set_current_vllm_config
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
GiB_bytes
from
vllm.v1.core.kv_cache_utils
import
(
estimate_max_model_len
,
get_kv_cache_config
)
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
SchedulerOutput
)
from
vllm.v1.worker.tpu_model_runner
import
(
from
vllm.v1.worker.tpu_model_runner
import
(
...
@@ -363,3 +368,223 @@ def test_get_req_paddings():
...
@@ -363,3 +368,223 @@ def test_get_req_paddings():
assert
_get_req_paddings
(
1
,
32
)
==
[
8
,
16
,
32
]
assert
_get_req_paddings
(
1
,
32
)
==
[
8
,
16
,
32
]
assert
_get_req_paddings
(
8
,
32
)
==
[
8
,
16
,
32
]
assert
_get_req_paddings
(
8
,
32
)
==
[
8
,
16
,
32
]
assert
_get_req_paddings
(
8
,
36
)
==
[
8
,
16
,
32
,
36
]
assert
_get_req_paddings
(
8
,
36
)
==
[
8
,
16
,
32
,
36
]
def
test_init_kv_cache_with_kv_sharing_invalid_target_layer_order
():
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
error_msg
=
f
"
{
layer_1
}
must come before the current layer"
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
):
fwd_context
=
{
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_0
,
kv_sharing_target_layer_name
=
layer_1
,
),
layer_1
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_1
,
)
}
# suppress var not used error
assert
fwd_context
is
not
None
def
test_init_kv_cache_with_kv_sharing_target_layer_not_exist
():
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
invalid_layer
=
"model.layers.0.cross_attn.attn"
error_msg
=
f
"
{
invalid_layer
}
is not a valid Attention layer in the model"
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
):
fwd_context
=
{
layer_0
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_0
,
),
layer_1
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_1
,
# invalid layer: cross_attn.atn doesn't exist!
kv_sharing_target_layer_name
=
invalid_layer
,
)
}
# suppress var not used error
assert
fwd_context
is
not
None
def
test_init_kv_cache_with_kv_sharing_target_same_as_current
():
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
error_msg
=
f
"
{
layer_1
}
cannot be the same as the current layer"
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
):
fwd_context
=
{
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_0
,
),
layer_1
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_1
,
kv_sharing_target_layer_name
=
layer_1
,
)
}
# suppress var not used error
assert
fwd_context
is
not
None
def
test_init_kv_cache_without_kv_sharing
(
model_runner
):
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
vllm_config
=
model_runner
.
vllm_config
with
set_current_vllm_config
(
vllm_config
):
fwd_context
=
{
layer_0
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_0
,
),
layer_1
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_1
,
)
}
# suppress var not used error
assert
fwd_context
is
not
None
# Set high context length to test max context length estimation
vllm_config
.
model_config
.
max_model_len
=
3_000_000
vllm_ctx
=
vllm_config
.
compilation_config
.
static_forward_context
kv_cache_spec
=
model_runner
.
get_kv_cache_spec
()
assert
len
(
kv_cache_spec
)
==
2
assert
len
(
model_runner
.
shared_kv_cache_layers
)
==
0
available_memory
=
20
*
GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
num_expected_blocks
=
327680
# 20GB / 32KB / 2 (num layers)
kv_cache_config
=
get_kv_cache_config
(
vllm_config
,
kv_cache_spec
,
available_memory
)
assert
kv_cache_config
.
num_blocks
==
num_expected_blocks
assert
len
(
kv_cache_config
.
tensors
)
==
2
assert
kv_cache_config
.
tensors
[
layer_0
].
size
==
available_memory
//
2
assert
kv_cache_config
.
tensors
[
layer_1
].
size
==
available_memory
//
2
max_context_len
=
\
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
5
*
GiB_bytes
)
# max context len with KV sharing should be 2x as large as without
assert
max_context_len
==
1310720
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 2 block worth of memory (2 * 32kb)
kv_cache_config
.
num_blocks
=
1
for
layer
in
kv_cache_config
.
tensors
:
kv_cache_config
.
tensors
[
layer
].
size
=
\
kv_cache_spec
[
layer
].
page_size_bytes
model_runner
.
initialize_kv_cache
(
kv_cache_config
)
layer_0_kv
=
vllm_ctx
[
layer_0
].
kv_cache
[
0
]
layer_1_kv
=
vllm_ctx
[
layer_1
].
kv_cache
[
0
]
# check layer 1 kv cache does NOT share memory with layer 0
assert
id
(
layer_1_kv
)
!=
id
(
layer_0_kv
)
# check layer 1 added to kv cache group's layer names
assert
len
(
kv_cache_config
.
kv_cache_groups
)
==
1
assert
len
(
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
)
==
2
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
0
]
==
layer_0
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
1
]
==
layer_1
def
test_init_kv_cache_with_kv_sharing_valid
(
model_runner
):
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
vllm_config
=
model_runner
.
vllm_config
with
set_current_vllm_config
(
vllm_config
):
fwd_context
=
{
layer_0
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_0
,
),
layer_1
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_1
,
kv_sharing_target_layer_name
=
"model.layers.0.self_attn.attn"
,
)
}
# suppress var not used error
assert
fwd_context
is
not
None
# Set high context length to test max context length estimation
vllm_config
.
model_config
.
max_model_len
=
3_000_000
vllm_ctx
=
vllm_config
.
compilation_config
.
static_forward_context
kv_cache_spec
=
model_runner
.
get_kv_cache_spec
()
assert
len
(
kv_cache_spec
)
==
1
assert
layer_0
in
kv_cache_spec
assert
model_runner
.
shared_kv_cache_layers
[
layer_1
]
==
layer_0
available_memory
=
20
*
GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks
=
655360
# 20GB / 32KB
kv_cache_config
=
get_kv_cache_config
(
vllm_config
,
kv_cache_spec
,
available_memory
)
assert
kv_cache_config
.
num_blocks
==
num_expected_blocks
assert
len
(
kv_cache_config
.
tensors
)
==
1
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
assert
kv_cache_config
.
tensors
[
layer_0
].
size
==
available_memory
max_context_len
=
\
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
5
*
GiB_bytes
)
# max context len with KV sharing should be 2x as large as without
assert
max_context_len
==
2
*
1310720
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (32kb)
kv_cache_config
.
num_blocks
=
1
kv_cache_config
.
tensors
[
layer_0
].
size
=
\
kv_cache_spec
[
layer_0
].
page_size_bytes
model_runner
.
initialize_kv_cache
(
kv_cache_config
)
layer_0_kv
=
vllm_ctx
[
layer_0
].
kv_cache
[
0
]
layer_1_kv
=
vllm_ctx
[
layer_1
].
kv_cache
[
0
]
# check layer 1 kv cache shares memory with layer 0
assert
id
(
layer_1_kv
)
==
id
(
layer_0_kv
)
# check layer 1 added to kv cache group's layer names
assert
len
(
kv_cache_config
.
kv_cache_groups
)
==
1
assert
len
(
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
)
==
2
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
0
]
==
layer_0
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
1
]
==
layer_1
tests/v1/worker/test_gpu_model_runner.py
View file @
bdf13965
...
@@ -7,8 +7,11 @@ import pytest
...
@@ -7,8 +7,11 @@ import pytest
from
vllm.attention
import
Attention
from
vllm.attention
import
Attention
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VllmConfig
)
SchedulerConfig
,
VllmConfig
,
set_current_vllm_config
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
GiB_bytes
from
vllm.v1.core.kv_cache_utils
import
(
estimate_max_model_len
,
get_kv_cache_config
)
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
SchedulerOutput
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
...
@@ -19,6 +22,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
...
@@ -19,6 +22,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
BLOCK_SIZE
=
16
BLOCK_SIZE
=
16
NUM_BLOCKS
=
10
NUM_BLOCKS
=
10
DEVICE
=
"cuda"
def
initialize_kv_cache
(
runner
:
GPUModelRunner
):
def
initialize_kv_cache
(
runner
:
GPUModelRunner
):
...
@@ -55,8 +59,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
...
@@ -55,8 +59,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
runner
.
initialize_attn_backend
(
kv_cache_config
)
runner
.
initialize_attn_backend
(
kv_cache_config
)
@
pytest
.
fixture
def
get_vllm_config
():
def
model_runner
():
scheduler_config
=
SchedulerConfig
(
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
10
,
max_num_seqs
=
10
,
max_num_batched_tokens
=
512
,
max_num_batched_tokens
=
512
,
...
@@ -84,13 +87,18 @@ def model_runner():
...
@@ -84,13 +87,18 @@ def model_runner():
scheduler_config
=
scheduler_config
,
scheduler_config
=
scheduler_config
,
parallel_config
=
parallel_config
,
parallel_config
=
parallel_config
,
)
)
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
return
vllm_config
@
pytest
.
fixture
def
model_runner
():
vllm_config
=
get_vllm_config
()
model_config
=
vllm_config
.
model_config
num_heads
=
model_config
.
get_num_kv_heads
(
vllm_config
.
parallel_config
)
head_size
=
model_config
.
get_head_size
()
head_size
=
model_config
.
get_head_size
()
vllm_config
.
compilation_config
.
static_forward_context
[
vllm_config
.
compilation_config
.
static_forward_context
[
"layer.0"
]
=
Attention
(
num_heads
,
head_size
,
0.1
)
"layer.0"
]
=
Attention
(
num_heads
,
head_size
,
0.1
)
runner
=
GPUModelRunner
(
vllm_config
,
DEVICE
)
device
=
"cuda"
runner
=
GPUModelRunner
(
vllm_config
,
device
)
initialize_kv_cache
(
runner
)
initialize_kv_cache
(
runner
)
return
runner
return
runner
...
@@ -385,3 +393,225 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
...
@@ -385,3 +393,225 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
model_runner_2
.
load_model
()
# Load real weights inplace
model_runner_2
.
load_model
()
# Load real weights inplace
assert
str
(
model_runner
.
get_model
().
state_dict
())
==
str
(
assert
str
(
model_runner
.
get_model
().
state_dict
())
==
str
(
model_runner_2
.
get_model
().
state_dict
())
model_runner_2
.
get_model
().
state_dict
())
def
test_init_kv_cache_with_kv_sharing_invalid_target_layer_order
():
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
error_msg
=
f
"
{
layer_1
}
must come before the current layer"
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
):
fwd_context
=
{
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_0
,
kv_sharing_target_layer_name
=
layer_1
,
),
layer_1
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_1
,
)
}
# suppress var not used error
assert
fwd_context
is
not
None
def
test_init_kv_cache_with_kv_sharing_target_layer_not_exist
():
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
invalid_layer
=
"model.layers.0.cross_attn.attn"
error_msg
=
f
"
{
invalid_layer
}
is not a valid Attention layer in the model"
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
):
fwd_context
=
{
layer_0
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_0
,
),
layer_1
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_1
,
# invalid layer: cross_attn.atn doesn't exist!
kv_sharing_target_layer_name
=
invalid_layer
,
)
}
# suppress var not used error
assert
fwd_context
is
not
None
def
test_init_kv_cache_with_kv_sharing_target_same_as_current
():
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
error_msg
=
f
"
{
layer_1
}
cannot be the same as the current layer"
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
):
fwd_context
=
{
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_0
,
),
layer_1
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_1
,
kv_sharing_target_layer_name
=
layer_1
,
)
}
# suppress var not used error
assert
fwd_context
is
not
None
def
test_init_kv_cache_without_kv_sharing
():
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
vllm_config
=
get_vllm_config
()
with
set_current_vllm_config
(
vllm_config
):
fwd_context
=
{
layer_0
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_0
,
),
layer_1
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_1
,
)
}
# suppress var not used error
assert
fwd_context
is
not
None
# Set high context length to test max context length estimation
vllm_config
.
model_config
.
max_model_len
=
3_000_000
vllm_ctx
=
vllm_config
.
compilation_config
.
static_forward_context
runner
=
GPUModelRunner
(
vllm_config
,
DEVICE
)
kv_cache_spec
=
runner
.
get_kv_cache_spec
()
assert
len
(
kv_cache_spec
)
==
2
assert
len
(
runner
.
shared_kv_cache_layers
)
==
0
available_memory
=
20
*
GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
num_expected_blocks
=
327680
# 20GB / 32KB / 2 (num layers)
kv_cache_config
=
get_kv_cache_config
(
vllm_config
,
kv_cache_spec
,
available_memory
)
assert
kv_cache_config
.
num_blocks
==
num_expected_blocks
assert
len
(
kv_cache_config
.
tensors
)
==
2
assert
kv_cache_config
.
tensors
[
layer_0
].
size
==
available_memory
//
2
assert
kv_cache_config
.
tensors
[
layer_1
].
size
==
available_memory
//
2
max_context_len
=
\
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
5
*
GiB_bytes
)
# max context len with KV sharing should be 2x as large as without
assert
max_context_len
==
1310720
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 2 block worth of memory (2 * 32kb)
kv_cache_config
.
num_blocks
=
1
for
layer
in
kv_cache_config
.
tensors
:
kv_cache_config
.
tensors
[
layer
].
size
=
\
kv_cache_spec
[
layer
].
page_size_bytes
runner
.
initialize_kv_cache
(
kv_cache_config
)
layer_0_kv
=
vllm_ctx
[
layer_0
].
kv_cache
[
0
]
layer_1_kv
=
vllm_ctx
[
layer_1
].
kv_cache
[
0
]
# check layer 1 kv cache does NOT share memory with layer 0
assert
id
(
layer_1_kv
)
!=
id
(
layer_0_kv
)
# check layer 1 added to kv cache group's layer names
assert
len
(
kv_cache_config
.
kv_cache_groups
)
==
1
assert
len
(
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
)
==
2
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
0
]
==
layer_0
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
1
]
==
layer_1
def
test_init_kv_cache_with_kv_sharing_valid
():
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
vllm_config
=
get_vllm_config
()
with
set_current_vllm_config
(
vllm_config
):
fwd_context
=
{
layer_0
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_0
,
),
layer_1
:
Attention
(
num_heads
=
8
,
head_size
=
64
,
scale
=
1.0
,
prefix
=
layer_1
,
kv_sharing_target_layer_name
=
"model.layers.0.self_attn.attn"
,
)
}
# suppress var not used error
assert
fwd_context
is
not
None
# Set high context length to test max context length estimation
vllm_config
.
model_config
.
max_model_len
=
3_000_000
vllm_ctx
=
vllm_config
.
compilation_config
.
static_forward_context
runner
=
GPUModelRunner
(
vllm_config
,
DEVICE
)
kv_cache_spec
=
runner
.
get_kv_cache_spec
()
assert
len
(
kv_cache_spec
)
==
1
assert
layer_0
in
kv_cache_spec
assert
runner
.
shared_kv_cache_layers
[
layer_1
]
==
layer_0
available_memory
=
20
*
GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks
=
655360
# 20GB / 32KB
kv_cache_config
=
get_kv_cache_config
(
vllm_config
,
kv_cache_spec
,
available_memory
)
assert
kv_cache_config
.
num_blocks
==
num_expected_blocks
assert
len
(
kv_cache_config
.
tensors
)
==
1
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
assert
kv_cache_config
.
tensors
[
layer_0
].
size
==
available_memory
max_context_len
=
\
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
5
*
GiB_bytes
)
# max context len with KV sharing should be 2x as large as without
assert
max_context_len
==
2
*
1310720
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (32kb)
kv_cache_config
.
num_blocks
=
1
kv_cache_config
.
tensors
[
layer_0
].
size
=
\
kv_cache_spec
[
layer_0
].
page_size_bytes
runner
.
initialize_kv_cache
(
kv_cache_config
)
layer_0_kv
=
vllm_ctx
[
layer_0
].
kv_cache
[
0
]
layer_1_kv
=
vllm_ctx
[
layer_1
].
kv_cache
[
0
]
# check layer 1 kv cache shares memory with layer 0
assert
id
(
layer_1_kv
)
==
id
(
layer_0_kv
)
# check layer 1 added to kv cache group's layer names
assert
len
(
kv_cache_config
.
kv_cache_groups
)
==
1
assert
len
(
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
)
==
2
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
0
]
==
layer_0
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
1
]
==
layer_1
vllm/attention/backends/abstract.py
View file @
bdf13965
...
@@ -270,6 +270,7 @@ class AttentionImpl(ABC, Generic[T]):
...
@@ -270,6 +270,7 @@ class AttentionImpl(ABC, Generic[T]):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
bdf13965
...
@@ -306,7 +306,10 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
...
@@ -306,7 +306,10 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
assert
blocksparse_params
is
not
None
assert
blocksparse_params
is
not
None
assert
alibi_slopes
is
None
,
ValueError
(
assert
alibi_slopes
is
None
,
ValueError
(
"Alibi not support for blocksparse flash attention."
)
"Alibi not support for blocksparse flash attention."
)
...
...
vllm/attention/backends/cpu_mla.py
View file @
bdf13965
...
@@ -206,12 +206,13 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
...
@@ -206,12 +206,13 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
# MLA Specific Arguments
# MLA Specific Arguments
**
mla_args
)
->
None
:
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
mla_args
)
kv_sharing_target_layer_name
,
**
mla_args
)
unsupported_features
=
[
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
...
...
vllm/attention/backends/dual_chunk_flash_attn.py
View file @
bdf13965
...
@@ -290,9 +290,12 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
...
@@ -290,9 +290,12 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
layer_idx
:
int
=
-
1
,
layer_idx
:
int
=
-
1
,
dual_chunk_attention_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dual_chunk_attention_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
...
vllm/attention/backends/flash_attn.py
View file @
bdf13965
...
@@ -618,8 +618,11 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -618,8 +618,11 @@ class FlashAttentionImpl(AttentionImpl):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
raise
ValueError
(
"FlashAttention does not support block-sparse attention."
)
"FlashAttention does not support block-sparse attention."
)
...
...
vllm/attention/backends/flashinfer.py
View file @
bdf13965
...
@@ -936,8 +936,11 @@ class FlashInferImpl(AttentionImpl):
...
@@ -936,8 +936,11 @@ class FlashInferImpl(AttentionImpl):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
use_irope
:
if
use_irope
:
logger
.
warning_once
(
logger
.
warning_once
(
"Using irope in FlashInfer is not supported yet, it will fall"
"Using irope in FlashInfer is not supported yet, it will fall"
...
...
vllm/attention/backends/flashmla.py
View file @
bdf13965
...
@@ -184,12 +184,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -184,12 +184,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
# MLA Specific Arguments
# MLA Specific Arguments
**
mla_args
)
->
None
:
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
mla_args
)
kv_sharing_target_layer_name
,
**
mla_args
)
assert
is_flashmla_supported
(),
\
assert
is_flashmla_supported
(),
\
"FlashMLA is not supported on this device"
"FlashMLA is not supported on this device"
...
...
vllm/attention/backends/hpu_attn.py
View file @
bdf13965
...
@@ -110,9 +110,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -110,9 +110,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_seq_len
:
int
=
4096
,
max_seq_len
:
int
=
4096
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
(
AttentionImpl
,
self
).
__init__
()
super
(
AttentionImpl
,
self
).
__init__
()
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
use_irope
:
if
use_irope
:
logger
.
warning_once
(
logger
.
warning_once
(
"Using irope in HPU is not supported yet, it will fall back "
"Using irope in HPU is not supported yet, it will fall back "
...
...
vllm/attention/backends/ipex_attn.py
View file @
bdf13965
...
@@ -123,8 +123,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -123,8 +123,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
use_irope
:
if
use_irope
:
logger
.
warning_once
(
logger
.
warning_once
(
"Using irope in Ipex is not supported yet, it will fall"
"Using irope in Ipex is not supported yet, it will fall"
...
...
vllm/attention/backends/mla/common.py
View file @
bdf13965
...
@@ -1000,6 +1000,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1000,6 +1000,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
# MLA Specific Arguments
# MLA Specific Arguments
q_lora_rank
:
Optional
[
int
],
q_lora_rank
:
Optional
[
int
],
kv_lora_rank
:
int
,
kv_lora_rank
:
int
,
...
@@ -1009,6 +1010,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1009,6 +1010,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
v_head_dim
:
int
,
v_head_dim
:
int
,
kv_b_proj
:
ColumnParallelLinear
,
kv_b_proj
:
ColumnParallelLinear
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing not supported in V0."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
...
vllm/attention/backends/pallas.py
View file @
bdf13965
...
@@ -109,8 +109,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -109,8 +109,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
use_irope
:
if
use_irope
:
logger
.
warning_once
(
logger
.
warning_once
(
"Using irope in Pallas is not supported yet, it will fall back "
"Using irope in Pallas is not supported yet, it will fall back "
...
...
vllm/attention/backends/rocm_aiter_mla.py
View file @
bdf13965
...
@@ -370,12 +370,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
...
@@ -370,12 +370,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
# MLA Specific Arguments
# MLA Specific Arguments
**
mla_args
)
->
None
:
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
mla_args
)
kv_sharing_target_layer_name
,
**
mla_args
)
unsupported_features
=
[
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
bdf13965
...
@@ -494,8 +494,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -494,8 +494,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
use_irope
:
if
use_irope
:
logger
.
warning_once
(
logger
.
warning_once
(
"Using irope in ROCm Flash Attention is not supported yet, it "
"Using irope in ROCm Flash Attention is not supported yet, it "
...
...
vllm/attention/backends/torch_sdpa.py
View file @
bdf13965
...
@@ -405,8 +405,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -405,8 +405,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
raise
ValueError
(
"Torch SPDA does not support block-sparse attention."
)
"Torch SPDA does not support block-sparse attention."
)
...
...
vllm/attention/backends/triton_mla.py
View file @
bdf13965
...
@@ -38,12 +38,13 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -38,12 +38,13 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
# MLA Specific Arguments
# MLA Specific Arguments
**
mla_args
)
->
None
:
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
mla_args
)
kv_sharing_target_layer_name
,
**
mla_args
)
unsupported_features
=
[
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
...
...
vllm/attention/backends/xformers.py
View file @
bdf13965
...
@@ -390,8 +390,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -390,8 +390,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
raise
ValueError
(
"XFormers does not support block-sparse attention."
)
"XFormers does not support block-sparse attention."
)
...
...
vllm/attention/layer.py
View file @
bdf13965
...
@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.utils
import
validate_kv_sharing_target
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
...
@@ -50,6 +51,7 @@ class Attention(nn.Module):
...
@@ -50,6 +51,7 @@ class Attention(nn.Module):
use_mla
:
bool
=
False
,
use_mla
:
bool
=
False
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
**
extra_impl_args
,
**
extra_impl_args
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -135,7 +137,7 @@ class Attention(nn.Module):
...
@@ -135,7 +137,7 @@ class Attention(nn.Module):
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
extra_impl_args
)
kv_sharing_target_layer_name
,
**
extra_impl_args
)
self
.
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
self
.
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
self
.
dtype
=
dtype
self
.
dtype
=
dtype
...
@@ -153,6 +155,19 @@ class Attention(nn.Module):
...
@@ -153,6 +155,19 @@ class Attention(nn.Module):
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_forward_context
[
prefix
]
=
self
self
.
layer_name
=
prefix
self
.
layer_name
=
prefix
self
.
attn_type
=
attn_type
self
.
attn_type
=
attn_type
if
kv_sharing_target_layer_name
is
not
None
:
if
not
envs
.
VLLM_USE_V1
:
raise
NotImplementedError
(
"Cross-layer KV sharing is not supported in V0."
)
validate_kv_sharing_target
(
prefix
,
kv_sharing_target_layer_name
,
compilation_config
.
static_forward_context
,
)
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
# use a placeholder kv cache tensor during init, which will be replaced
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
# this variable will not be accessed if use_direct_call is True
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
bdf13965
...
@@ -485,6 +485,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -485,6 +485,7 @@ class FlashAttentionImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
...
@@ -506,6 +507,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -506,6 +507,7 @@ class FlashAttentionImpl(AttentionImpl):
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap
=
0
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
self
.
logits_soft_cap
=
logits_soft_cap
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
@@ -569,22 +571,26 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -569,22 +571,26 @@ class FlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
# performance to make sure it does not introduce any overhead.
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
# Reshape the input keys and values and store them in the cache.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens] and
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
# the slot_mapping's shape to determine the number of actual tokens.
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
if
self
.
kv_sharing_target_layer_name
is
None
:
value
,
# Reshape the input keys and values and store them in the cache.
key_cache
,
# Skip this if sharing KV cache with an earlier attention layer.
value_cache
,
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
attn_metadata
.
slot_mapping
,
# not padded. However, we don't need to do key[:num_actual_tokens]
self
.
kv_cache_dtype
,
# and value[:num_actual_tokens] because the reshape_and_cache_flash
layer
.
_k_scale
,
# op uses the slot_mapping's shape to determine the number of
layer
.
_v_scale
,
# actual tokens.
)
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
key_cache
=
key_cache
.
view
(
torch
.
float8_e4m3fn
)
key_cache
=
key_cache
.
view
(
torch
.
float8_e4m3fn
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment