conftest.py 11.6 KB
Newer Older
1
import asyncio
2
from itertools import cycle
3
from typing import Dict, List, Optional, Tuple, Union
4

5
import pytest
6
import ray
7
import torch
8

9
from vllm import LLM
10
11
12
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.lora.request import LoRARequest
13
from vllm.model_executor.utils import set_random_seed
14
from vllm.multimodal import MultiModalDataDict
15
from vllm.outputs import RequestOutput
16
from vllm.prompt_adapter.request import PromptAdapterRequest
17
from vllm.sampling_params import SamplingParams
18
from vllm.sequence import Logprob
19
20
21
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, random_uuid

22
from ...conftest import cleanup
23
from ...utils import wait_for_gpu_memory_to_clear
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

class AsyncLLM:
    """AsyncLLM

    Note: Current LLM class in vllm don't support async mode, for test purpose,
    we implement async one in here. Maybe we could move to
    vllm/entrypoints/llm.py in future.

    Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
    to make to work in async mode.
    """

    def __init__(
        self,
        model: str,
        tokenizer: Optional[str] = None,
        tokenizer_mode: str = "auto",
        skip_tokenizer_init: bool = False,
        trust_remote_code: bool = False,
        tensor_parallel_size: int = 1,
        dtype: str = "auto",
        quantization: Optional[str] = None,
        revision: Optional[str] = None,
        tokenizer_revision: Optional[str] = None,
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
        swap_space: int = 4,
        enforce_eager: bool = False,
53
        max_seq_len_to_capture: int = 8192,
54
55
56
57
58
        disable_custom_all_reduce: bool = False,
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
59
        engine_args = AsyncEngineArgs(
60
61
62
63
64
65
66
67
68
69
70
71
72
73
            model=model,
            tokenizer=tokenizer,
            tokenizer_mode=tokenizer_mode,
            skip_tokenizer_init=skip_tokenizer_init,
            trust_remote_code=trust_remote_code,
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
            quantization=quantization,
            revision=revision,
            tokenizer_revision=tokenizer_revision,
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
            enforce_eager=enforce_eager,
74
            max_seq_len_to_capture=max_seq_len_to_capture,
75
76
77
            # For now use ray for the distributed back-end, since
            # we rely on the use of engine_use_ray=True to avoid
            # reinitializing CUDA in the same process (driver worker)
78
            engine_use_ray=True,
79
            distributed_executor_backend="ray",
80
81
82
83
            disable_custom_all_reduce=disable_custom_all_reduce,
            **kwargs,
        )
        self.request_counter = Counter()
84
85
        self.llm_engine = AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
86
87
88
89
90
91
92
93
94

    def generate(
        self,
        prompts: Optional[Union[str, List[str]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
95
        multi_modal_data: Optional[MultiModalDataDict] = None,
96
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    ) -> List[RequestOutput]:

        if prompts is None:
            raise ValueError("prompts must be provided.")
        if isinstance(prompts, str):
            # Convert a single prompt to a list.
            prompts = [prompts]

        if prompts is not None:
            num_requests = len(prompts)

        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

        elif isinstance(sampling_params,
                        list) and len(sampling_params) != num_requests:
            raise ValueError("The lengths of prompts and "
                             "sampling_params must be the same.")

117
        async def get_output(prompt, sampling_param) -> RequestOutput:
118
            request_id = random_uuid()
119
120
            results_generator = self.llm_engine.generate(
                prompt, sampling_param, request_id)
121
122
123
            final_output = None
            async for request_output in results_generator:
                final_output = request_output
124
            assert final_output is not None
125
126
            return final_output

127
        outputs: List[RequestOutput] = []
128
129
130
131
132
133
134
135
        try:
            for i in range(num_requests):
                prompt = prompts[i] if prompts is not None else None
                res = asyncio.run(get_output(prompt, sampling_params))
                outputs.append(res)
        finally:
            ray.shutdown()
        return outputs
136
137
138


@pytest.fixture
139
140
141
142
143
def baseline_llm_generator(request, common_llm_kwargs,
                           per_test_common_llm_kwargs, baseline_llm_kwargs,
                           seed):
    return create_llm_generator("baseline", request, common_llm_kwargs,
                                per_test_common_llm_kwargs,
144
145
146
147
                                baseline_llm_kwargs, seed)


@pytest.fixture
148
def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
149
                       test_llm_kwargs, seed):
150
151
152
    return create_llm_generator("test", request, common_llm_kwargs,
                                per_test_common_llm_kwargs, test_llm_kwargs,
                                seed)
153
154


155
156
157
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
                         per_test_common_llm_kwargs, distinct_llm_kwargs,
                         seed):
158
159
160
161
162
    kwargs = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **distinct_llm_kwargs,
    }
163
    test_name = request.node.name
164

165
166
167
168
169
    model = kwargs["model"]
    draft_model = kwargs.get("speculative_model", None)
    same_draft_target_model = (draft_model is not None
                               and draft_model == model)

170
    def generator_inner():
