utils.py 9.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Sequence as GenericSequence
5
from itertools import count
6
from typing import Callable, Optional, TypeVar, Union
7
from unittest.mock import MagicMock
8

9
10
import torch

11
from vllm.engine.arg_utils import EngineArgs
12
from vllm.model_executor.layers.sampler import SamplerOutput
13
from vllm.model_executor.utils import set_random_seed
14
from vllm.sampling_params import SamplingParams
15
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
16
                           SequenceData, SequenceGroupMetadata, SequenceOutput)
17
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
18
from vllm.worker.cache_engine import CacheEngine
19
from vllm.worker.model_runner import ModelRunner
20
from vllm.worker.worker import Worker
21

22
23
T = TypeVar("T", bound=Worker)

24
25
26
27
28

def round_up_to_next_block(seq_len: int, block_size: int) -> int:
    return (seq_len + block_size - 1) // block_size


29
30
31
def mock_worker(cls=None,
                vocab_size: int = 30_000,
                max_model_len: int = 2048,
32
33
                rank: int = 0,
                use_spec: bool = True) -> MagicMock:
34
35
36
    if cls is None:
        cls = Worker

37
38
39
    spec = cls if use_spec else None

    worker = MagicMock(spec=spec)
40
41
42
43
44
45
46
    worker.vocab_size = vocab_size
    worker.max_model_len = max_model_len
    worker.rank = rank
    worker.device = 'cuda:0'
    return worker


47
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: list[int]):
48
49
50
51
52
53
54
55
56
57
58
    seed_iter = iter(rand_seeds)
    original_execute_model = worker.execute_model

    def new_execute_model(*args, **kwargs):
        result = original_execute_model(*args, **kwargs)
        set_random_seed(next(seed_iter))
        return result

    return new_execute_model


59
def zero_kv_cache(cache_engine: list[CacheEngine]):
60
61
    assert cache_engine[0].gpu_cache
    for key_blocks, value_blocks in cache_engine[0].gpu_cache:
62
63
64
65
        key_blocks.zero_()
        value_blocks.zero_()


66
def create_worker(cls: Callable[..., T],
67
68
69
70
71
                  model_name: str,
                  block_size: int,
                  num_gpu_blocks: int,
                  seed: int,
                  is_driver_worker: bool = True,
72
                  enforce_eager: bool = True,
73
74
                  model_runner_cls: Optional[ModelRunner] = None,
                  dtype: Optional[str] = "auto") -> T:
75
76
77
78
79
    engine_args = EngineArgs(
        model=model_name,
        seed=seed,
        block_size=block_size,
        enforce_eager=enforce_eager,
80
        dtype=dtype,
81
    )
82
    engine_config = engine_args.create_engine_config()
83
84
85
86
87

    distributed_init_method = get_distributed_init_method(
        get_ip(), get_open_port())

    worker = cls(
88
        vllm_config=engine_config,
89
90
91
92
        local_rank=0,
        rank=0,
        distributed_init_method=distributed_init_method,
        is_driver_worker=is_driver_worker,
93
        model_runner_cls=model_runner_cls,
94
95
    )

96
    worker.init_device()
97
98
    worker.load_model()

99
100
    engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
    engine_config.cache_config.num_cpu_blocks = 0
