conftest.py 17.3 KB
Newer Older
1
2
import contextlib
import gc
3
import os
4
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7

import pytest
import torch
8
import torch.nn.functional as F
9
from PIL import Image
10
11
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
                          LlavaConfig, LlavaForConditionalGeneration)
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13

from vllm import LLM, SamplingParams
14
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
15
from vllm.distributed import destroy_model_parallel
16
from vllm.inputs import TextPrompt
17
from vllm.logger import init_logger
18
from vllm.sequence import MultiModalData, SampleLogprobs
19
20

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
21

22
23
24
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Multi modal related
_PIXEL_VALUES_FILES = [
    os.path.join(_TEST_DIR, "images", filename) for filename in
    ["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
]
_IMAGE_FEATURES_FILES = [
    os.path.join(_TEST_DIR, "images", filename) for filename in
    ["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
]
_IMAGE_FILES = [
    os.path.join(_TEST_DIR, "images", filename)
    for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
]
_IMAGE_PROMPTS = [
    "<image>\nUSER: What's the content of the image?\nASSISTANT:",
    "<image>\nUSER: What is the season?\nASSISTANT:"
]
assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len(
    _IMAGE_FILES) == len(_IMAGE_PROMPTS)

46

47
def _read_prompts(filename: str) -> List[str]:
48
    with open(filename, "r") as f:
49
50
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52


53
54
55
56
57
58
59
60
def cleanup():
    destroy_model_parallel()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
    torch.cuda.empty_cache()


61
@pytest.fixture()
62
def should_do_global_cleanup_after_test(request) -> bool:
63
64
65
66
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
67
68
69
70

    if request.node.get_closest_marker("skip_global_cleanup"):
        return False

71
72
73
    return True


74
@pytest.fixture(autouse=True)
75
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
76
    yield
77
78
    if should_do_global_cleanup_after_test:
        cleanup()
79
80


81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@pytest.fixture(scope="session")
def hf_image_prompts() -> List[str]:
    return _IMAGE_PROMPTS


@pytest.fixture(scope="session")
def hf_images() -> List[Image.Image]:
    return [Image.open(filename) for filename in _IMAGE_FILES]


@pytest.fixture()
def vllm_images(request) -> "torch.Tensor":
    vision_language_config = request.getfixturevalue("model_and_config")[1]
    all_images = []
    if vision_language_config.image_input_type == (
            VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
        filenames = _IMAGE_FEATURES_FILES
    else:
        filenames = _PIXEL_VALUES_FILES
    for filename in filenames:
        all_images.append(torch.load(filename))
    return torch.concat(all_images, dim=0)


@pytest.fixture()
def vllm_image_prompts(request) -> List[str]:
    vision_language_config = request.getfixturevalue("model_and_config")[1]
    return [
        "<image>" * (vision_language_config.image_feature_size - 1) + p
        for p in _IMAGE_PROMPTS
    ]


Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
@pytest.fixture
def example_prompts() -> List[str]:
116
117
    prompts = []
    for filename in _TEST_PROMPTS:
118
        prompts += _read_prompts(filename)
119
120
121
122
123
124
125
    return prompts


@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
126
        prompts += _read_prompts(filename)
127
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
128
129
130
131
132
133
134
135


_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
}

136
AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration)
137

138
139
140
141
_EMBEDDING_MODELS = [
    "intfloat/e5-mistral-7b-instruct",
]

Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144
145
146
147
148
149
150
151

class HfRunner:

    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
    ) -> None:
        assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
        torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
152

153
        self.model_name = model_name
154
155

        if model_name in _EMBEDDING_MODELS:
156
157
158
159
160
161
162
163
164
165
166
167
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
            self.model = SentenceTransformer(
                model_name,
                device="cpu",
            ).to(dtype=torch_dtype).cuda()
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            ).cuda()
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

        try:
            self.processor = AutoProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
        except Exception:
            logger.warning(
                "Unable to auto-load processor from HuggingFace for "
                "model %s. Using tokenizer instead.", model_name)
            self.processor = self.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
188
189

    def generate(
        self,
        prompts: List[str],
190
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
191
        **kwargs,
192
    ) -> List[Tuple[List[List[int]], List[str]]]:
193
194
        if images:
            assert len(prompts) == len(images)
195
196

        outputs: List[Tuple[List[List[int]], List[str]]] = []
197
        for i, prompt in enumerate(prompts):
198
199
200
201
202
203
204
205
206
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)

