Unverified Commit bdf13965 authored by Yong Hoon Shin's avatar Yong Hoon Shin Committed by GitHub
Browse files

[V1] Support cross-layer KV sharing (#18212)


Signed-off-by: default avatarYong Hoon Shin <yhshin@meta.com>
parent fa98d777
......@@ -4,8 +4,13 @@ import unittest.mock as mock
import pytest
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.attention.layer import Attention
from vllm.config import (CacheConfig, ModelConfig, SchedulerConfig, VllmConfig,
set_current_vllm_config)
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
get_kv_cache_config)
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.worker.tpu_model_runner import (
......@@ -363,3 +368,223 @@ def test_get_req_paddings():
assert _get_req_paddings(1, 32) == [8, 16, 32]
assert _get_req_paddings(8, 32) == [8, 16, 32]
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer"
with pytest.raises(ValueError, match=error_msg):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
kv_sharing_target_layer_name=layer_1,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
)
}
# suppress var not used error
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn"
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
with pytest.raises(ValueError, match=error_msg):
fwd_context = {
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
# invalid layer: cross_attn.atn doesn't exist!
kv_sharing_target_layer_name=invalid_layer,
)
}
# suppress var not used error
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer"
with pytest.raises(ValueError, match=error_msg):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
kv_sharing_target_layer_name=layer_1,
)
}
# suppress var not used error
assert fwd_context is not None
def test_init_kv_cache_without_kv_sharing(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = model_runner.vllm_config
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
)
}
# suppress var not used error
assert fwd_context is not None
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 2
assert len(model_runner.shared_kv_cache_layers) == 0
available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.tensors) == 2
assert kv_cache_config.tensors[layer_0].size == available_memory // 2
assert kv_cache_config.tensors[layer_1].size == available_memory // 2
max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 1310720
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 2 block worth of memory (2 * 32kb)
kv_cache_config.num_blocks = 1
for layer in kv_cache_config.tensors:
kv_cache_config.tensors[layer].size =\
kv_cache_spec[layer].page_size_bytes
model_runner.initialize_kv_cache(kv_cache_config)
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
# check layer 1 kv cache does NOT share memory with layer 0
assert id(layer_1_kv) != id(layer_0_kv)
# check layer 1 added to kv cache group's layer names
assert len(kv_cache_config.kv_cache_groups) == 1
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_init_kv_cache_with_kv_sharing_valid(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = model_runner.vllm_config
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
)
}
# suppress var not used error
assert fwd_context is not None
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 1
assert layer_0 in kv_cache_spec
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0
available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks = 655360 # 20GB / 32KB
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.tensors) == 1
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
assert kv_cache_config.tensors[layer_0].size == available_memory
max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 2 * 1310720
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (32kb)
kv_cache_config.num_blocks = 1
kv_cache_config.tensors[layer_0].size =\
kv_cache_spec[layer_0].page_size_bytes
model_runner.initialize_kv_cache(kv_cache_config)
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
# check layer 1 kv cache shares memory with layer 0
assert id(layer_1_kv) == id(layer_0_kv)
# check layer 1 added to kv cache group's layer names
assert len(kv_cache_config.kv_cache_groups) == 1
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
......@@ -7,8 +7,11 @@ import pytest
from vllm.attention import Attention
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig)
SchedulerConfig, VllmConfig, set_current_vllm_config)
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
get_kv_cache_config)
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
......@@ -19,6 +22,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
BLOCK_SIZE = 16
NUM_BLOCKS = 10
DEVICE = "cuda"
def initialize_kv_cache(runner: GPUModelRunner):
......@@ -55,8 +59,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
runner.initialize_attn_backend(kv_cache_config)
@pytest.fixture
def model_runner():
def get_vllm_config():
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
......@@ -84,13 +87,18 @@ def model_runner():
scheduler_config=scheduler_config,
parallel_config=parallel_config,
)
num_heads = model_config.get_num_kv_heads(parallel_config)
return vllm_config
@pytest.fixture
def model_runner():
vllm_config = get_vllm_config()
model_config = vllm_config.model_config
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
head_size = model_config.get_head_size()
vllm_config.compilation_config.static_forward_context[
"layer.0"] = Attention(num_heads, head_size, 0.1)
device = "cuda"
runner = GPUModelRunner(vllm_config, device)
runner = GPUModelRunner(vllm_config, DEVICE)
initialize_kv_cache(runner)
return runner
......@@ -385,3 +393,225 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
model_runner_2.load_model() # Load real weights inplace
assert str(model_runner.get_model().state_dict()) == str(
model_runner_2.get_model().state_dict())
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer"
with pytest.raises(ValueError, match=error_msg):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
kv_sharing_target_layer_name=layer_1,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
)
}
# suppress var not used error
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn"
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
with pytest.raises(ValueError, match=error_msg):
fwd_context = {
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
# invalid layer: cross_attn.atn doesn't exist!
kv_sharing_target_layer_name=invalid_layer,
)
}
# suppress var not used error
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer"
with pytest.raises(ValueError, match=error_msg):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
kv_sharing_target_layer_name=layer_1,
)
}
# suppress var not used error
assert fwd_context is not None
def test_init_kv_cache_without_kv_sharing():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
)
}
# suppress var not used error
assert fwd_context is not None
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
runner = GPUModelRunner(vllm_config, DEVICE)
kv_cache_spec = runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 2
assert len(runner.shared_kv_cache_layers) == 0
available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.tensors) == 2
assert kv_cache_config.tensors[layer_0].size == available_memory // 2
assert kv_cache_config.tensors[layer_1].size == available_memory // 2
max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 1310720
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 2 block worth of memory (2 * 32kb)
kv_cache_config.num_blocks = 1
for layer in kv_cache_config.tensors:
kv_cache_config.tensors[layer].size =\
kv_cache_spec[layer].page_size_bytes
runner.initialize_kv_cache(kv_cache_config)
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
# check layer 1 kv cache does NOT share memory with layer 0
assert id(layer_1_kv) != id(layer_0_kv)
# check layer 1 added to kv cache group's layer names
assert len(kv_cache_config.kv_cache_groups) == 1
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_init_kv_cache_with_kv_sharing_valid():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
)
}
# suppress var not used error
assert fwd_context is not None
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
runner = GPUModelRunner(vllm_config, DEVICE)
kv_cache_spec = runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 1
assert layer_0 in kv_cache_spec
assert runner.shared_kv_cache_layers[layer_1] == layer_0
available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks = 655360 # 20GB / 32KB
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.tensors) == 1
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
assert kv_cache_config.tensors[layer_0].size == available_memory
max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 2 * 1310720
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (32kb)
kv_cache_config.num_blocks = 1
kv_cache_config.tensors[layer_0].size =\
kv_cache_spec[layer_0].page_size_bytes
runner.initialize_kv_cache(kv_cache_config)
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
# check layer 1 kv cache shares memory with layer 0
assert id(layer_1_kv) == id(layer_0_kv)
# check layer 1 added to kv cache group's layer names
assert len(kv_cache_config.kv_cache_groups) == 1
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
......@@ -270,6 +270,7 @@ class AttentionImpl(ABC, Generic[T]):
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
raise NotImplementedError
......
......@@ -306,7 +306,10 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
"Alibi not support for blocksparse flash attention.")
......
......@@ -206,12 +206,13 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
......
......@@ -290,9 +290,12 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
layer_idx: int = -1,
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
......
......@@ -618,8 +618,11 @@ class FlashAttentionImpl(AttentionImpl):
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
......
......@@ -936,8 +936,11 @@ class FlashInferImpl(AttentionImpl):
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if use_irope:
logger.warning_once(
"Using irope in FlashInfer is not supported yet, it will fall"
......
......@@ -184,12 +184,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str] = None,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
kv_sharing_target_layer_name, **mla_args)
assert is_flashmla_supported(), \
"FlashMLA is not supported on this device"
......
......@@ -110,9 +110,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
super(AttentionImpl, self).__init__()
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if use_irope:
logger.warning_once(
"Using irope in HPU is not supported yet, it will fall back "
......
......@@ -123,8 +123,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if use_irope:
logger.warning_once(
"Using irope in Ipex is not supported yet, it will fall"
......
......@@ -1000,6 +1000,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
q_lora_rank: Optional[int],
kv_lora_rank: int,
......@@ -1009,6 +1010,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing not supported in V0.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
......
......@@ -109,8 +109,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if use_irope:
logger.warning_once(
"Using irope in Pallas is not supported yet, it will fall back "
......
......@@ -370,12 +370,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
......
......@@ -494,8 +494,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if use_irope:
logger.warning_once(
"Using irope in ROCm Flash Attention is not supported yet, it "
......
......@@ -405,8 +405,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if blocksparse_params is not None:
raise ValueError(
"Torch SPDA does not support block-sparse attention.")
......
......@@ -38,12 +38,13 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
......
......@@ -390,8 +390,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if blocksparse_params is not None:
raise ValueError(
"XFormers does not support block-sparse attention.")
......
......@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
class Attention(nn.Module):
......@@ -50,6 +51,7 @@ class Attention(nn.Module):
use_mla: bool = False,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
**extra_impl_args,
) -> None:
"""
......@@ -135,7 +137,7 @@ class Attention(nn.Module):
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**extra_impl_args)
kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype
......@@ -153,6 +155,19 @@ class Attention(nn.Module):
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.attn_type = attn_type
if kv_sharing_target_layer_name is not None:
if not envs.VLLM_USE_V1:
raise NotImplementedError(
"Cross-layer KV sharing is not supported in V0.")
validate_kv_sharing_target(
prefix,
kv_sharing_target_layer_name,
compilation_config.static_forward_context,
)
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
......
......@@ -485,6 +485,7 @@ class FlashAttentionImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
......@@ -506,6 +507,7 @@ class FlashAttentionImpl(AttentionImpl):
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
......@@ -569,12 +571,16 @@ class FlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens] and
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
# the slot_mapping's shape to determine the number of actual tokens.
key_cache, value_cache = kv_cache.unbind(0)
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment