test_model_input.py 9.34 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
import dataclasses

import torch

8
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
9
from vllm.attention.backends.abstract import AttentionBackend
10
from vllm.attention.backends.utils import CommonAttentionState
11
12
13
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
14
from vllm.worker.multi_step_model_runner import StatefulModelInput
15
16
from vllm.worker.pooling_model_runner import (
    ModelInputForGPUWithPoolingMetadata)
17
18
19
20
21
22
23
24
25
26
27
28
29


class MockAttentionBackend(AttentionBackend):

    @staticmethod
    def get_name() -> str:
        raise NotImplementedError

    @staticmethod
    def get_impl_cls():
        raise NotImplementedError

    @staticmethod
30
    def get_metadata_cls() -> type["AttentionMetadata"]:
31
32
        return AttentionMetadata

33
    @staticmethod
34
    def get_builder_cls() -> type["AttentionMetadataBuilder"]:
35
36
37
        return AttentionMetadataBuilder

    @staticmethod
38
    def get_state_cls() -> type["CommonAttentionState"]:
39
        return CommonAttentionState
40

41
42
43
44
45
46
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
47
    ) -> tuple[int, ...]:
48
49
50
51
52
53
54
55
56
57
58
59
        raise NotImplementedError

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: torch.Tensor,
    ) -> None:
        pass

    @staticmethod
    def copy_blocks(
60
        kv_caches: list[torch.Tensor],
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        src_to_dists: torch.Tensor,
    ) -> None:
        pass


def test_model_runner_input():
    sampling_metadata = SamplingMetadata(
        ["seq_group"],
        "selected_token_indices",
        "categorized_sample_indices",
        "num_prompts",
    )
    attn_metadata = AttentionMetadata(
        num_prefills=1,
        num_prefill_tokens=2,
        num_decode_tokens=3,
        slot_mapping=torch.zeros(1),
78
        multi_modal_placeholder_index_maps=None,
79
        enable_kv_scales_calculation=True,
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    )
    model_input = ModelInputForGPUWithSamplingMetadata(
        input_tokens=torch.ones(10),
        input_positions=torch.ones(10),
        sampling_metadata=sampling_metadata,
        attn_metadata=attn_metadata)

    assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)

    # Test round trip serialization.
    tensor_dict = model_input.as_broadcastable_tensor_dict()
    attn_backend = MockAttentionBackend()
    received_model_input = (
        ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
            tensor_dict, attn_backend=attn_backend))
    # Check that received copy has correct values.
    assert isinstance(received_model_input,
                      ModelInputForGPUWithSamplingMetadata)
    assert received_model_input.input_tokens is not None
    assert (
        received_model_input.input_tokens == model_input.input_tokens).all()
    assert received_model_input.input_positions is not None
    assert (received_model_input.input_positions == model_input.input_positions
            ).all()
    assert received_model_input.multi_modal_kwargs is None
    assert (received_model_input.multi_modal_kwargs ==
            model_input.multi_modal_kwargs)
    assert received_model_input.lora_requests is None
    assert received_model_input.lora_requests == model_input.lora_requests
    assert received_model_input.lora_mapping is None
    assert received_model_input.lora_mapping == model_input.lora_mapping
    for field in dataclasses.fields(AttentionMetadata):
        assert getattr(received_model_input.attn_metadata, field.name,
                       None) == getattr(attn_metadata, field.name, None)
    # For sampling metadata, only selected_token_indices is copied.
    assert (received_model_input.sampling_metadata.selected_token_indices ==
            sampling_metadata.selected_token_indices)
    assert received_model_input.sampling_metadata.seq_groups is None


def test_embedding_model_runner_input():
    pooling_metadata = PoolingMetadata(
        seq_groups=[[0]],
        seq_data={},
        prompt_lens=[1],
    )
    attn_metadata = AttentionMetadata(
        num_prefills=1,
        num_prefill_tokens=2,
        num_decode_tokens=3,
        slot_mapping=torch.zeros(1),
131
        multi_modal_placeholder_index_maps=None,
132
        enable_kv_scales_calculation=True,
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    )
    model_input = ModelInputForGPUWithPoolingMetadata(
        input_tokens=torch.ones(10),
        input_positions=torch.ones(10),
        pooling_metadata=pooling_metadata,
        attn_metadata=attn_metadata)

    assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)

    # Test round trip serialization.
    tensor_dict = model_input.as_broadcastable_tensor_dict()
    attn_backend = MockAttentionBackend()
    received_model_input = (
        ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
            tensor_dict, attn_backend=attn_backend))
    # Check that received copy has correct values.
    assert isinstance(received_model_input,
                      ModelInputForGPUWithPoolingMetadata)
    assert received_model_input.input_tokens is not None
    assert (
        received_model_input.input_tokens == model_input.input_tokens).all()
    assert received_model_input.input_positions is not None
    assert (received_model_input.input_positions == model_input.input_positions
            ).all()
    assert received_model_input.multi_modal_kwargs is None
    assert (received_model_input.multi_modal_kwargs ==
            model_input.multi_modal_kwargs)
    assert received_model_input.lora_requests is None
    assert received_model_input.lora_requests == model_input.lora_requests
    assert received_model_input.lora_mapping is None
    assert received_model_input.lora_mapping == model_input.lora_mapping
    for field in dataclasses.fields(AttentionMetadata):
        assert getattr(received_model_input.attn_metadata, field.name,
                       None) == getattr(attn_metadata, field.name, None)
    # Pooling metadata is not broadcast.
    assert received_model_input.pooling_metadata is None
169
170
171
172
173
174
175
176
177
178
179
180
181
182


def test_multi_step_model_runner_input():
    sampling_metadata = SamplingMetadata(
        ["seq_group"],
        "selected_token_indices",
        "categorized_sample_indices",
        "num_prompts",
    )
    attn_metadata = AttentionMetadata(
        num_prefills=1,
        num_prefill_tokens=2,
        num_decode_tokens=3,
        slot_mapping=torch.zeros(1),
183
        multi_modal_placeholder_index_maps=None,
184
        enable_kv_scales_calculation=True,
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    )
    frozen_model_input = ModelInputForGPUWithSamplingMetadata(
        input_tokens=torch.ones(10),
        input_positions=torch.ones(10),
        sampling_metadata=sampling_metadata,
        attn_metadata=attn_metadata)

    model_input = StatefulModelInput(
        frozen_model_input=frozen_model_input,
        is_last_step=True,
        is_first_multi_step=False,
        current_step=4,
        last_sampled_token_ids=torch.ones((10, 1)),
        is_multi_step=True,
        num_queries=8,
        num_seqs=5,
        cached_outputs=[],
    )

    assert isinstance(model_input, StatefulModelInput)

    # Test round trip serialization.
    tensor_dict = model_input.as_broadcastable_tensor_dict()
    attn_backend = MockAttentionBackend()
    received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
        tensor_dict, attn_backend=attn_backend))

212
    received_frozen_input = received_model_input.frozen_model_input
213
214
215

    # Check that received copy has correct values.
    assert isinstance(received_model_input, StatefulModelInput)
216
217
    assert received_frozen_input.input_tokens is not None
    assert (received_frozen_input.input_tokens ==
218
            frozen_model_input.input_tokens).all()
219
220
    assert received_frozen_input.input_positions is not None
    assert (received_frozen_input.input_positions ==
221
            frozen_model_input.input_positions).all()
222
    assert received_frozen_input.multi_modal_kwargs is None
223
224
    assert (frozen_model_input.multi_modal_kwargs ==
            frozen_model_input.multi_modal_kwargs)
225
226
    assert received_frozen_input.lora_requests is None
    assert (received_frozen_input.lora_requests ==
227
            frozen_model_input.lora_requests)
228
    assert received_frozen_input.lora_mapping is None
229
    assert (
230
        received_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
231
    for field in dataclasses.fields(AttentionMetadata):
232
        assert getattr(received_frozen_input.attn_metadata, field.name,
233
234
                       None) == getattr(attn_metadata, field.name, None)
    # For sampling metadata, only selected_token_indices is copied.
235
    assert (received_frozen_input.sampling_metadata.selected_token_indices ==
236
            sampling_metadata.selected_token_indices)
237
    assert received_frozen_input.sampling_metadata.seq_groups is None
238
239
240
241
242
243
244
245
246

    # check non frozen fields
    assert received_model_input.is_last_step == model_input.is_last_step
    assert (received_model_input.is_first_multi_step ==
            model_input.is_first_multi_step)
    assert received_model_input.current_step == model_input.current_step
    assert (received_model_input.last_sampled_token_ids ==
            model_input.last_sampled_token_ids).all()
    assert received_model_input.is_multi_step == model_input.is_multi_step