Woosuk Kwon's avatar
Woosuk Kwon committed
207
            output_ids = self.model.generate(
208
                **inputs.to("cuda"),
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211
                use_cache=True,
                **kwargs,
            )
212
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
213
214
215
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
216
217
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219
220
221
222
223
224
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
225
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
226
    ) -> List[Tuple[List[int], str]]:
227
228
        outputs = self.generate(prompts,
                                do_sample=False,
229
230
                                max_new_tokens=max_tokens,
                                images=images)
231
232
233

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
234
235
236
237
238
239

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
240
    ) -> List[Tuple[List[List[int]], List[str]]]:
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
255

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
    ) -> List[List[torch.Tensor]]:
        all_logprobs = []
        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            output = self.model.generate(
                input_ids.cuda(),
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )
            seq_logprobs = []
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
282
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
283
284
285
286
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

287
288
289
290
291
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
292
293
294
295
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
296
297
298
299
300
301
302
303
304
305
306
307

        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            output = self.model.generate(
                input_ids.cuda(),
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )

308
            seq_logprobs: List[torch.Tensor] = []
309
310
311
312
313
314
315
316
317
318
            for _, hidden_states in enumerate(output.hidden_states):
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if getattr(self.model.get_output_embeddings(), "bias",
                           None) is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
319
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
320
321
322
                seq_logprobs.append(logprobs)

            # convert to dict
323
            seq_logprobs_lst: List[Dict[int, float]] = []
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
            for tok_idx, tok_logprobs in enumerate(seq_logprobs):
                # drop prompt logprobs
                if tok_idx == 0:
                    tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
                topk = tok_logprobs.topk(num_logprobs)

                tok_logprobs_dct = {}
                for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                    tok_logprobs_dct[token_id.item()] = logprob.item()

                seq_logprobs_lst.append(tok_logprobs_dct)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = seq_ids.shape[0] - input_ids.shape[1]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

347
348
349
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

350
351
352
353
    def __del__(self):
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
354
355
356
357
358
359
360
361
362
363
364
365

@pytest.fixture
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
366
367
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
368
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
369
        dtype: str = "half",
370
        disable_log_stats: bool = True,
371
        tensor_parallel_size: int = 1,
372
373
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
374
        swap_space: int = 4,
375
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
376
377
378
379
380
381
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
382
            swap_space=swap_space,
383
            disable_log_stats=disable_log_stats,
384
            tensor_parallel_size=tensor_parallel_size,
385
            max_model_len=max_model_len,
386
387
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
388
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
389
390
391
392
393
394
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
395
396
        images: Optional[torch.Tensor] = None,
    ) -> List[Tuple[List[List[int]], List[str]]]:
397
        if images is not None:
398
            assert len(prompts) == len(images)
399

400
        prompt_inputs: List[TextPrompt] = []
401
        for i, prompt in enumerate(prompts):
402
403
404
405
406
407
            prompt = TextPrompt(prompt=prompt)
            if images is not None:
                prompt["multi_modal_data"] = MultiModalData(
                    type=MultiModalData.Type.IMAGE,
                    data=images[i:i + 1],
                )
408

409
            prompt_inputs.append(prompt)
410
411
412

        req_outputs = self.model.generate(prompt_inputs,
                                          sampling_params=sampling_params)
413
414

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
417
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
418
419
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
420
421
422
423
424
425
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = sample.token_ids
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
426
427
        return outputs

428
429
430
431
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
432
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
433
434
435
436
        assert sampling_params.logprobs is not None

        req_outputs = self.model.generate(prompts,
                                          sampling_params=sampling_params)
437
        outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
438
439
440
441
442
443
444
445
        for req_output in req_outputs:
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = sample.token_ids
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs))
        return outputs

Woosuk Kwon's avatar
Woosuk Kwon committed
446
447
448
449
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
450
        images: Optional[torch.Tensor] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
451
452
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
453
        outputs = self.generate(prompts, greedy_params, images=images)
454
455
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
456

457
458
459
460
461
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
462
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
463
464
465
466
467
468
469
470
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                max_tokens=max_tokens,
                                                logprobs=num_logprobs)
        outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

471
472
473
474
475
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
476
    ) -> List[Tuple[List[List[int]], List[str]]]:
477
478
479
480
481
482
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
483

484
485
486
487
488
489
490
491
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

492
493
494
495
    def __del__(self):
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
496

497
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
498
499
def vllm_runner():
    return VllmRunner
500
501
502
503
504
505
506
507
508
509


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog