utils.py 7.42 KB
Newer Older
1
from itertools import count
2
3
4
from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import TypeVar, Union
5
from unittest.mock import MagicMock
6

7
8
import torch

9
from vllm.engine.arg_utils import EngineArgs
10
from vllm.model_executor.layers.sampler import SamplerOutput
11
from vllm.model_executor.utils import set_random_seed
12
from vllm.sampling_params import SamplingParams
13
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
14
                           SequenceData, SequenceGroupMetadata, SequenceOutput)
15
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
16
from vllm.worker.cache_engine import CacheEngine
17
from vllm.worker.model_runner import ModelRunner
18
from vllm.worker.worker import Worker
19

20
21
T = TypeVar("T", bound=Worker)

22
23
24
25
26

def round_up_to_next_block(seq_len: int, block_size: int) -> int:
    return (seq_len + block_size - 1) // block_size


27
28
29
def mock_worker(cls=None,
                vocab_size: int = 30_000,
                max_model_len: int = 2048,
30
31
                rank: int = 0,
                use_spec: bool = True) -> MagicMock:
32
33
34
    if cls is None:
        cls = Worker

35
36
37
    spec = cls if use_spec else None

    worker = MagicMock(spec=spec)
38
39
40
41
42
43
44
    worker.vocab_size = vocab_size
    worker.max_model_len = max_model_len
    worker.rank = rank
    worker.device = 'cuda:0'
    return worker


45
46
47
48
49
50
51
52
53
54
55
56
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
    seed_iter = iter(rand_seeds)
    original_execute_model = worker.execute_model

    def new_execute_model(*args, **kwargs):
        result = original_execute_model(*args, **kwargs)
        set_random_seed(next(seed_iter))
        return result

    return new_execute_model


57
58
59
def zero_kv_cache(cache_engine: List[CacheEngine]):
    assert cache_engine[0].gpu_cache
    for key_blocks, value_blocks in cache_engine[0].gpu_cache:
60
61
62
63
        key_blocks.zero_()
        value_blocks.zero_()


64
def create_worker(cls: Callable[..., T],
65
66
67
68
69
                  model_name: str,
                  block_size: int,
                  num_gpu_blocks: int,
                  seed: int,
                  is_driver_worker: bool = True,
70
71
                  enforce_eager: bool = True,
                  model_runner_cls: Optional[ModelRunner] = None) -> T:
72
73
74
75
76
77
    engine_args = EngineArgs(
        model=model_name,
        seed=seed,
        block_size=block_size,
        enforce_eager=enforce_eager,
    )
78
    engine_config = engine_args.create_engine_config()
79
80
81
82
83

    distributed_init_method = get_distributed_init_method(
        get_ip(), get_open_port())

    worker = cls(
84
        vllm_config=engine_config,
85
86
87
88
        local_rank=0,
        rank=0,
        distributed_init_method=distributed_init_method,
        is_driver_worker=is_driver_worker,
89
        model_runner_cls=model_runner_cls,
90
91
    )

92
    worker.init_device()
93
94
    worker.load_model()

95
96
    engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
    engine_config.cache_config.num_cpu_blocks = 0
97
98
99
    worker.initialize_cache(
        num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
        num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
100
101
102
103
104
105
106
107

    return worker


def create_seq_group_metadata_from_prompts(
    prompts: List[List[int]],
    num_gpu_blocks: int,
    block_size: int,
108
    final_prompt_lens: List[int],
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    continuations: Optional[List[List[int]]] = None,
    seq_ids: Optional[List[int]] = None,
) -> List[SequenceGroupMetadata]:

    if continuations is None:
        continuations = [[] for _ in prompts]

    if seq_ids is None:
        seq_ids = list(i for i, _ in enumerate(prompts))

    free_gpu_blocks = list(range(num_gpu_blocks))

    block_allocations = {
        i: [
            free_gpu_blocks.pop()
            for _ in range(round_up_to_next_block(final_len, block_size))
        ]
126
        for i, final_len in enumerate(final_prompt_lens)
127
128
    }

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    seq_grou_metadata_list = []
    for i, (prompt_token_ids,
            cont_token_ids) in enumerate(zip(prompts, continuations)):
        data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
        data.update_num_computed_tokens(
            len(prompt_token_ids) + len(cont_token_ids) - 1)
        seq_data = {i: data}
        seq_grou_metadata_list.append(
            SequenceGroupMetadata(
                request_id=str(i),
                is_prompt=len(cont_token_ids) == 0,
                seq_data=seq_data,
                sampling_params=SamplingParams(temperature=0.0),
                block_tables={i: block_allocations[i][:]},
            ))
    return seq_grou_metadata_list
145
146
147


def assert_logprobs_dict_allclose(
148
149
        actual_logprobs: List[Dict[int, Logprob]],
        expected_logprobs: List[Dict[int, Logprob]]) -> None:
150
151
152
153
154
    for single_step_actual_logprobs, single_step_expected_logprobs in zip(
            actual_logprobs, expected_logprobs):
        assert set(single_step_actual_logprobs.keys()) == set(
            single_step_expected_logprobs.keys())
        for token_id in single_step_actual_logprobs:
155
156
157
158
            actual = torch.tensor(
                single_step_actual_logprobs[token_id].logprob)
            expected = torch.tensor(
                single_step_expected_logprobs[token_id].logprob)
159
            torch.testing.assert_close(actual, expected)
160
161
162
163


def create_sampler_output_list(
        token_ids: torch.Tensor,
164
165
        probs: GenericSequence[Optional[torch.Tensor]],
        logprobs: GenericSequence[Optional[torch.Tensor]],
166
167
168
169
170
171
172
173
174
        seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
    num_steps, batch_size = token_ids.shape
    token_ids_by_step = token_ids.tolist()

    if seq_ids is None:
        seq_ids = list(range(batch_size))

    return [
        SamplerOutput(outputs=[
175
            CompletionSequenceGroupOutput(
176
177
178
179
                samples=[
                    SequenceOutput(
                        output_token=token_id,
                        parent_seq_id=seq_ids[seq_index],
180
                        logprobs={token_id: Logprob(0)},
181
182
183
184
185
186
                    )
                ],
                prompt_logprobs=None,
            ) for seq_index, token_id in enumerate(token_ids_by_step[step])
        ],
                      sampled_token_probs=probs[step],
187
                      logprobs=logprobs[step],
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
                      sampled_token_ids=token_ids[step])
        for step in range(num_steps)
    ]


def create_batch(batch_size,
                 k,
                 prompt_len: Union[int, List[int]] = 10,
                 prev_output_token_len: int = 10,
                 seq_ids: Optional[List[int]] = None,
                 num_gpu_blocks: Optional[int] = None,
                 block_size: Optional[int] = None):
    if block_size is None:
        block_size = 8

    if num_gpu_blocks is None:
        num_gpu_blocks = 2048 // block_size

    iterator = count()

    if isinstance(prompt_len, int):
        prompt_lens = [prompt_len for _ in range(batch_size)]
    else:
        prompt_lens = prompt_len

    prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
    prev_output_tokens = [[
        next(iterator) for _ in range(prev_output_token_len)
    ] for _ in range(batch_size)]
217
    final_prompt_lens = [
218
219
220
221
        len(prompt) + len(prev_output_token) + k + 1
        for prompt, prev_output_token in zip(prompts, prev_output_tokens)
    ]

222
223
224
225
    seq_group_metadata_list = create_seq_group_metadata_from_prompts(
        prompts, num_gpu_blocks, block_size, final_prompt_lens,
        prev_output_tokens, seq_ids)
    return seq_group_metadata_list, prompts, prev_output_tokens