test_kv_sharing.py 6.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from unittest.mock import Mock

import torch

from vllm.v1.attention.backends.flash_attn import (
    FlashAttentionBackend, FlashAttentionMetadataBuilder)
from vllm.v1.attention.backends.flex_attention import (
    FlexAttentionBackend, FlexAttentionMetadataBuilder)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec
from vllm.v1.worker.utils import (AttentionGroup,
                                  initialize_kv_cache_for_kv_sharing)


def new_kv_cache_spec():
    return FullAttentionSpec(16, 1, 1, torch.float32, False)


def test_initialize_kv_cache_for_kv_sharing_different_attn_groups():
    """
    Test initializing KV cache sharing with different attention groups.
    Layers in the same KV cache group might be placed in different attn groups
    if they have different attention backends.
    """
    shared_kv_cache_layers = {
        "model.layers.2": "model.layers.0",
        "model.layers.3": "model.layers.1",
    }

    # Layers 0 and 1 both belong in KV cache group 0
33
    # However, if they have different attention backends, they will be
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    # placed in different attention groups for KV cache group 0
    kv_cache_groups = [
        KVCacheGroupSpec(["model.layers.0", "model.layers.1"],
                         new_kv_cache_spec()),
    ]

    attn_groups = [
        # KV cache group 0 has two attention groups
        [
            AttentionGroup(
                backend=FlashAttentionBackend,
                metadata_builder=Mock(spec=FlashAttentionMetadataBuilder),
                layer_names=["model.layers.0"],
            ),
            AttentionGroup(
                backend=FlexAttentionBackend,
                metadata_builder=Mock(spec=FlexAttentionMetadataBuilder),
                layer_names=["model.layers.1"],
            ),
        ],
    ]

    # Only layers 0 and 1 will have KV caches allocated
    kv_caches = {
        "model.layers.0": torch.zeros(1, 2, 3),
        "model.layers.1": torch.ones(1, 2, 3),
    }

    initialize_kv_cache_for_kv_sharing(
        shared_kv_cache_layers=shared_kv_cache_layers,
        kv_cache_groups=kv_cache_groups,
        kv_caches=kv_caches,
        attn_groups=attn_groups,
    )

    # Check that the KV caches were shared correctly
    assert kv_caches["model.layers.2"].data_ptr(
    ) == kv_caches["model.layers.0"].data_ptr()
    assert kv_caches["model.layers.3"].data_ptr(
    ) == kv_caches["model.layers.1"].data_ptr()

    # Check that the layers were added to the correct KV cache group
    assert len(kv_cache_groups) == 1
    assert kv_cache_groups[0].layer_names == [
        "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3"
    ]

    # Check that the layers were added to the attention groups
    assert len(attn_groups) == 1 and len(attn_groups[0]) == 2
    assert attn_groups[0][0].layer_names == [
        "model.layers.0", "model.layers.2"
    ]
    assert attn_groups[0][1].layer_names == [
        "model.layers.1", "model.layers.3"
    ]


def test_initialize_kv_cache_for_kv_sharing_same_attn_groups():
    """
    Test case assuming that all layers in the same KV cache group have the same
    attention backends. This is true for most models.
    """
    shared_kv_cache_layers = {
        "model.layers.2": "model.layers.0",
        "model.layers.3": "model.layers.1",
    }

    kv_cache_groups = [
        KVCacheGroupSpec(["model.layers.0", "model.layers.1"],
                         new_kv_cache_spec()),
    ]

    attn_groups = [
        # KV cache group 0 has a single attention group
        # as all layers have the same flash attention backend
        [
            AttentionGroup(
                backend=FlashAttentionBackend,
                metadata_builder=Mock(spec=FlashAttentionMetadataBuilder),
                layer_names=["model.layers.0", "model.layers.1"],
            ),
        ],
    ]

    kv_caches = {
        "model.layers.0": torch.zeros(1, 2, 3),
        "model.layers.1": torch.ones(1, 2, 3),
    }

    initialize_kv_cache_for_kv_sharing(
        shared_kv_cache_layers=shared_kv_cache_layers,
        kv_cache_groups=kv_cache_groups,
        kv_caches=kv_caches,
        attn_groups=attn_groups,
    )

    # Check that the KV caches were shared correctly
    assert kv_caches["model.layers.2"].data_ptr(
    ) == kv_caches["model.layers.0"].data_ptr()
    assert kv_caches["model.layers.3"].data_ptr(
    ) == kv_caches["model.layers.1"].data_ptr()

    # Check that the layers were added to the correct KV cache group
    assert len(kv_cache_groups) == 1
    assert kv_cache_groups[0].layer_names == [
        "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3"
    ]

    # Check that the layers were added to the attention groups
    assert len(attn_groups) == 1 and len(attn_groups[0]) == 1
    assert attn_groups[0][0].layer_names == [
        "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3"
    ]


def test_initialize_kv_cache_for_kv_sharing_no_attn_groups():
    """
    Test KV sharing set up when no attention groups are provided.
    This is the case for the TPU model runner, which doesn't have 
    support for attention groups yet.
    """
    shared_kv_cache_layers = {
        "model.layers.2": "model.layers.0",
        "model.layers.3": "model.layers.1",
    }

    kv_cache_groups = [
        KVCacheGroupSpec(["model.layers.0"], new_kv_cache_spec()),
        KVCacheGroupSpec(["model.layers.1"], new_kv_cache_spec()),
    ]

    kv_caches = {
        "model.layers.0": torch.zeros(1, 2, 3),
        "model.layers.1": torch.ones(1, 2, 3),
    }

    initialize_kv_cache_for_kv_sharing(
        shared_kv_cache_layers=shared_kv_cache_layers,
        kv_cache_groups=kv_cache_groups,
        kv_caches=kv_caches,
    )

    # Check that the KV caches were shared correctly
    assert kv_caches["model.layers.2"].data_ptr(
    ) == kv_caches["model.layers.0"].data_ptr()
    assert kv_caches["model.layers.3"].data_ptr(
    ) == kv_caches["model.layers.1"].data_ptr()

    # Check that the layers were added to the correct KV cache group
    assert len(kv_cache_groups) == 2
    assert kv_cache_groups[0].layer_names == [
        "model.layers.0", "model.layers.2"
    ]
    assert kv_cache_groups[1].layer_names == [
        "model.layers.1", "model.layers.3"
    ]