"docs/vscode:/vscode.git/clone" did not exist on "ad39106b16fee0074e814f06ec7a517399ea154d"
conftest.py 13.4 KB
Newer Older
1
import asyncio
2
import os
3
from itertools import cycle
4
from typing import Dict, List, Optional, Sequence, Tuple, Union
5

6
import pytest
7
import ray
8
import torch
9

10
from vllm import LLM
11
12
13
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.lora.request import LoRARequest
14
from vllm.model_executor.utils import set_random_seed
15
from vllm.multimodal import MultiModalDataDict
16
from vllm.outputs import RequestOutput
17
from vllm.prompt_adapter.request import PromptAdapterRequest
18
from vllm.sampling_params import SamplingParams
19
from vllm.sequence import Logprob
20
21
22
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, random_uuid

23
from ...conftest import cleanup
24
from ...utils import wait_for_gpu_memory_to_clear
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

class AsyncLLM:
    """AsyncLLM

    Note: Current LLM class in vllm don't support async mode, for test purpose,
    we implement async one in here. Maybe we could move to
    vllm/entrypoints/llm.py in future.

    Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
    to make to work in async mode.
    """

    def __init__(
        self,
        model: str,
        tokenizer: Optional[str] = None,
        tokenizer_mode: str = "auto",
        skip_tokenizer_init: bool = False,
        trust_remote_code: bool = False,
        tensor_parallel_size: int = 1,
        dtype: str = "auto",
        quantization: Optional[str] = None,
        revision: Optional[str] = None,
        tokenizer_revision: Optional[str] = None,
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
        swap_space: int = 4,
        enforce_eager: bool = False,
54
        max_seq_len_to_capture: int = 8192,
55
56
57
58
59
        disable_custom_all_reduce: bool = False,
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
60
61
62
63
64

        # Needed to engine_use_ray works as a deprecated feature,
        # otherwise the following constructor will raise an exception
        os.environ["VLLM_ALLOW_ENGINE_USE_RAY"] = "1"

65
        engine_args = AsyncEngineArgs(
66
67
68
69
70
71
72
73
74
75
76
77
78
79
            model=model,
            tokenizer=tokenizer,
            tokenizer_mode=tokenizer_mode,
            skip_tokenizer_init=skip_tokenizer_init,
            trust_remote_code=trust_remote_code,
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
            quantization=quantization,
            revision=revision,
            tokenizer_revision=tokenizer_revision,
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
            enforce_eager=enforce_eager,
80
            max_seq_len_to_capture=max_seq_len_to_capture,
81
82
83
            # For now use ray for the distributed back-end, since
            # we rely on the use of engine_use_ray=True to avoid
            # reinitializing CUDA in the same process (driver worker)
84
            engine_use_ray=True,
85
            distributed_executor_backend="ray",
86
87
88
89
            disable_custom_all_reduce=disable_custom_all_reduce,
            **kwargs,
        )
        self.request_counter = Counter()
90
91
        self.llm_engine = AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
92
93
94
95
96
97
98
99
100

    def generate(
        self,
        prompts: Optional[Union[str, List[str]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
101
        multi_modal_data: Optional[MultiModalDataDict] = None,
102
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    ) -> List[RequestOutput]:

        if prompts is None:
            raise ValueError("prompts must be provided.")
        if isinstance(prompts, str):
            # Convert a single prompt to a list.
            prompts = [prompts]

        if prompts is not None:
            num_requests = len(prompts)

        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

        elif isinstance(sampling_params,
                        list) and len(sampling_params) != num_requests:
            raise ValueError("The lengths of prompts and "
                             "sampling_params must be the same.")

123
        async def get_output(prompt, sampling_param) -> RequestOutput:
124
            request_id = random_uuid()
125
126
            results_generator = self.llm_engine.generate(
                prompt, sampling_param, request_id)
127
128
129
            final_output = None
            async for request_output in results_generator:
                final_output = request_output
130
            assert final_output is not None
131
132
            return final_output

133
        outputs: List[RequestOutput] = []
134
135
136
        try:
            for i in range(num_requests):
                prompt = prompts[i] if prompts is not None else None
137
138
139
                params = sampling_params[i] if isinstance(
                    sampling_params, Sequence) else sampling_params
                res = asyncio.run(get_output(prompt, params))
140
141
142
143
                outputs.append(res)
        finally:
            ray.shutdown()
        return outputs
144
145
146


@pytest.fixture
147
148
149
150
151
def baseline_llm_generator(request, common_llm_kwargs,
                           per_test_common_llm_kwargs, baseline_llm_kwargs,
                           seed):
    return create_llm_generator("baseline", request, common_llm_kwargs,
                                per_test_common_llm_kwargs,
152
153
154
155
                                baseline_llm_kwargs, seed)


@pytest.fixture
156
def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
157
                       test_llm_kwargs, seed):
158
159
160
    return create_llm_generator("test", request, common_llm_kwargs,
                                per_test_common_llm_kwargs, test_llm_kwargs,
                                seed)
161
162


163
164
165
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
                         per_test_common_llm_kwargs, distinct_llm_kwargs,
                         seed):
166
167
168
169
170
    kwargs = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **distinct_llm_kwargs,
    }
