conftest.py 1.95 KB
Newer Older
1
2
from typing import List, Tuple

3
4
5
6
7
8
9
10
import pytest

from tests.conftest import cleanup
from vllm import LLM
from vllm.model_executor.utils import set_random_seed


@pytest.fixture
11
12
13
14
15
def baseline_llm_generator(request, common_llm_kwargs,
                           per_test_common_llm_kwargs, baseline_llm_kwargs,
                           seed):
    return create_llm_generator("baseline", request, common_llm_kwargs,
                                per_test_common_llm_kwargs,
16
17
18
19
                                baseline_llm_kwargs, seed)


@pytest.fixture
20
def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
21
                       test_llm_kwargs, seed):
22
23
24
    return create_llm_generator("test", request, common_llm_kwargs,
                                per_test_common_llm_kwargs, test_llm_kwargs,
                                seed)
25
26


27
28
29
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
                         per_test_common_llm_kwargs, distinct_llm_kwargs,
                         seed):
30
31
32
33
34
    kwargs = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **distinct_llm_kwargs,
    }
35
    test_name = request.node.name
36
37

    def generator_inner():
38
        print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
39
40
41
42
43
44
45
46
        llm = LLM(**kwargs)

        set_random_seed(seed)

        yield llm
        del llm
        cleanup()

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    def generator_outer():
        for llm in generator_inner():
            yield llm
            del llm

    return generator_outer


def get_output_from_llm_generator(
        llm_generator, prompts,
        sampling_params) -> Tuple[List[str], List[List[int]]]:
    tokens = []
    token_ids = []
    for llm in llm_generator():
        outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
        token_ids = [output.outputs[0].token_ids for output in outputs]
        tokens = [output.outputs[0].text for output in outputs]
64
        del llm
65
66

    return tokens, token_ids