conftest.py 17.4 KB
Newer Older
1
2
import contextlib
import gc
3
import os
4
from typing import Any, Dict, List, Optional, Tuple, TypeVar
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7

import pytest
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
from PIL import Image
11
12
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
                          AutoProcessor, AutoTokenizer, BatchEncoding)
Woosuk Kwon's avatar
Woosuk Kwon committed
13
14

from vllm import LLM, SamplingParams
15
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
16
from vllm.distributed import destroy_model_parallel
17
from vllm.inputs import TextPrompt
18
from vllm.logger import init_logger
19
20
21
from vllm.multimodal import MultiModalData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.sequence import SampleLogprobs
22
from vllm.utils import is_cpu
23
24

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
25

26
27
28
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
29

30
# Multi modal related
31
# You can use `.buildkite/download-images.sh` to download the assets
32
PIXEL_VALUES_FILES = [
33
34
35
    os.path.join(_TEST_DIR, "images", filename) for filename in
    ["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
]
36
IMAGE_FEATURES_FILES = [
37
38
39
    os.path.join(_TEST_DIR, "images", filename) for filename in
    ["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
]
40
IMAGE_FILES = [
41
42
43
    os.path.join(_TEST_DIR, "images", filename)
    for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
]
44
assert len(PIXEL_VALUES_FILES) == len(IMAGE_FEATURES_FILES) == len(IMAGE_FILES)
45

46

47
def _read_prompts(filename: str) -> List[str]:
48
    with open(filename, "r") as f:
49
50
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52


53
54
55
56
57
def cleanup():
    destroy_model_parallel()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
58
59
    if not is_cpu():
        torch.cuda.empty_cache()
60
61


62
@pytest.fixture()
63
def should_do_global_cleanup_after_test(request) -> bool:
64
65
66
67
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
68
69
70
71

    if request.node.get_closest_marker("skip_global_cleanup"):
        return False

72
73
74
    return True


75
@pytest.fixture(autouse=True)
76
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
77
    yield
78
79
    if should_do_global_cleanup_after_test:
        cleanup()
80
81


82
83
@pytest.fixture(scope="session")
def hf_images() -> List[Image.Image]:
84
    return [Image.open(filename) for filename in IMAGE_FILES]
85
86
87


@pytest.fixture()
88
def vllm_images(request) -> List[MultiModalData]:
89
90
91
    vision_language_config = request.getfixturevalue("model_and_config")[1]
    if vision_language_config.image_input_type == (
            VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
92
93
        return [
            ImageFeatureData(torch.load(filename))
94
            for filename in IMAGE_FEATURES_FILES
95
        ]
96
    else:
97
        return [
98
            ImagePixelData(Image.open(filename)) for filename in IMAGE_FILES
99
100
101
102
103
        ]


@pytest.fixture()
def vllm_image_tensors(request) -> List[torch.Tensor]:
104
    return [torch.load(filename) for filename in PIXEL_VALUES_FILES]
105
106


Woosuk Kwon's avatar
Woosuk Kwon committed
107
108
@pytest.fixture
def example_prompts() -> List[str]:
109
110
    prompts = []
    for filename in _TEST_PROMPTS:
111
        prompts += _read_prompts(filename)
112
113
114
115
116
117
118
    return prompts


@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
119
        prompts += _read_prompts(filename)
120
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123
124
125
126
127
128


_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
}

129
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
130

Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133

class HfRunner:

134
    def wrap_device(self, input: _T) -> _T:
135
136
137
138
139
        if not is_cpu():
            return input.to("cuda")
        else:
            return input.to("cpu")

Woosuk Kwon's avatar
Woosuk Kwon committed
140
141
142
143
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
144
145
146
        *,
        is_embedding_model: bool = False,
        is_vision_model: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
149
    ) -> None:
        assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
        torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
150

151
        self.model_name = model_name
152

153
        if is_embedding_model:
154
155
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
156
157
158
159
160
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
                ).to(dtype=torch_dtype))
161
        else:
162
163
164
165
166
            if is_vision_model:
                auto_cls = AutoModelForVision2Seq
            else:
                auto_cls = AutoModelForCausalLM

167
            self.model = self.wrap_device(
168
                auto_cls.from_pretrained(
169
170
171
172
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
                ))
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

        try:
            self.processor = AutoProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
        except Exception:
            logger.warning(
                "Unable to auto-load processor from HuggingFace for "
                "model %s. Using tokenizer instead.", model_name)
            self.processor = self.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
191
192
193
194

    def generate(
        self,
        prompts: List[str],
195
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
196
        **kwargs,
197
    ) -> List[Tuple[List[List[int]], List[str]]]:
198
199
        if images:
            assert len(prompts) == len(images)
200
201

        outputs: List[Tuple[List[List[int]], List[str]]] = []
202
        for i, prompt in enumerate(prompts):
203
204
205
206
207
208
209
210
211
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)

Woosuk Kwon's avatar
Woosuk Kwon committed
212
            output_ids = self.model.generate(
213
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
214
215
216
                use_cache=True,
                **kwargs,
            )
217
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219
220
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
221
222
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
225
226
227
228
229
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
230
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
231
    ) -> List[Tuple[List[int], str]]:
232
233
        outputs = self.generate(prompts,
                                do_sample=False,
234
235
                                max_new_tokens=max_tokens,
                                images=images)
236
237
238

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
239
240
241
242
243
244

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
245
    ) -> List[Tuple[List[List[int]], List[str]]]:
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
260

261
262
263
264
265
266
267
268
269
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
    ) -> List[List[torch.Tensor]]:
        all_logprobs = []
        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            output = self.model.generate(
270
                self.wrap_device(input_ids),
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )
            seq_logprobs = []
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
287
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
288
289
290
291
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

292
293
294
295
296
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
297
298
299
300
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
301
302
303
304

        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            output = self.model.generate(
305
                self.wrap_device(input_ids),
306
307
308
309
310
311
312
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )

313
            seq_logprobs: List[torch.Tensor] = []
314
315
316
317
318
319
320
321
322
323
            for _, hidden_states in enumerate(output.hidden_states):
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if getattr(self.model.get_output_embeddings(), "bias",
                           None) is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
324
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
325
326
327
                seq_logprobs.append(logprobs)

            # convert to dict
328
            seq_logprobs_lst: List[Dict[int, float]] = []
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
            for tok_idx, tok_logprobs in enumerate(seq_logprobs):
                # drop prompt logprobs
                if tok_idx == 0:
                    tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
                topk = tok_logprobs.topk(num_logprobs)

                tok_logprobs_dct = {}
                for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                    tok_logprobs_dct[token_id.item()] = logprob.item()

                seq_logprobs_lst.append(tok_logprobs_dct)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = seq_ids.shape[0] - input_ids.shape[1]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

352
353
354
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

355
356
357
358
    def __del__(self):
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
359
360
361
362
363
364
365
366
367
368
369
370

@pytest.fixture
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
371
372
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
373
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
374
        dtype: str = "half",
375
        disable_log_stats: bool = True,
376
        tensor_parallel_size: int = 1,
377
378
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
379
        swap_space: int = 4,
380
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
381
382
383
384
385
386
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
387
            swap_space=swap_space,
388
            disable_log_stats=disable_log_stats,
389
            tensor_parallel_size=tensor_parallel_size,
390
            max_model_len=max_model_len,
391
392
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
393
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
394
395
396
397
398
399
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
400
        images: Optional[List[MultiModalData]] = None,
401
    ) -> List[Tuple[List[List[int]], List[str]]]:
402
        if images is not None:
403
            assert len(prompts) == len(images)
404

405
406
407
408
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = image
409

410
        req_outputs = self.model.generate(inputs,
411
                                          sampling_params=sampling_params)
412
413

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
414
415
416
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
417
418
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
419
420
421
422
423
424
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = sample.token_ids
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
425
426
        return outputs

427
428
429
430
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
431
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
432
433
434
435
        assert sampling_params.logprobs is not None

        req_outputs = self.model.generate(prompts,
                                          sampling_params=sampling_params)
436
        outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
437
438
439
440
441
442
443
444
        for req_output in req_outputs:
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = sample.token_ids
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs))
        return outputs

Woosuk Kwon's avatar
Woosuk Kwon committed
445
446
447
448
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
449
        images: Optional[List[MultiModalData]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
450
451
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
452
        outputs = self.generate(prompts, greedy_params, images=images)
453
454
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
455

456
457
458
459
460
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
461
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
462
463
464
465
466
467
468
469
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                max_tokens=max_tokens,
                                                logprobs=num_logprobs)
        outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

470
471
472
473
474
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
475
    ) -> List[Tuple[List[List[int]], List[str]]]:
476
477
478
479
480
481
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
482

483
484
485
486
487
488
489
490
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

491
492
493
494
    def __del__(self):
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
495

496
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
497
498
def vllm_runner():
    return VllmRunner
499
500
501
502
503
504
505
506
507
508


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog