"vllm/vscode:/vscode.git/clone" did not exist on "dd96465fd7447c0de06af796b6c9bd3949913ea6"
test_utils.py 3.63 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

from vllm.v1.worker.utils import bind_kv_cache


9
def test_bind_kv_cache(default_vllm_config):
10
    from vllm.attention.layer import Attention
11
12

    ctx = {
13
14
15
16
        "layers.0.self_attn": Attention(32, 128, 0.1, prefix="layers.0.self_attn"),
        "layers.1.self_attn": Attention(32, 128, 0.1, prefix="layers.1.self_attn"),
        "layers.2.self_attn": Attention(32, 128, 0.1, prefix="layers.2.self_attn"),
        "layers.3.self_attn": Attention(32, 128, 0.1, prefix="layers.3.self_attn"),
17
18
    }
    kv_cache = {
19
20
21
22
        "layers.0.self_attn": torch.zeros((1,)),
        "layers.1.self_attn": torch.zeros((1,)),
        "layers.2.self_attn": torch.zeros((1,)),
        "layers.3.self_attn": torch.zeros((1,)),
23
24
25
    }
    runner_kv_caches: list[torch.Tensor] = []
    bind_kv_cache(kv_cache, ctx, runner_kv_caches)
26
27
28
29
    assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache["layers.0.self_attn"]
    assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache["layers.1.self_attn"]
    assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache["layers.2.self_attn"]
    assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache["layers.3.self_attn"]
30

31
32
33
34
    assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"]
    assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"]
    assert runner_kv_caches[2] is kv_cache["layers.2.self_attn"]
    assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"]
35
36


37
def test_bind_kv_cache_non_attention(default_vllm_config):
38
    from vllm.attention.layer import Attention
39
40
41

    # example from Jamba PP=2
    ctx = {
42
43
        "model.layers.20.attn": Attention(32, 128, 0.1, prefix="model.layers.20.attn"),
        "model.layers.28.attn": Attention(32, 128, 0.1, prefix="model.layers.28.attn"),
44
45
    }
    kv_cache = {
46
47
        "model.layers.20.attn": torch.zeros((1,)),
        "model.layers.28.attn": torch.zeros((1,)),
48
49
50
51
52
    }

    runner_kv_caches: list[torch.Tensor] = []
    bind_kv_cache(kv_cache, ctx, runner_kv_caches)

53
54
    assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache["model.layers.20.attn"]
    assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache["model.layers.28.attn"]
55

56
57
    assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
    assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92


def test_bind_kv_cache_draft_model(default_vllm_config):
    from vllm.attention.layer import Attention

    layer_names = [
        "model.layers.0.attn",
        "model.layers.1.attn",
        "draft_model.layers.0.attn",
        "draft_model.layers.1.attn",
    ]
    ctx = {
        layer_name: Attention(32, 128, 0.1, prefix=layer_name)
        for layer_name in layer_names
    }
    kv_cache = {layer_name: torch.zeros((1,)) for layer_name in layer_names}
    runner_kv_caches: list[torch.Tensor] = []
    bind_kv_cache(kv_cache, ctx, runner_kv_caches)

    assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"]
    assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"]
    assert (
        ctx["draft_model.layers.0.attn"].kv_cache[0]
        is kv_cache["draft_model.layers.0.attn"]
    )
    assert (
        ctx["draft_model.layers.1.attn"].kv_cache[0]
        is kv_cache["draft_model.layers.1.attn"]
    )

    # caches are ordered by layer_index, interleaving target and draft model
    assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"]
    assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"]
    assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"]
    assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"]