conftest.py 17.1 KB
Newer Older
1
2
import contextlib
import gc
3
import os
4
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7

import pytest
import torch
8
from PIL import Image
9
10
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
                          LlavaConfig, LlavaForConditionalGeneration)
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12

from vllm import LLM, SamplingParams
13
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
14
from vllm.distributed import destroy_model_parallel
15
from vllm.logger import init_logger
16
from vllm.sequence import MultiModalData
17
18

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
19

20
21
22
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
23

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Multi modal related
_PIXEL_VALUES_FILES = [
    os.path.join(_TEST_DIR, "images", filename) for filename in
    ["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
]
_IMAGE_FEATURES_FILES = [
    os.path.join(_TEST_DIR, "images", filename) for filename in
    ["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
]
_IMAGE_FILES = [
    os.path.join(_TEST_DIR, "images", filename)
    for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
]
_IMAGE_PROMPTS = [
    "<image>\nUSER: What's the content of the image?\nASSISTANT:",
    "<image>\nUSER: What is the season?\nASSISTANT:"
]
assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len(
    _IMAGE_FILES) == len(_IMAGE_PROMPTS)

44

45
def _read_prompts(filename: str) -> List[str]:
46
    with open(filename, "r") as f:
47
48
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50


51
52
53
54
55
56
57
58
def cleanup():
    destroy_model_parallel()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
    torch.cuda.empty_cache()


59
@pytest.fixture()
60
def should_do_global_cleanup_after_test(request) -> bool:
61
62
63
64
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
65
66
67
68

    if request.node.get_closest_marker("skip_global_cleanup"):
        return False

69
70
71
    return True


72
@pytest.fixture(autouse=True)
73
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
74
    yield
75
76
    if should_do_global_cleanup_after_test:
        cleanup()
77
78


79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
@pytest.fixture(scope="session")
def hf_image_prompts() -> List[str]:
    return _IMAGE_PROMPTS


@pytest.fixture(scope="session")
def hf_images() -> List[Image.Image]:
    return [Image.open(filename) for filename in _IMAGE_FILES]


@pytest.fixture()
def vllm_images(request) -> "torch.Tensor":
    vision_language_config = request.getfixturevalue("model_and_config")[1]
    all_images = []
    if vision_language_config.image_input_type == (
            VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
        filenames = _IMAGE_FEATURES_FILES
    else:
        filenames = _PIXEL_VALUES_FILES
    for filename in filenames:
        all_images.append(torch.load(filename))
    return torch.concat(all_images, dim=0)


@pytest.fixture()
def vllm_image_prompts(request) -> List[str]:
    vision_language_config = request.getfixturevalue("model_and_config")[1]
    return [
        "<image>" * (vision_language_config.image_feature_size - 1) + p
        for p in _IMAGE_PROMPTS
    ]


Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
@pytest.fixture
def example_prompts() -> List[str]:
114
115
    prompts = []
    for filename in _TEST_PROMPTS:
116
        prompts += _read_prompts(filename)
117
118
119
120
121
122
123
    return prompts


@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
124
        prompts += _read_prompts(filename)
125
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
129
130
131
132
133


_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
}

134
AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration)
135

136
137
138
139
_EMBEDDING_MODELS = [
    "intfloat/e5-mistral-7b-instruct",
]

Woosuk Kwon's avatar
Woosuk Kwon committed
140
141
142
143
144
145
146
147
148
149

class HfRunner:

    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
    ) -> None:
        assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
        torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
150

151
        self.model_name = model_name
152
153

        if model_name in _EMBEDDING_MODELS:
154
155
156
157
158
159
160
161
162
163
164
165
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
            self.model = SentenceTransformer(
                model_name,
                device="cpu",
            ).to(dtype=torch_dtype).cuda()
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            ).cuda()
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

        try:
            self.processor = AutoProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
        except Exception:
            logger.warning(
                "Unable to auto-load processor from HuggingFace for "
                "model %s. Using tokenizer instead.", model_name)
            self.processor = self.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
184
185
186
187

    def generate(
        self,
        prompts: List[str],
188
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
189
190
191
        **kwargs,
    ) -> List[Tuple[List[int], str]]:
        outputs: List[Tuple[List[int], str]] = []
192
193
194
        if images:
            assert len(prompts) == len(images)
        for i, prompt in enumerate(prompts):
195
196
197
198
199
200
201
202
203
204
205
206
207
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)
            inputs = {
                key: value.cuda() if value is not None else None
                for key, value in inputs.items()
            }

