conftest.py 3.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15

import pytest
import torch
from transformers import AutoTokenizer

from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
                                   NUM_SAMPLE_LOGPROBS_UNDER_TEST, PROMPT_LEN,
                                   TOKENIZER_NAME,
                                   DummyOutputProcessorTestVectors,
                                   generate_dummy_prompt_logprobs_tensors,
                                   generate_dummy_sample_logprobs)
from vllm.engine.arg_utils import EngineArgs

16
17
from ...distributed.conftest import publisher_config, random_port  # noqa: F401

18
19
from tests.v1.engine.utils import FULL_STRINGS  # isort: skip

20
21
EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]]
EngineCorePromptLogprobsType = tuple[torch.Tensor, torch.Tensor]
22
23
24
25


def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
    """Generate output processor dummy test vectors, without logprobs
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    Returns:
      DummyOutputProcessorTestVectors instance with no logprobs
    """

    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
    vllm_config = EngineArgs(model=TOKENIZER_NAME).create_engine_config()
    # Tokenize prompts under test & create dummy generated tokens
    prompt_tokens = [
        tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS
    ]
    generation_tokens = [
        tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS
    ]
    # Generate prompt strings
    prompt_strings = [
        tokenizer.decode(prompt_tokens, skip_special_tokens=True)
        for prompt_tokens in prompt_tokens
    ]
    prompt_strings_len = [
        len(prompt_string) for prompt_string in prompt_strings
    ]
    return DummyOutputProcessorTestVectors(
        tokenizer=tokenizer,
        vllm_config=vllm_config,
        full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS],
        prompt_tokens=prompt_tokens,
        generation_tokens=generation_tokens,
        prompt_strings=prompt_strings,
        prompt_strings_len=prompt_strings_len,
        generation_strings=[
            text[prompt_len:]
            for text, prompt_len in zip(FULL_STRINGS, prompt_strings_len)
        ],
        prompt_logprobs=[],
        generation_logprobs=[])


@pytest.fixture
def dummy_test_vectors() -> DummyOutputProcessorTestVectors:
    """Generate output processor dummy test vectors, with logprobs
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    Returns:
      DummyOutputProcessorTestVectors instance with logprobs
    """
    # Build dummy test vectors without logprobs
    dtv = _build_test_vectors_no_logprobs()
    # Inject logprobs into dummy test vectors
    # data structure
    dtv.generation_logprobs = [
        generate_dummy_sample_logprobs(
            sampled_tokens_list=tokens_list,
            num_logprobs=NUM_SAMPLE_LOGPROBS_UNDER_TEST,
            tokenizer=dtv.tokenizer) for tokens_list in dtv.generation_tokens
    ]
    dtv.prompt_logprobs = [
        generate_dummy_prompt_logprobs_tensors(
            prompt_tokens_list=tokens_list,
            num_logprobs=NUM_PROMPT_LOGPROBS_UNDER_TEST,
            tokenizer=dtv.tokenizer) for tokens_list in dtv.prompt_tokens
    ]
    return dtv