conftest.py 20.2 KB
Newer Older
1
2
import contextlib
import gc
3
import os
4
import sys
5
6
7
8
from collections import UserList
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
9
10
from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict,
                    TypeVar)
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13

import pytest
import torch
14
import torch.nn as nn
15
import torch.nn.functional as F
16
from PIL import Image
17
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
18
                          AutoTokenizer, BatchEncoding)
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20

from vllm import LLM, SamplingParams
21
from vllm.config import TokenizerPoolConfig
22
23
from vllm.distributed import (destroy_distributed_environment,
                              destroy_model_parallel)
24
from vllm.inputs import TextPrompt
25
from vllm.logger import init_logger
26
from vllm.multimodal.utils import fetch_image
27
28
from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu
29

30
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
33
34
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
35

36
37
_IMAGE_DIR = Path(_TEST_DIR) / "images"
"""You can use `.buildkite/download-images.sh` to download the assets."""
38

39

40
def _read_prompts(filename: str) -> List[str]:
41
    with open(filename, "r") as f:
42
43
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45


46
47
@dataclass(frozen=True)
class ImageAsset:
48
    name: Literal["stop_sign", "cherry_blossom", "boardwalk"]
49
50
51

    @cached_property
    def pil_image(self) -> Image.Image:
52
53
54
55
        if self.name == "boardwalk":
            return fetch_image(
                "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
            )
56

57
        return Image.open(_IMAGE_DIR / f"{self.name}.jpg")
58
59
60
61
62


class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
63
64
65
66
67
68
69
70
    boardwalk: str


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _ImageAssetsBase(UserList):
        pass
else:
71

72
73
    class _ImageAssetsBase(UserList[ImageAsset]):
        pass
74

75
76

class _ImageAssets(_ImageAssetsBase):
77
78

    def __init__(self) -> None:
79
80
81
82
83
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
            ImageAsset("boardwalk")
        ])
84
85
86
87
88
89
90
91

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
92
93
94
95
        return [
            prompts["stop_sign"], prompts["cherry_blossom"],
            prompts["boardwalk"]
        ]
96
97
98
99
100
101


IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""


102
103
def cleanup():
    destroy_model_parallel()
104
    destroy_distributed_environment()
105
106
107
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
108
109
    if not is_cpu():
        torch.cuda.empty_cache()
110
111


112
@pytest.fixture()
113
def should_do_global_cleanup_after_test(request) -> bool:
114
115
116
117
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
118
119
120
121

    if request.node.get_closest_marker("skip_global_cleanup"):
        return False

122
123
124
    return True


125
@pytest.fixture(autouse=True)
126
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
127
    yield
128
129
    if should_do_global_cleanup_after_test:
        cleanup()
130
131


Woosuk Kwon's avatar
Woosuk Kwon committed
132
133
@pytest.fixture
def example_prompts() -> List[str]:
134
135
    prompts = []
    for filename in _TEST_PROMPTS:
136
        prompts += _read_prompts(filename)
137
138
139
140
141
142
143
    return prompts


@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
144
        prompts += _read_prompts(filename)
145
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
146
147


148
149
150
151
152
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155
156
157
158
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
}

159
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
160

Woosuk Kwon's avatar
Woosuk Kwon committed
161
162
163

class HfRunner:

164
    def wrap_device(self, input: _T) -> _T:
165
166
167
168
169
        if not is_cpu():
            return input.to("cuda")
        else:
            return input.to("cpu")

Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
172
173
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
174
        *,
175
        model_kwargs: Optional[Dict[str, Any]] = None,
176
177
        is_embedding_model: bool = False,
        is_vision_model: bool = False,
178
        is_sparseml_model: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
179
180
181
    ) -> None:
        assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
        torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
182

183
        self.model_name = model_name
184

185
        if is_embedding_model:
186
187
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
188
189
190
191
192
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
                ).to(dtype=torch_dtype))
193
        else:
194
195
            if is_vision_model:
                auto_cls = AutoModelForVision2Seq
196
197
198
            elif is_sparseml_model:
                from sparseml.transformers import SparseAutoModelForCausalLM
                auto_cls = SparseAutoModelForCausalLM
199
200
201
            else:
                auto_cls = AutoModelForCausalLM

202
            model_kwargs = model_kwargs if model_kwargs is not None else {}
203
            self.model = self.wrap_device(
204
                auto_cls.from_pretrained(
205
206
207
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
208
                    **model_kwargs,
209
                ))
210
211
212
213
214
215
216
217

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

        try:
218
219
220
            # don't put this import at the top level
            # it will call torch.cuda.device_count()
            from transformers import AutoProcessor  # noqa: F401
221
222
223
224
225
226
227
228
229
230
            self.processor = AutoProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
        except Exception:
            logger.warning(
                "Unable to auto-load processor from HuggingFace for "
                "model %s. Using tokenizer instead.", model_name)
            self.processor = self.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
231
232
233
234

    def generate(
        self,
        prompts: List[str],
235
        images: Optional[List[Image.Image]] = None,
236
        **kwargs: Any,
237
    ) -> List[Tuple[List[List[int]], List[str]]]:
238
239
        if images:
            assert len(prompts) == len(images)
240
241

        outputs: List[Tuple[List[List[int]], List[str]]] = []
242
        for i, prompt in enumerate(prompts):
243
244
245
246
247
248
249
250
251
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)