171
    test_name = request.node.name
172

173
174
175
176
177
    model = kwargs["model"]
    draft_model = kwargs.get("speculative_model", None)
    same_draft_target_model = (draft_model is not None
                               and draft_model == model)

178
    def generator_inner():
179
180
181
182
183
184

        wait_for_gpu_memory_to_clear(
            devices=list(range(torch.cuda.device_count())),
            threshold_bytes=2 * 2**30,
            timeout_s=60,
        )
185

186
187
188
        use_async = False
        if "use_async" in kwargs:
            use_async = kwargs.pop("use_async")
189
        print(f'{use_async=}')
190

191
        print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
192
        llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
193
194
195
196
197
198
199

        # Override logging interval to 0 for spec decode test run to
        # log all metrics in time.
        if (baseline_or_test == "test" and not use_async
                and llm.llm_engine.log_stats):
            for sate_logger in llm.llm_engine.stat_loggers.values():
                sate_logger.local_interval = 0
200
201
        if seed is not None:
            set_random_seed(seed)
202
203
204
205
206

        yield llm
        del llm
        cleanup()

207
208
209
210
211
    def generator_outer():
        for llm in generator_inner():
            yield llm
            del llm

212
213
214
    # Set an attribute to the generator_outer function to allow us to
    # determine whether to further check the acceptance rate in tests.
    generator_outer.same_draft_target_model = same_draft_target_model  # type: ignore
215
216
217
    return generator_outer


