test_model_input.py 3.83 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
import dataclasses

import torch

8
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
9
from vllm.attention.backends.abstract import AttentionBackend
10
from vllm.attention.backends.utils import CommonAttentionState
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from vllm.model_executor import SamplingMetadata
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata


class MockAttentionBackend(AttentionBackend):

    @staticmethod
    def get_name() -> str:
        raise NotImplementedError

    @staticmethod
    def get_impl_cls():
        raise NotImplementedError

    @staticmethod
26
    def get_metadata_cls() -> type["AttentionMetadata"]:
27
28
        return AttentionMetadata

29
    @staticmethod
30
    def get_builder_cls() -> type["AttentionMetadataBuilder"]:
31
32
33
        return AttentionMetadataBuilder

    @staticmethod
34
    def get_state_cls() -> type["CommonAttentionState"]:
35
        return CommonAttentionState
36

37
38
39
40
41
42
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
43
    ) -> tuple[int, ...]:
44
45
46
47
48
49
50
51
52
53
54
55
        raise NotImplementedError

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: torch.Tensor,
    ) -> None:
        pass

    @staticmethod
    def copy_blocks(
56
        kv_caches: list[torch.Tensor],
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        src_to_dists: torch.Tensor,
    ) -> None:
        pass


def test_model_runner_input():
    sampling_metadata = SamplingMetadata(
        ["seq_group"],
        "selected_token_indices",
        "categorized_sample_indices",
        "num_prompts",
    )
    attn_metadata = AttentionMetadata(
        num_prefills=1,
        num_prefill_tokens=2,
        num_decode_tokens=3,
        slot_mapping=torch.zeros(1),
74
        multi_modal_placeholder_index_maps=None,
75
        enable_kv_scales_calculation=True,
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    )
    model_input = ModelInputForGPUWithSamplingMetadata(
        input_tokens=torch.ones(10),
        input_positions=torch.ones(10),
        sampling_metadata=sampling_metadata,
        attn_metadata=attn_metadata)

    assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)

    # Test round trip serialization.
    tensor_dict = model_input.as_broadcastable_tensor_dict()
    attn_backend = MockAttentionBackend()
    received_model_input = (
        ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
            tensor_dict, attn_backend=attn_backend))
    # Check that received copy has correct values.
    assert isinstance(received_model_input,
                      ModelInputForGPUWithSamplingMetadata)
    assert received_model_input.input_tokens is not None
    assert (
        received_model_input.input_tokens == model_input.input_tokens).all()
    assert received_model_input.input_positions is not None
    assert (received_model_input.input_positions == model_input.input_positions
            ).all()
    assert received_model_input.multi_modal_kwargs is None
    assert (received_model_input.multi_modal_kwargs ==
            model_input.multi_modal_kwargs)
    assert received_model_input.lora_requests is None
    assert received_model_input.lora_requests == model_input.lora_requests
    assert received_model_input.lora_mapping is None
    assert received_model_input.lora_mapping == model_input.lora_mapping
    for field in dataclasses.fields(AttentionMetadata):
        assert getattr(received_model_input.attn_metadata, field.name,
                       None) == getattr(attn_metadata, field.name, None)
    # For sampling metadata, only selected_token_indices is copied.
    assert (received_model_input.sampling_metadata.selected_token_indices ==
            sampling_metadata.selected_token_indices)
    assert received_model_input.sampling_metadata.seq_groups is None