conftest.py 17.9 KB
Newer Older
1
2
import contextlib
import gc
3
import os
4
from typing import Any, Dict, List, Optional, Tuple, TypeVar
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7

import pytest
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
from PIL import Image
11
12
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
                          AutoProcessor, AutoTokenizer, BatchEncoding)
Woosuk Kwon's avatar
Woosuk Kwon committed
13
14

from vllm import LLM, SamplingParams
15
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
16
17
from vllm.distributed import (destroy_distributed_environment,
                              destroy_model_parallel)
18
from vllm.inputs import TextPrompt
19
from vllm.logger import init_logger
20
21
22
from vllm.multimodal import MultiModalData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.sequence import SampleLogprobs
23
from vllm.utils import cuda_device_count_stateless, is_cpu
24
25

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
26

27
28
29
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
30

31
# Multi modal related
32
# You can use `.buildkite/download-images.sh` to download the assets
33
PIXEL_VALUES_FILES = [
34
35
36
    os.path.join(_TEST_DIR, "images", filename) for filename in
    ["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
]
37
IMAGE_FEATURES_FILES = [
38
39
40
    os.path.join(_TEST_DIR, "images", filename) for filename in
    ["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
]
41
IMAGE_FILES = [
42
43
44
    os.path.join(_TEST_DIR, "images", filename)
    for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
]
45
assert len(PIXEL_VALUES_FILES) == len(IMAGE_FEATURES_FILES) == len(IMAGE_FILES)
46

47

48
def _read_prompts(filename: str) -> List[str]:
49
    with open(filename, "r") as f:
50
51
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53


54
55
def cleanup():
    destroy_model_parallel()
56
    destroy_distributed_environment()
57
58
59
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
60
61
    if not is_cpu():
        torch.cuda.empty_cache()
62
63


64
@pytest.fixture()
65
def should_do_global_cleanup_after_test(request) -> bool:
66
67
68
69
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
70
71
72
73

    if request.node.get_closest_marker("skip_global_cleanup"):
        return False

74
75
76
    return True


77
@pytest.fixture(autouse=True)
78
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
79
    yield
80
81
    if should_do_global_cleanup_after_test:
        cleanup()
82
83


84
85
@pytest.fixture(scope="session")
def hf_images() -> List[Image.Image]:
86
    return [Image.open(filename) for filename in IMAGE_FILES]
87
88
89


@pytest.fixture()
90
def vllm_images(request) -> List[MultiModalData]:
91
92
93
    vision_language_config = request.getfixturevalue("model_and_config")[1]
    if vision_language_config.image_input_type == (
            VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
94
95
        return [
            ImageFeatureData(torch.load(filename))
96
            for filename in IMAGE_FEATURES_FILES
97
        ]
98
    else:
99
        return [
100
            ImagePixelData(Image.open(filename)) for filename in IMAGE_FILES
101
102
103
104
105
        ]


@pytest.fixture()
def vllm_image_tensors(request) -> List[torch.Tensor]:
106
    return [torch.load(filename) for filename in PIXEL_VALUES_FILES]
107
108


Woosuk Kwon's avatar
Woosuk Kwon committed
109
110
@pytest.fixture
def example_prompts() -> List[str]:
111
112
    prompts = []
    for filename in _TEST_PROMPTS:
113
        prompts += _read_prompts(filename)
114
115
116
117
118
119
120
    return prompts


@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
121
        prompts += _read_prompts(filename)
122
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
123
124
125
126
127
128
129
130


_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
}

131
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
132

Woosuk Kwon's avatar
Woosuk Kwon committed
133
134
135

class HfRunner:

136
    def wrap_device(self, input: _T) -> _T:
137
138
139
140
141
        if not is_cpu():
            return input.to("cuda")
        else:
            return input.to("cpu")

Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144
145
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
146
147
148
        *,
        is_embedding_model: bool = False,
        is_vision_model: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
151
    ) -> None:
        assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
        torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
152

153
        self.model_name = model_name
154

155
        if is_embedding_model:
156
157
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
158
159
160
161
162
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
                ).to(dtype=torch_dtype))
163
        else:
164
165
166
167
168
            if is_vision_model:
                auto_cls = AutoModelForVision2Seq
            else:
                auto_cls = AutoModelForCausalLM

169
            self.model = self.wrap_device(
170
                auto_cls.from_pretrained(
171
172
173
174
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
                ))
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

        try:
            self.processor = AutoProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
        except Exception:
            logger.warning(
                "Unable to auto-load processor from HuggingFace for "
                "model %s. Using tokenizer instead.", model_name)
            self.processor = self.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
193
194
195
196

    def generate(
        self,
        prompts: List[str],
197
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
198
        **kwargs,
199
    ) -> List[Tuple[List[List[int]], List[str]]]:
200
201
        if images:
            assert len(prompts) == len(images)
202
203

        outputs: List[Tuple[List[List[int]], List[str]]] = []
204
        for i, prompt in enumerate(prompts):
205
206
207
208
209
210
211
212
213
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)

Woosuk Kwon's avatar
Woosuk Kwon committed
214
            output_ids = self.model.generate(
215
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
218
                use_cache=True,
                **kwargs,
            )
219
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
220
221
222
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
223
224
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
227
228
229
230
231
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
232
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
233
    ) -> List[Tuple[List[int], str]]:
234
235
        outputs = self.generate(prompts,
                                do_sample=False,
236
237
                                max_new_tokens=max_tokens,
                                images=images)
238
239
240

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
241
242
243
244
245
246

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
247
    ) -> List[Tuple[List[List[int]], List[str]]]:
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
262

263
264
265
266
267
268
269
270
271
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
    ) -> List[List[torch.Tensor]]:
        all_logprobs = []
        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            output = self.model.generate(
272
                self.wrap_device(input_ids),
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )
            seq_logprobs = []
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
289
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
290
291
292
293
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

294
295
296
297
298
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
299
300
301
302
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
303
304
305
306

        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            output = self.model.generate(
307
                self.wrap_device(input_ids),
308
309
310
311
312
313
314
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )

315
            seq_logprobs: List[torch.Tensor] = []
316
317
318
319
320
321
322
323
324
325
            for _, hidden_states in enumerate(output.hidden_states):
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if getattr(self.model.get_output_embeddings(), "bias",
                           None) is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
326
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
327
328
329
                seq_logprobs.append(logprobs)

            # convert to dict
330
            seq_logprobs_lst: List[Dict[int, float]] = []
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
            for tok_idx, tok_logprobs in enumerate(seq_logprobs):
                # drop prompt logprobs
                if tok_idx == 0:
                    tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
                topk = tok_logprobs.topk(num_logprobs)

                tok_logprobs_dct = {}
                for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                    tok_logprobs_dct[token_id.item()] = logprob.item()

                seq_logprobs_lst.append(tok_logprobs_dct)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = seq_ids.shape[0] - input_ids.shape[1]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

354
355
356
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

357
358
359
360
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
361
362
363
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
364
365
366
367
368
369
370
371
372
373
374
375

@pytest.fixture
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
376
377
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
378
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
379
        dtype: str = "half",
380
        disable_log_stats: bool = True,
381
        tensor_parallel_size: int = 1,
382
383
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
384
        swap_space: int = 4,
385
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
386
387
388
389
390
391
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
392
            swap_space=swap_space,
393
            disable_log_stats=disable_log_stats,
394
            tensor_parallel_size=tensor_parallel_size,
395
            max_model_len=max_model_len,
396
397
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
398
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
399
400
401
402
403
404
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
405
        images: Optional[List[MultiModalData]] = None,
406
    ) -> List[Tuple[List[List[int]], List[str]]]:
407
        if images is not None:
408
            assert len(prompts) == len(images)
409

410
411
412
413
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = image
414

415
        req_outputs = self.model.generate(inputs,
416
                                          sampling_params=sampling_params)
417
418

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
419
420
421
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
422
423
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
424
425
426
427
428
429
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = sample.token_ids
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
430
431
        return outputs

432
433
434
435
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
436
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
437
438
439
440
        assert sampling_params.logprobs is not None

        req_outputs = self.model.generate(prompts,
                                          sampling_params=sampling_params)
441
        outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
442
443
444
445
446
447
448
449
        for req_output in req_outputs:
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = sample.token_ids
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs))
        return outputs

Woosuk Kwon's avatar
Woosuk Kwon committed
450
451
452
453
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
454
        images: Optional[List[MultiModalData]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
455
456
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
457
        outputs = self.generate(prompts, greedy_params, images=images)
458
459
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
460

461
462
463
464
465
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
466
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
467
468
469
470
471
472
473
474
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                max_tokens=max_tokens,
                                                logprobs=num_logprobs)
        outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

475
476
477
478
479
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
480
    ) -> List[Tuple[List[List[int]], List[str]]]:
481
482
483
484
485
486
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
487

488
489
490
491
492
493
494
495
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

496
497
498
499
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
500
501
502
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
503

504
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
505
506
def vllm_runner():
    return VllmRunner
507
508
509
510
511
512
513
514
515
516


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
533
534
535
536
537
538
539


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

540
    return cuda_device_count_stateless()