218
219
220
221
222
223
224
225
226
227
228
def maybe_assert_ngram_worker(llm):
    # Verify the proposer worker is ngram if ngram is specified.
    if (not isinstance(llm, AsyncLLM)
            and llm.llm_engine.speculative_config is not None
            and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
        from vllm.spec_decode.ngram_worker import NGramWorker
        assert isinstance(
            llm.llm_engine.model_executor.driver_worker.proposer_worker,
            NGramWorker)


229
230
def get_output_from_llm_generator(
        llm_generator, prompts,
231
        sampling_params) -> Tuple[List[str], List[List[int]], float]:
232
233
    tokens: List[str] = []
    token_ids: List[List[int]] = []
234
    acceptance_rate: float = -1.0
235
    for llm in llm_generator():
236
237
        maybe_assert_ngram_worker(llm)

238
        outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
239

240
241
        token_ids = [output.outputs[0].token_ids for output in outputs]
        tokens = [output.outputs[0].text for output in outputs]
242
243
244
245
246
247
248

        # Fetch acceptance rate if logging is enabled.
        if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
            stat_logger = stat_loggers["prometheus"]
            acceptance_rate = (stat_logger.metrics.
                               gauge_spec_decode_draft_acceptance_rate.labels(
                                   **stat_logger.labels)._value.get())
249
        del llm
250

251
    return tokens, token_ids, acceptance_rate
252
253


254
255
256
257
258
259
260
261
262
263
264
265
266
267
def get_logprobs_from_llm_generator(
        llm_generator, prompts,
        sampling_params) -> List[List[Dict[int, Logprob]]]:
    """Returns a dict of (token_id: Logprob) for each generated position, for
    each sequence in the batch.
    """
    for llm in llm_generator():
        outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
        logprobs = [output.outputs[0].logprobs[:] for output in outputs]
        del llm

    return logprobs


268
269
270
271
272
def run_greedy_equality_correctness_test(baseline_llm_generator,
                                         test_llm_generator,
                                         batch_size,
                                         max_output_len,
                                         force_output_len: bool,
273
274
                                         print_tokens: bool = False,
                                         ensure_all_accepted: bool = False):
275
276
277
278
    """Helper method that compares the outputs of both the baseline LLM and
    the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
    the same when temperature is zero.
    """
279
280
281
282
283
284
285
286
287
288
289
290

    run_equality_correctness_test(baseline_llm_generator,
                                  test_llm_generator,
                                  batch_size,
                                  max_output_len,
                                  force_output_len,
                                  temperature=0.0,
                                  seeded=False,
                                  print_tokens=print_tokens,
                                  ensure_all_accepted=ensure_all_accepted)


291
292
293
294
295
296
297
298
299
300
301
def run_equality_correctness_test(
        baseline_llm_generator,
        test_llm_generator,
        batch_size,
        max_output_len,
        force_output_len: bool,
        temperature: float,
        seeded: bool,
        print_tokens: bool = False,
        ensure_all_accepted: bool = False,
        expected_acceptance_rate: Optional[float] = None):
302
303
304
305
    """Helper method that compares the outputs of both the baseline LLM and
    the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
    the same when temperature is zero (or when temperature is > 0 and seeded).
    """
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
        "San Francisco is know for its",
        "Facebook was created in 2004 by",
        "Curious George is a",
        "Python 3.11 brings improvements to its",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    # If the test requires that we generated max_output_len tokens, then set the
    # sampling params to ignore eos token.
    ignore_eos = force_output_len

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    if seeded:
        sampling_params = [
            SamplingParams(
                max_tokens=max_output_len,
                ignore_eos=ignore_eos,
                temperature=temperature,
                seed=i,
            ) for i in range(len(prompts))
        ]
    else:
        sampling_params = SamplingParams(
            max_tokens=max_output_len,
            ignore_eos=ignore_eos,
            temperature=temperature,
        )
339

340
341
342
    (spec_batch_tokens, spec_batch_token_ids,
     acceptance_rate) = get_output_from_llm_generator(test_llm_generator,
                                                      prompts, sampling_params)
343

344
345
346
    (baseline_batch_tokens, baseline_batch_token_ids,
     _) = get_output_from_llm_generator(baseline_llm_generator, prompts,
                                        sampling_params)
347
348
349
350
351
352
353
354
355
356
357
358
359
360

    assert len(baseline_batch_token_ids) == len(prompts)
    assert len(spec_batch_token_ids) == len(prompts)

    for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
            spec_tokens) in enumerate(
                zip(baseline_batch_token_ids, baseline_batch_tokens,
                    spec_batch_token_ids, spec_batch_tokens)):
        if print_tokens:
            print(f'{i=} {baseline_tokens=}')
            print(f'{i=}     {spec_tokens=}')
        print(f'{i=} {baseline_token_ids=}')
        print(f'{i=}     {spec_token_ids=}')
        assert baseline_token_ids == spec_token_ids
361

362
363
    print(f'{acceptance_rate=}')

364
365
    if ensure_all_accepted:
        assert acceptance_rate == 1.0
366
367
368

    if expected_acceptance_rate is not None:
        assert acceptance_rate >= expected_acceptance_rate - 1e-2