101
102
103
    worker.initialize_cache(
        num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
        num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
104
105
106
107
108

    return worker


def create_seq_group_metadata_from_prompts(
109
    prompts: list[list[int]],
110
111
    num_gpu_blocks: int,
    block_size: int,
112
113
114
115
    final_prompt_lens: list[int],
    continuations: Optional[list[list[int]]] = None,
    seq_ids: Optional[list[int]] = None,
) -> list[SequenceGroupMetadata]:
116
117
118
119
120
121
122
123
124
125
126
127
128
129

    if continuations is None:
        continuations = [[] for _ in prompts]

    if seq_ids is None:
        seq_ids = list(i for i, _ in enumerate(prompts))

    free_gpu_blocks = list(range(num_gpu_blocks))

    block_allocations = {
        i: [
            free_gpu_blocks.pop()
            for _ in range(round_up_to_next_block(final_len, block_size))
        ]
130
        for i, final_len in enumerate(final_prompt_lens)
131
132
    }

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    seq_grou_metadata_list = []
    for i, (prompt_token_ids,
            cont_token_ids) in enumerate(zip(prompts, continuations)):
        data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
        data.update_num_computed_tokens(
            len(prompt_token_ids) + len(cont_token_ids) - 1)
        seq_data = {i: data}
        seq_grou_metadata_list.append(
            SequenceGroupMetadata(
                request_id=str(i),
                is_prompt=len(cont_token_ids) == 0,
                seq_data=seq_data,
                sampling_params=SamplingParams(temperature=0.0),
                block_tables={i: block_allocations[i][:]},
            ))
    return seq_grou_metadata_list
149
150


151
def create_chunked_seq_group_metadata_from_prompt(
152
        prompt: list[int],
153
154
155
        num_gpu_blocks: int,
        chunk_size: int,
        block_size: int,
156
        seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

    if seq_id is None:
        seq_id = 0

    free_gpu_blocks = list(range(num_gpu_blocks))

    block_allocations = [
        free_gpu_blocks.pop()
        for _ in range(round_up_to_next_block(len(prompt), block_size))
    ]

    seq_group_metadata_list = []
    for i, idx in enumerate(range(0, len(prompt), chunk_size)):
        chunk_ids = prompt[idx:idx + chunk_size]
        data = SequenceData.from_seqs(prompt)
        data.update_num_computed_tokens(idx)
        seq_data = {i: data}
        seq_group_metadata_list.append(
            SequenceGroupMetadata(
                request_id=str(seq_id),
                is_prompt=True,
                do_sample=idx + chunk_size >= len(prompt),  # terminal chunk
                seq_data=seq_data,
                sampling_params=SamplingParams(temperature=0.0),
                block_tables={i: block_allocations},
                token_chunk_size=len(chunk_ids)))
    return seq_group_metadata_list


186
def assert_logprobs_dict_allclose(
187
188
        actual_logprobs: list[dict[int, Logprob]],
        expected_logprobs: list[dict[int, Logprob]]) -> None:
189
190
191
192
193
    for single_step_actual_logprobs, single_step_expected_logprobs in zip(
            actual_logprobs, expected_logprobs):
        assert set(single_step_actual_logprobs.keys()) == set(
            single_step_expected_logprobs.keys())
        for token_id in single_step_actual_logprobs:
194
195
196
197
            actual = torch.tensor(
                single_step_actual_logprobs[token_id].logprob)
            expected = torch.tensor(
                single_step_expected_logprobs[token_id].logprob)
198
            torch.testing.assert_close(actual, expected)
199
200
201
202


def create_sampler_output_list(
        token_ids: torch.Tensor,
203
204
        probs: GenericSequence[Optional[torch.Tensor]],
        logprobs: GenericSequence[Optional[torch.Tensor]],
205
        seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]:
206
207
208
209
210
211
212
213
    num_steps, batch_size = token_ids.shape
    token_ids_by_step = token_ids.tolist()

    if seq_ids is None:
        seq_ids = list(range(batch_size))

    return [
        SamplerOutput(outputs=[
214
            CompletionSequenceGroupOutput(
215
216
217
218
                samples=[
                    SequenceOutput(
                        output_token=token_id,
                        parent_seq_id=seq_ids[seq_index],
219
                        logprobs={token_id: Logprob(0)},
220
221
222
223
224
225
                    )
                ],
                prompt_logprobs=None,
            ) for seq_index, token_id in enumerate(token_ids_by_step[step])
        ],
                      sampled_token_probs=probs[step],
226
                      logprobs=logprobs[step],
227
228
229
230
231
232
233
                      sampled_token_ids=token_ids[step])
        for step in range(num_steps)
    ]


def create_batch(batch_size,
                 k,
234
                 prompt_len: Union[int, list[int]] = 10,
235
                 prev_output_token_len: int = 10,
236
                 seq_ids: Optional[list[int]] = None,
237
                 num_gpu_blocks: Optional[int] = None,
238
239
                 block_size: Optional[int] = None,
                 prefill_chunk_size: Optional[int] = None):
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    if block_size is None:
        block_size = 8

    if num_gpu_blocks is None:
        num_gpu_blocks = 2048 // block_size

    iterator = count()

    if isinstance(prompt_len, int):
        prompt_lens = [prompt_len for _ in range(batch_size)]
    else:
        prompt_lens = prompt_len

    prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    if prefill_chunk_size:
        # Create a batch of chunked prompts.
        if not seq_ids:
            seq_ids = list(range(len(prompts)))
        seq_group_metadata_list = []
        for p, sid in zip(prompts, seq_ids):
            seq_group_metadata_list += \
                create_chunked_seq_group_metadata_from_prompt(
                p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
        seq_group_metadata_list = seq_group_metadata_list[:batch_size]
        prev_output_tokens = []
    else:
        prev_output_tokens = [[
            next(iterator) for _ in range(prev_output_token_len)
        ] for _ in range(batch_size)]
        final_prompt_lens = [
            len(prompt) + len(prev_output_token) + k + 1
            for prompt, prev_output_token in zip(prompts, prev_output_tokens)
        ]

        seq_group_metadata_list = create_seq_group_metadata_from_prompts(
            prompts, num_gpu_blocks, block_size, final_prompt_lens,
            prev_output_tokens, seq_ids)
278
    return seq_group_metadata_list, prompts, prev_output_tokens
279
280
281
282
283
284
285
286
287
288
289
290


def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs):
    if prefill_chunk_size > 0:
        llm_kwargs.update(
            **{
                "enable_chunked_prefill": True,
                "max_num_batched_tokens": prefill_chunk_size,
                "max_num_seqs": prefill_chunk_size
            })
    else:
        llm_kwargs["enable_chunked_prefill"] = False