test_utils.py 2.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

from vllm.v1.worker.utils import bind_kv_cache


def test_bind_kv_cache():
    from vllm.attention import Attention

    ctx = {
        'layers.0.self_attn': Attention(32, 128, 0.1),
        'layers.1.self_attn': Attention(32, 128, 0.1),
        'layers.2.self_attn': Attention(32, 128, 0.1),
        'layers.3.self_attn': Attention(32, 128, 0.1),
    }
    kv_cache = {
        'layers.0.self_attn': torch.zeros((1, )),
        'layers.1.self_attn': torch.zeros((1, )),
        'layers.2.self_attn': torch.zeros((1, )),
        'layers.3.self_attn': torch.zeros((1, )),
    }
    runner_kv_caches: list[torch.Tensor] = []
    bind_kv_cache(kv_cache, ctx, runner_kv_caches)
    assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
        'layers.0.self_attn']
    assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[
        'layers.1.self_attn']
    assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[
        'layers.2.self_attn']
    assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[
        'layers.3.self_attn']

    assert runner_kv_caches[0] is kv_cache['layers.0.self_attn']
    assert runner_kv_caches[1] is kv_cache['layers.1.self_attn']
    assert runner_kv_caches[2] is kv_cache['layers.2.self_attn']
    assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']


def test_bind_kv_cache_non_attention():
    from vllm.attention import Attention

    # example from Jamba PP=2
    ctx = {
        'model.layers.20.attn': Attention(32, 128, 0.1),
        'model.layers.28.attn': Attention(32, 128, 0.1),
    }
    kv_cache = {
        'model.layers.20.attn': torch.zeros((1, )),
        'model.layers.28.attn': torch.zeros((1, )),
    }

    runner_kv_caches: list[torch.Tensor] = []
    bind_kv_cache(kv_cache, ctx, runner_kv_caches)

    assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
        'model.layers.20.attn']
    assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[
        'model.layers.28.attn']

    assert runner_kv_caches[0] is kv_cache['model.layers.20.attn']
    assert runner_kv_caches[1] is kv_cache['model.layers.28.attn']