"vllm/tool_parsers/abstract_tool_parser.py" did not exist on "08287ef6751e79a89bf4f060f5f9545560a6de12"
test_utils.py 3.68 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

from vllm.v1.worker.utils import bind_kv_cache


9
def test_bind_kv_cache(default_vllm_config):
10
    from vllm.model_executor.layers.attention import Attention
11
12

    ctx = {
13
14
15
16
        "layers.0.self_attn": Attention(32, 128, 0.1, prefix="layers.0.self_attn"),
        "layers.1.self_attn": Attention(32, 128, 0.1, prefix="layers.1.self_attn"),
        "layers.2.self_attn": Attention(32, 128, 0.1, prefix="layers.2.self_attn"),
        "layers.3.self_attn": Attention(32, 128, 0.1, prefix="layers.3.self_attn"),
17
18
    }
    kv_cache = {
19
20
21
22
        "layers.0.self_attn": torch.zeros((1,)),
        "layers.1.self_attn": torch.zeros((1,)),
        "layers.2.self_attn": torch.zeros((1,)),
        "layers.3.self_attn": torch.zeros((1,)),
23
24
25
    }
    runner_kv_caches: list[torch.Tensor] = []
    bind_kv_cache(kv_cache, ctx, runner_kv_caches)
26
27
28
29
    assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache["layers.0.self_attn"]
    assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache["layers.1.self_attn"]
    assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache["layers.2.self_attn"]
    assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache["layers.3.self_attn"]
30

31
32
33
34
    assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"]
    assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"]
    assert runner_kv_caches[2] is kv_cache["layers.2.self_attn"]
    assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"]
35
36


37
def test_bind_kv_cache_non_attention(default_vllm_config):
38
    from vllm.model_executor.layers.attention import Attention
39
40
41

    # example from Jamba PP=2
    ctx = {
42
43
        "model.layers.20.attn": Attention(32, 128, 0.1, prefix="model.layers.20.attn"),
        "model.layers.28.attn": Attention(32, 128, 0.1, prefix="model.layers.28.attn"),
44
45
    }
    kv_cache = {
46
47
        "model.layers.20.attn": torch.zeros((1,)),
        "model.layers.28.attn": torch.zeros((1,)),
48
49
50
51
52
    }

    runner_kv_caches: list[torch.Tensor] = []
    bind_kv_cache(kv_cache, ctx, runner_kv_caches)

53
54
    assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache["model.layers.20.attn"]
    assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache["model.layers.28.attn"]
55

56
57
    assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
    assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]
58
59
60


def test_bind_kv_cache_draft_model(default_vllm_config):
61
    from vllm.model_executor.layers.attention import Attention
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

    layer_names = [
        "model.layers.0.attn",
        "model.layers.1.attn",
        "draft_model.layers.0.attn",
        "draft_model.layers.1.attn",
    ]
    ctx = {
        layer_name: Attention(32, 128, 0.1, prefix=layer_name)
        for layer_name in layer_names
    }
    kv_cache = {layer_name: torch.zeros((1,)) for layer_name in layer_names}
    runner_kv_caches: list[torch.Tensor] = []
    bind_kv_cache(kv_cache, ctx, runner_kv_caches)

    assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"]
    assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"]
    assert (
        ctx["draft_model.layers.0.attn"].kv_cache[0]
        is kv_cache["draft_model.layers.0.attn"]
    )
    assert (
        ctx["draft_model.layers.1.attn"].kv_cache[0]
        is kv_cache["draft_model.layers.1.attn"]
    )

    # caches are ordered by layer_index, interleaving target and draft model
    assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"]
    assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"]
    assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"]
    assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"]