171
172
173
174
175
176

        wait_for_gpu_memory_to_clear(
            devices=list(range(torch.cuda.device_count())),
            threshold_bytes=2 * 2**30,
            timeout_s=60,
        )
177

178
179
180
        use_async = False
        if "use_async" in kwargs:
            use_async = kwargs.pop("use_async")
181
        print(f'{use_async=}')
182

183
        print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
184
        llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
185
186
187
188
189
190
191

        # Override logging interval to 0 for spec decode test run to
        # log all metrics in time.
        if (baseline_or_test == "test" and not use_async
                and llm.llm_engine.log_stats):
            for sate_logger in llm.llm_engine.stat_loggers.values():
                sate_logger.local_interval = 0
192
193
194
195
196
197
        set_random_seed(seed)

        yield llm
        del llm
        cleanup()

198
199
200
201
202
    def generator_outer():
        for llm in generator_inner():
            yield llm
            del llm

203
204
205
    # Set an attribute to the generator_outer function to allow us to
    # determine whether to further check the acceptance rate in tests.
    generator_outer.same_draft_target_model = same_draft_target_model  # type: ignore
206
207
208
    return generator_outer


209
210
211
212
213
214
215
216
217
218
219
def maybe_assert_ngram_worker(llm):
    # Verify the proposer worker is ngram if ngram is specified.
    if (not isinstance(llm, AsyncLLM)
            and llm.llm_engine.speculative_config is not None
            and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
        from vllm.spec_decode.ngram_worker import NGramWorker
        assert isinstance(
            llm.llm_engine.model_executor.driver_worker.proposer_worker,
            NGramWorker)


220
221
def get_output_from_llm_generator(
        llm_generator, prompts,
222
        sampling_params) -> Tuple[List[str], List[List[int]], float]:
223
224
    tokens: List[str] = []
    token_ids: List[List[int]] = []
225
    acceptance_rate: float = -1.0
226
    for llm in llm_generator():
227
228
        maybe_assert_ngram_worker(llm)

229
        outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
230

231
232
        token_ids = [output.outputs[0].token_ids for output in outputs]
        tokens = [output.outputs[0].text for output in outputs]
233
234
235
236
237
238
239

        # Fetch acceptance rate if logging is enabled.
        if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
            stat_logger = stat_loggers["prometheus"]
            acceptance_rate = (stat_logger.metrics.
                               gauge_spec_decode_draft_acceptance_rate.labels(
                                   **stat_logger.labels)._value.get())
240
        del llm
241

242
    return tokens, token_ids, acceptance_rate
243
244


245
246
247
248
249
250
251
252
253
254
255
256
257
258
def get_logprobs_from_llm_generator(
        llm_generator, prompts,
        sampling_params) -> List[List[Dict[int, Logprob]]]:
    """Returns a dict of (token_id: Logprob) for each generated position, for
    each sequence in the batch.
    """
    for llm in llm_generator():
        outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
        logprobs = [output.outputs[0].logprobs[:] for output in outputs]
        del llm

    return logprobs


259
260
261
262
263
def run_greedy_equality_correctness_test(baseline_llm_generator,
                                         test_llm_generator,
                                         batch_size,
                                         max_output_len,
                                         force_output_len: bool,
264
265
                                         print_tokens: bool = False,
                                         ensure_all_accepted: bool = False):
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    """Helper method that compares the outputs of both the baseline LLM and
    the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
    the same when temperature is zero.
    """
    temperature = 0.0

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
        "San Francisco is know for its",
        "Facebook was created in 2004 by",
        "Curious George is a",
        "Python 3.11 brings improvements to its",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    # If the test requires that we generated max_output_len tokens, then set the
    # sampling params to ignore eos token.
    ignore_eos = force_output_len

    sampling_params = SamplingParams(
        max_tokens=max_output_len,
        ignore_eos=ignore_eos,
        temperature=temperature,
    )

295
296
297
    (spec_batch_tokens, spec_batch_token_ids,
     acceptance_rate) = get_output_from_llm_generator(test_llm_generator,
                                                      prompts, sampling_params)
298

299
300
301
    (baseline_batch_tokens, baseline_batch_token_ids,
     _) = get_output_from_llm_generator(baseline_llm_generator, prompts,
                                        sampling_params)
302
303
304
305
306
307
308
309
310
311
312
313
314
315

    assert len(baseline_batch_token_ids) == len(prompts)
    assert len(spec_batch_token_ids) == len(prompts)

    for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
            spec_tokens) in enumerate(
                zip(baseline_batch_token_ids, baseline_batch_tokens,
                    spec_batch_token_ids, spec_batch_tokens)):
        if print_tokens:
            print(f'{i=} {baseline_tokens=}')
            print(f'{i=}     {spec_tokens=}')
        print(f'{i=} {baseline_token_ids=}')
        print(f'{i=}     {spec_token_ids=}')
        assert baseline_token_ids == spec_token_ids
316
317
318

    if ensure_all_accepted:
        assert acceptance_rate == 1.0