Woosuk Kwon's avatar
Woosuk Kwon committed
208
            output_ids = self.model.generate(
209
                **inputs,
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211
212
213
214
215
216
                use_cache=True,
                **kwargs,
            )
            output_str = self.tokenizer.batch_decode(
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
217
218
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
221
222
223
224
225
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
226
        images: Optional["torch.Tensor"] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
227
    ) -> List[Tuple[List[int], str]]:
228
229
        outputs = self.generate(prompts,
                                do_sample=False,
230
231
                                max_new_tokens=max_tokens,
                                images=images)
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            outputs[i] = (output_ids[0], output_str[0])
        return outputs

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
    ) -> List[Tuple[List[int], str]]:
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
257

258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
    ) -> List[List[torch.Tensor]]:
        all_logprobs = []
        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            output = self.model.generate(
                input_ids.cuda(),
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )
            seq_logprobs = []
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
                logprobs = torch.nn.functional.log_softmax(logits,
                                                           dim=-1,
                                                           dtype=torch.float32)
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
    ) -> List[Tuple[List[int], str]]:
        all_logprobs = []
        all_output_ids = []
        all_output_strs = []

        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            output = self.model.generate(
                input_ids.cuda(),
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )

            seq_logprobs = []
            for _, hidden_states in enumerate(output.hidden_states):
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if getattr(self.model.get_output_embeddings(), "bias",
                           None) is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
                logprobs = torch.nn.functional.log_softmax(logits,
                                                           dim=-1,
                                                           dtype=torch.float32)
                seq_logprobs.append(logprobs)

            # convert to dict
            seq_logprobs_lst = []
            for tok_idx, tok_logprobs in enumerate(seq_logprobs):
                # drop prompt logprobs
                if tok_idx == 0:
                    tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
                topk = tok_logprobs.topk(num_logprobs)

                tok_logprobs_dct = {}
                for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                    tok_logprobs_dct[token_id.item()] = logprob.item()

                seq_logprobs_lst.append(tok_logprobs_dct)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = seq_ids.shape[0] - input_ids.shape[1]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

353
354
355
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

356
357
358
359
    def __del__(self):
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
360
361
362
363
364
365
366
367
368
369
370
371

@pytest.fixture
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
372
373
374
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
        max_model_len=1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
375
        dtype: str = "half",
376
        disable_log_stats: bool = True,
377
        tensor_parallel_size: int = 1,
378
379
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
380
        swap_space=4,
381
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
382
383
384
385
386
387
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
388
            swap_space=swap_space,
389
            disable_log_stats=disable_log_stats,
390
            tensor_parallel_size=tensor_parallel_size,
391
            max_model_len=max_model_len,
392
393
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
394
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
395
396
397
398
399
400
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
401
        images: Optional["torch.Tensor"] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
402
    ) -> List[Tuple[List[int], str]]:
403
404
405
406
407
408
409
410
        if images is not None:
            assert len(prompts) == images.shape[0]
        req_outputs = self.model.generate(
            prompts,
            sampling_params=sampling_params,
            multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE,
                                            data=images)
            if images is not None else None)
Woosuk Kwon's avatar
Woosuk Kwon committed
411
412
413
414
        outputs = []
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
415
416
417
418
419
420
421
422
            req_sample_output_ids = []
            req_sample_output_strs = []
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = sample.token_ids
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
423
424
        return outputs

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
    ) -> List[Tuple[List[int], str]]:
        assert sampling_params.logprobs is not None

        req_outputs = self.model.generate(prompts,
                                          sampling_params=sampling_params)
        outputs = []
        for req_output in req_outputs:
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = sample.token_ids
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs))
        return outputs

Woosuk Kwon's avatar
Woosuk Kwon committed
443
444
445
446
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
447
        images: Optional[torch.Tensor] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
448
449
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
450
        outputs = self.generate(prompts, greedy_params, images=images)
451
452
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
453

454
455
456
457
458
459
460
461
462
463
464
465
466
467
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
    ) -> List[Tuple[List[int], str]]:
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                max_tokens=max_tokens,
                                                logprobs=num_logprobs)
        outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

468
469
470
471
472
473
474
475
476
477
478
479
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
    ) -> List[Tuple[List[int], str]]:
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
480

481
482
483
484
485
486
487
488
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

489
490
491
492
    def __del__(self):
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
493

494
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
495
496
def vllm_runner():
    return VllmRunner
497
498
499
500
501
502
503
504
505
506


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog