"vscode:/vscode.git/clone" did not exist on "27c6c2f98c68971fcf4e763938efd2350f517dc5"
conftest.py 18.3 KB
Newer Older
1
2
import contextlib
import gc
3
import os
4
5
import subprocess
import sys
6
from typing import Any, Dict, List, Optional, Tuple, TypeVar
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9

import pytest
import torch
10
import torch.nn as nn
11
import torch.nn.functional as F
12
from PIL import Image
13
14
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
                          AutoProcessor, AutoTokenizer, BatchEncoding)
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16

from vllm import LLM, SamplingParams
17
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
18
19
from vllm.distributed import (destroy_distributed_environment,
                              destroy_model_parallel)
20
from vllm.inputs import TextPrompt
21
from vllm.logger import init_logger
22
23
24
from vllm.multimodal import MultiModalData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.sequence import SampleLogprobs
25
from vllm.utils import is_cpu
26
27

logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
28

29
30
31
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
32

33
# Multi modal related
34
# You can use `.buildkite/download-images.sh` to download the assets
35
PIXEL_VALUES_FILES = [
36
37
38
    os.path.join(_TEST_DIR, "images", filename) for filename in
    ["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
]
39
IMAGE_FEATURES_FILES = [
40
41
42
    os.path.join(_TEST_DIR, "images", filename) for filename in
    ["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
]
43
IMAGE_FILES = [
44
45
46
    os.path.join(_TEST_DIR, "images", filename)
    for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
]
47
assert len(PIXEL_VALUES_FILES) == len(IMAGE_FEATURES_FILES) == len(IMAGE_FILES)
48

49

50
def _read_prompts(filename: str) -> List[str]:
51
    with open(filename, "r") as f:
52
53
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
54
55


56
57
def cleanup():
    destroy_model_parallel()
58
    destroy_distributed_environment()
59
60
61
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
62
63
    if not is_cpu():
        torch.cuda.empty_cache()
64
65


66
@pytest.fixture()
67
def should_do_global_cleanup_after_test(request) -> bool:
68
69
70
71
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
72
73
74
75

    if request.node.get_closest_marker("skip_global_cleanup"):
        return False

76
77
78
    return True


79
@pytest.fixture(autouse=True)
80
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
81
    yield
82
83
    if should_do_global_cleanup_after_test:
        cleanup()
84
85


86
87
@pytest.fixture(scope="session")
def hf_images() -> List[Image.Image]:
88
    return [Image.open(filename) for filename in IMAGE_FILES]
89
90
91


@pytest.fixture()
92
def vllm_images(request) -> List[MultiModalData]:
93
94
95
    vision_language_config = request.getfixturevalue("model_and_config")[1]
    if vision_language_config.image_input_type == (
            VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
96
97
        return [
            ImageFeatureData(torch.load(filename))
98
            for filename in IMAGE_FEATURES_FILES
99
        ]
100
    else:
101
        return [
102
            ImagePixelData(Image.open(filename)) for filename in IMAGE_FILES
103
104
105
106
107
        ]


@pytest.fixture()
def vllm_image_tensors(request) -> List[torch.Tensor]:
108
    return [torch.load(filename) for filename in PIXEL_VALUES_FILES]
109
110


Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
@pytest.fixture
def example_prompts() -> List[str]:
113
114
    prompts = []
    for filename in _TEST_PROMPTS:
115
        prompts += _read_prompts(filename)
116
117
118
119
120
121
122
    return prompts


@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
123
        prompts += _read_prompts(filename)
124
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
128
129
130
131
132


_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
}

133
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
134

Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
137

class HfRunner:

138
    def wrap_device(self, input: _T) -> _T:
139
140
141
142
143
        if not is_cpu():
            return input.to("cuda")
        else:
            return input.to("cpu")

Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
146
147
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
148
149
150
        *,
        is_embedding_model: bool = False,
        is_vision_model: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
153
    ) -> None:
        assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
        torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
154

155
        self.model_name = model_name
156

157
        if is_embedding_model:
158
159
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
160
161
162
163
164
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
                ).to(dtype=torch_dtype))
165
        else:
166
167
168
169
170
            if is_vision_model:
                auto_cls = AutoModelForVision2Seq
            else:
                auto_cls = AutoModelForCausalLM

171
            self.model = self.wrap_device(
172
                auto_cls.from_pretrained(
173
174
175
176
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
                ))
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

        try:
            self.processor = AutoProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
        except Exception:
            logger.warning(
                "Unable to auto-load processor from HuggingFace for "
                "model %s. Using tokenizer instead.", model_name)
            self.processor = self.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
195
196
197
198

    def generate(
        self,
        prompts: List[str],
199
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
200
        **kwargs,
201
    ) -> List[Tuple[List[List[int]], List[str]]]:
202
203
        if images:
            assert len(prompts) == len(images)
204
205

        outputs: List[Tuple[List[List[int]], List[str]]] = []
206
        for i, prompt in enumerate(prompts):
207
208
209
210
211
212
213
214
215
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)

Woosuk Kwon's avatar
Woosuk Kwon committed
216
            output_ids = self.model.generate(
217
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219
220
                use_cache=True,
                **kwargs,
            )
221
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
222
223
224
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
225
226
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
227
228
229
230
231
232
233
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
234
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
235
    ) -> List[Tuple[List[int], str]]:
236
237
        outputs = self.generate(prompts,
                                do_sample=False,
238
239
                                max_new_tokens=max_tokens,
                                images=images)
240
241
242

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
243
244
245
246
247
248

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
249
    ) -> List[Tuple[List[List[int]], List[str]]]:
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
264

265
266
267
268
269
270
271
272
273
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
    ) -> List[List[torch.Tensor]]:
        all_logprobs = []
        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            output = self.model.generate(
274
                self.wrap_device(input_ids),
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )
            seq_logprobs = []
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
291
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
292
293
294
295
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

296
297
298
299
300
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
301
302
303
304
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
305
306
307
308

        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            output = self.model.generate(
309
                self.wrap_device(input_ids),
310
311
312
313
314
315
316
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )

317
            seq_logprobs: List[torch.Tensor] = []
318
319
320
321
322
323
324
325
326
327
            for _, hidden_states in enumerate(output.hidden_states):
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if getattr(self.model.get_output_embeddings(), "bias",
                           None) is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
328
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
329
330
331
                seq_logprobs.append(logprobs)

            # convert to dict
332
            seq_logprobs_lst: List[Dict[int, float]] = []
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
            for tok_idx, tok_logprobs in enumerate(seq_logprobs):
                # drop prompt logprobs
                if tok_idx == 0:
                    tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
                topk = tok_logprobs.topk(num_logprobs)

                tok_logprobs_dct = {}
                for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                    tok_logprobs_dct[token_id.item()] = logprob.item()

                seq_logprobs_lst.append(tok_logprobs_dct)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = seq_ids.shape[0] - input_ids.shape[1]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

356
357
358
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

359
360
361
362
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
363
364
365
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
366
367
368
369
370
371
372
373
374
375
376
377

@pytest.fixture
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
378
379
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
380
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
381
        dtype: str = "half",
382
        disable_log_stats: bool = True,
383
        tensor_parallel_size: int = 1,
384
385
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
386
        swap_space: int = 4,
387
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
388
389
390
391
392
393
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
394
            swap_space=swap_space,
395
            disable_log_stats=disable_log_stats,
396
            tensor_parallel_size=tensor_parallel_size,
397
            max_model_len=max_model_len,
398
399
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
400
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
401
402
403
404
405
406
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
407
        images: Optional[List[MultiModalData]] = None,
408
    ) -> List[Tuple[List[List[int]], List[str]]]:
409
        if images is not None:
410
            assert len(prompts) == len(images)
411

412
413
414
415
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = image
416

417
        req_outputs = self.model.generate(inputs,
418
                                          sampling_params=sampling_params)
419
420

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
421
422
423
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
424
425
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
426
427
428
429
430
431
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = sample.token_ids
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
432
433
        return outputs

434
435
436
437
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
438
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
439
440
441
442
        assert sampling_params.logprobs is not None

        req_outputs = self.model.generate(prompts,
                                          sampling_params=sampling_params)
443
        outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
444
445
446
447
448
449
450
451
        for req_output in req_outputs:
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = sample.token_ids
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs))
        return outputs

Woosuk Kwon's avatar
Woosuk Kwon committed
452
453
454
455
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
456
        images: Optional[List[MultiModalData]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
457
458
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
459
        outputs = self.generate(prompts, greedy_params, images=images)
460
461
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
462

463
464
465
466
467
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
468
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
469
470
471
472
473
474
475
476
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                max_tokens=max_tokens,
                                                logprobs=num_logprobs)
        outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

477
478
479
480
481
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
482
    ) -> List[Tuple[List[List[int]], List[str]]]:
483
484
485
486
487
488
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
489

490
491
492
493
494
495
496
497
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

498
499
500
501
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
502
503
504
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
505

506
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
507
508
def vllm_runner():
    return VllmRunner
509
510
511
512
513
514
515
516
517
518


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

    try:
        out = subprocess.run([
            sys.executable, "-c",
            "import torch; print(torch.cuda.device_count())"
        ],
                             capture_output=True,
                             check=True,
                             text=True)
    except subprocess.CalledProcessError as e:
        logger.warning("Failed to get number of GPUs.", exc_info=e)
        return 0
    return int(out.stdout.strip())