Woosuk Kwon's avatar
Woosuk Kwon committed
252
            output_ids = self.model.generate(
253
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
254
255
256
                use_cache=True,
                **kwargs,
            )
257
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
258
259
260
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
261
262
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265
266
267
268
269
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
270
        images: Optional[List[Image.Image]] = None,
271
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
272
    ) -> List[Tuple[List[int], str]]:
273
274
        outputs = self.generate(prompts,
                                do_sample=False,
275
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
276
277
                                images=images,
                                **kwargs)
278
279
280

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
281
282
283
284
285
286

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
287
    ) -> List[Tuple[List[List[int]], List[str]]]:
288
289
290
291
292
293
294
295
296
297
298
299
300
301
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
302

303
304
305
306
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
307
308
        images: Optional[List[Image.Image]] = None,
        **kwargs: Any,
309
    ) -> List[List[torch.Tensor]]:
310
311
312
313
314
315
316
317
318
319
320
        all_logprobs: List[List[torch.Tensor]] = []
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)

321
            output = self.model.generate(
322
                **self.wrap_device(inputs),
323
324
325
326
327
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
328
                **kwargs,
329
            )
330
            seq_logprobs: List[torch.Tensor] = []
331
332
333
334
335
336
337
338
339
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
340
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
341
342
343
344
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

345
346
347
348
349
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
350
351
        images: Optional[List[Image.Image]] = None,
        **kwargs: Any,
352
353
354
355
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
356

357
358
359
360
361
362
363
364
365
366
367
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)
            input_ids = inputs.input_ids

368
            output = self.model.generate(
369
                **self.wrap_device(inputs),
370
371
372
373
374
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
375
                **kwargs,
376
377
            )

378
            seq_logprobs: List[torch.Tensor] = []
379
380
381
382
383
384
385
386
387
388
            for _, hidden_states in enumerate(output.hidden_states):
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if getattr(self.model.get_output_embeddings(), "bias",
                           None) is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
389
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
390
391
392
                seq_logprobs.append(logprobs)

            # convert to dict
393
            seq_logprobs_lst: List[Dict[int, float]] = []
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
            for tok_idx, tok_logprobs in enumerate(seq_logprobs):
                # drop prompt logprobs
                if tok_idx == 0:
                    tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
                topk = tok_logprobs.topk(num_logprobs)

                tok_logprobs_dct = {}
                for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                    tok_logprobs_dct[token_id.item()] = logprob.item()

                seq_logprobs_lst.append(tok_logprobs_dct)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = seq_ids.shape[0] - input_ids.shape[1]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

417
418
419
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

420
421
422
423
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
424
425
426
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
427

Cyrus Leung's avatar
Cyrus Leung committed
428
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
429
430
431
432
433
434
435
436
437
438
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
439
440
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
441
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
442
        dtype: str = "half",
443
        disable_log_stats: bool = True,
444
        tensor_parallel_size: int = 1,
445
446
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
447
        swap_space: int = 4,
Cyrus Leung's avatar
Cyrus Leung committed
448
        enforce_eager: bool = False,
449
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
450
451
452
453
454
455
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
456
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
457
            enforce_eager=enforce_eager,
458
            disable_log_stats=disable_log_stats,
459
            tensor_parallel_size=tensor_parallel_size,
460
            max_model_len=max_model_len,
461
462
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
463
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
464
465
466
467
468
469
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
470
        images: Optional[List[Image.Image]] = None,
471
    ) -> List[Tuple[List[List[int]], List[str]]]:
472
        if images is not None:
473
            assert len(prompts) == len(images)
474

475
476
477
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
478
                inputs[i]["multi_modal_data"] = {"image": image}
479

480
        req_outputs = self.model.generate(inputs,
481
                                          sampling_params=sampling_params)
482
483

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
484
485
486
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
487
488
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
489
490
            for sample in req_output.outputs:
                output_str = sample.text
491
                output_ids = list(sample.token_ids)
492
493
494
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
495
496
        return outputs

497
498
499
500
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
501
        images: Optional[List[Image.Image]] = None,
502
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
503
504
        assert sampling_params.logprobs is not None

505
506
507
508
509
510
511
512
513
        if images is not None:
            assert len(prompts) == len(images)

        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = {"image": image}

        req_outputs = self.model.generate(inputs,
514
                                          sampling_params=sampling_params)
515
        outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
516
517
518
519
520
521
522
523
        for req_output in req_outputs:
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = sample.token_ids
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs))
        return outputs

Woosuk Kwon's avatar
Woosuk Kwon committed
524
525
526
527
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
528
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
529
530
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
531
        outputs = self.generate(prompts, greedy_params, images=images)
532
533
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
534

535
536
537
538
539
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
540
        images: Optional[List[Image.Image]] = None,
541
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
542
543
544
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                max_tokens=max_tokens,
                                                logprobs=num_logprobs)
545
546
547
        outputs = self.generate_w_logprobs(prompts,
                                           greedy_logprobs_params,
                                           images=images)
548
549
550
551

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

552
553
554
555
556
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
557
    ) -> List[Tuple[List[List[int]], List[str]]]:
558
559
560
561
562
563
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
564

565
566
567
568
569
570
571
572
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

573
574
575
576
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
577
578
579
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
580

581
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
582
583
def vllm_runner():
    return VllmRunner
584
585
586
587
588
589
590
591
592
593


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
610
611
612
613
614
615
616


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

617
    return cuda_device_count_stateless()