conftest.py 19.5 KB
Newer Older
1
2
import contextlib
import gc
3
import os
4
5
6
7
from collections import UserList
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
8
9
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple,
                    TypedDict, TypeVar)
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12

import pytest
import torch
13
import torch.nn as nn
14
import torch.nn.functional as F
15
from PIL import Image
16
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
17
                          AutoTokenizer, BatchEncoding)
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19

from vllm import LLM, SamplingParams
20
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
21
22
from vllm.distributed import (destroy_distributed_environment,
                              destroy_model_parallel)
23
from vllm.inputs import TextPrompt
24
from vllm.logger import init_logger
25
26
27
28
29
30

if TYPE_CHECKING:
    from vllm.multimodal import MultiModalData
else:
    # it will call torch.cuda.device_count()
    MultiModalData = None
31
from vllm.sequence import SampleLogprobs
32
from vllm.utils import cuda_device_count_stateless, is_cpu
33
34

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
35

36
37
38
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
39

40
41
_IMAGE_DIR = Path(_TEST_DIR) / "images"
"""You can use `.buildkite/download-images.sh` to download the assets."""
42

43

44
def _read_prompts(filename: str) -> List[str]:
45
    with open(filename, "r") as f:
46
47
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@dataclass(frozen=True)
class ImageAsset:
    name: Literal["stop_sign", "cherry_blossom"]

    @cached_property
    def pixel_values(self) -> torch.Tensor:
        return torch.load(_IMAGE_DIR / f"{self.name}_pixel_values.pt")

    @cached_property
    def image_features(self) -> torch.Tensor:
        return torch.load(_IMAGE_DIR / f"{self.name}_image_features.pt")

    @cached_property
    def pil_image(self) -> Image.Image:
        return Image.open(_IMAGE_DIR / f"{self.name}.jpg")

    def for_hf(self) -> Image.Image:
        return self.pil_image

    def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
70
71
72
73
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from vllm.multimodal.image import ImageFeatureData  # noqa: F401
        from vllm.multimodal.image import ImagePixelData
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        image_input_type = vision_config.image_input_type
        ImageInputType = VisionLanguageConfig.ImageInputType

        if image_input_type == ImageInputType.IMAGE_FEATURES:
            return ImageFeatureData(self.image_features)
        if image_input_type == ImageInputType.PIXEL_VALUES:
            return ImagePixelData(self.pil_image)

        raise NotImplementedError


class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str


class _ImageAssets(UserList[ImageAsset]):

    def __init__(self) -> None:
        super().__init__(
            [ImageAsset("stop_sign"),
             ImageAsset("cherry_blossom")])

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
        return [prompts["stop_sign"], prompts["cherry_blossom"]]


IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""


111
112
def cleanup():
    destroy_model_parallel()
113
    destroy_distributed_environment()
114
115
116
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
117
118
    if not is_cpu():
        torch.cuda.empty_cache()
119
120


121
@pytest.fixture()
122
def should_do_global_cleanup_after_test(request) -> bool:
123
124
125
126
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
127
128
129
130

    if request.node.get_closest_marker("skip_global_cleanup"):
        return False

131
132
133
    return True


134
@pytest.fixture(autouse=True)
135
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
136
    yield
137
138
    if should_do_global_cleanup_after_test:
        cleanup()
139
140


Woosuk Kwon's avatar
Woosuk Kwon committed
141
142
@pytest.fixture
def example_prompts() -> List[str]:
143
144
    prompts = []
    for filename in _TEST_PROMPTS:
145
        prompts += _read_prompts(filename)
146
147
148
149
150
151
152
    return prompts


@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
153
        prompts += _read_prompts(filename)
154
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156


157
158
159
160
161
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


Woosuk Kwon's avatar
Woosuk Kwon committed
162
163
164
165
166
167
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
}

168
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
169

Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
172

class HfRunner:

173
    def wrap_device(self, input: _T) -> _T:
174
175
176
177
178
        if not is_cpu():
            return input.to("cuda")
        else:
            return input.to("cpu")

Woosuk Kwon's avatar
Woosuk Kwon committed
179
180
181
182
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
183
        *,
184
        model_kwargs: Optional[Dict[str, Any]] = None,
185
186
        is_embedding_model: bool = False,
        is_vision_model: bool = False,
187
        is_sparseml_model: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
    ) -> None:
        assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
        torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
191

192
        self.model_name = model_name
193

194
        if is_embedding_model:
195
196
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
197
198
199
200
201
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
                ).to(dtype=torch_dtype))
202
        else:
203
204
            if is_vision_model:
                auto_cls = AutoModelForVision2Seq
205
206
207
            elif is_sparseml_model:
                from sparseml.transformers import SparseAutoModelForCausalLM
                auto_cls = SparseAutoModelForCausalLM
208
209
210
            else:
                auto_cls = AutoModelForCausalLM

211
            model_kwargs = model_kwargs if model_kwargs is not None else {}
212
            self.model = self.wrap_device(
213
                auto_cls.from_pretrained(
214
215
216
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
217
                    **model_kwargs,
218
                ))
219
220
221
222
223
224
225
226

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

        try:
227
228
229
            # don't put this import at the top level
            # it will call torch.cuda.device_count()
            from transformers import AutoProcessor  # noqa: F401
230
231
232
233
234
235
236
237
238
239
            self.processor = AutoProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
        except Exception:
            logger.warning(
                "Unable to auto-load processor from HuggingFace for "
                "model %s. Using tokenizer instead.", model_name)
            self.processor = self.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
242
243

    def generate(
        self,
        prompts: List[str],
244
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
245
        **kwargs,
246
    ) -> List[Tuple[List[List[int]], List[str]]]:
247
248
        if images:
            assert len(prompts) == len(images)
249
250

        outputs: List[Tuple[List[List[int]], List[str]]] = []
251
        for i, prompt in enumerate(prompts):
252
253
254
255
256
257
258
259
260
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)

Woosuk Kwon's avatar
Woosuk Kwon committed
261
            output_ids = self.model.generate(
262
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265
                use_cache=True,
                **kwargs,
            )
266
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
267
268
269
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
270
271
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
272
273
274
275
276
277
278
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
279
        images: Optional[List[Image.Image]] = None,
Chang Su's avatar
Chang Su committed
280
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
281
    ) -> List[Tuple[List[int], str]]:
282
283
        outputs = self.generate(prompts,
                                do_sample=False,
284
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
285
286
                                images=images,
                                **kwargs)
287
288
289

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
290
291
292
293
294
295

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
296
    ) -> List[Tuple[List[List[int]], List[str]]]:
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
311

312
313
314
315
316
317
318
319
320
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
    ) -> List[List[torch.Tensor]]:
        all_logprobs = []
        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            output = self.model.generate(
321
                self.wrap_device(input_ids),
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )
            seq_logprobs = []
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
338
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
339
340
341
342
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

343
344
345
346
347
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
348
349
350
351
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
352
353
354
355

        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            output = self.model.generate(
356
                self.wrap_device(input_ids),
357
358
359
360
361
362
363
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )

364
            seq_logprobs: List[torch.Tensor] = []
365
366
367
368
369
370
371
372
373
374
            for _, hidden_states in enumerate(output.hidden_states):
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if getattr(self.model.get_output_embeddings(), "bias",
                           None) is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
375
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
376
377
378
                seq_logprobs.append(logprobs)

            # convert to dict
379
            seq_logprobs_lst: List[Dict[int, float]] = []
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
            for tok_idx, tok_logprobs in enumerate(seq_logprobs):
                # drop prompt logprobs
                if tok_idx == 0:
                    tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
                topk = tok_logprobs.topk(num_logprobs)

                tok_logprobs_dct = {}
                for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                    tok_logprobs_dct[token_id.item()] = logprob.item()

                seq_logprobs_lst.append(tok_logprobs_dct)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = seq_ids.shape[0] - input_ids.shape[1]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

403
404
405
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

406
407
408
409
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
410
411
412
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
413

Cyrus Leung's avatar
Cyrus Leung committed
414
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
417
418
419
420
421
422
423
424
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
425
426
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
427
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
428
        dtype: str = "half",
429
        disable_log_stats: bool = True,
430
        tensor_parallel_size: int = 1,
431
432
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
433
        swap_space: int = 4,
Cyrus Leung's avatar
Cyrus Leung committed
434
        enforce_eager: bool = False,
435
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
436
437
438
439
440
441
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
442
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
443
            enforce_eager=enforce_eager,
444
            disable_log_stats=disable_log_stats,
445
            tensor_parallel_size=tensor_parallel_size,
446
            max_model_len=max_model_len,
447
448
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
449
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
450
451
452
453
454
455
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
456
        images: Optional[List[MultiModalData]] = None,
457
    ) -> List[Tuple[List[List[int]], List[str]]]:
458
        if images is not None:
459
            assert len(prompts) == len(images)
460

461
462
463
464
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = image
465

466
        req_outputs = self.model.generate(inputs,
467
                                          sampling_params=sampling_params)
468
469

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
470
471
472
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
473
474
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
475
476
            for sample in req_output.outputs:
                output_str = sample.text
477
                output_ids = list(sample.token_ids)
478
479
480
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
481
482
        return outputs

483
484
485
486
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
487
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
488
489
490
491
        assert sampling_params.logprobs is not None

        req_outputs = self.model.generate(prompts,
                                          sampling_params=sampling_params)
492
        outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
493
494
495
496
497
498
499
500
        for req_output in req_outputs:
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = sample.token_ids
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs))
        return outputs

Woosuk Kwon's avatar
Woosuk Kwon committed
501
502
503
504
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
505
        images: Optional[List[MultiModalData]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
506
507
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
508
        outputs = self.generate(prompts, greedy_params, images=images)
509
510
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
511

512
513
514
515
516
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
517
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
518
519
520
521
522
523
524
525
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                max_tokens=max_tokens,
                                                logprobs=num_logprobs)
        outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

526
527
528
529
530
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
531
    ) -> List[Tuple[List[List[int]], List[str]]]:
532
533
534
535
536
537
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
538

539
540
541
542
543
544
545
546
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

547
548
549
550
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
551
552
553
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
554

555
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
556
557
def vllm_runner():
    return VllmRunner
558
559
560
561
562
563
564
565
566
567


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
584
585
586
587
588
589
590


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

591
    return cuda_device_count_stateless()