test_lazy_outlines.py 3.86 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import sys
5
import os
youkaichao's avatar
youkaichao committed
6
from contextlib import nullcontext
7

8
import pytest
youkaichao's avatar
youkaichao committed
9
from vllm_test_utils import BlameResult, blame
10
11

from vllm import LLM, SamplingParams
12
from vllm.distributed import cleanup_dist_env_and_memory
13
from ...utils import models_path_prefix
14
15


16
17
18
19
20
21
22
23
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    V1 only supports xgrammar so this is irrelevant.
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')


24
def run_normal_opt125m():
25
26
27
28
29
30
31
32
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

33
    # Create an LLM without guided decoding as a baseline.
34
    llm = LLM(model=os.path.join(models_path_prefix, "facebook/opt-125m"),
35
36
37
38
39
40
41
42
              enforce_eager=True,
              gpu_memory_utilization=0.3)
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

43
44
45
    # Destroy the LLM object and free up the GPU memory.
    del llm
    cleanup_dist_env_and_memory()
46

youkaichao's avatar
youkaichao committed
47

48
49
50
51
52
53
54
55
56
57
def run_normal():
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

    # Create an LLM without guided decoding as a baseline.
58
    llm = LLM(model="distilbert/distilgpt2",
59
60
61
62
63
64
65
66
67
68
69
70
71
              enforce_eager=True,
              gpu_memory_utilization=0.3)
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

    # Destroy the LLM object and free up the GPU memory.
    del llm
    cleanup_dist_env_and_memory()


youkaichao's avatar
youkaichao committed
72
def run_lmfe(sample_regex):
73
    # Create an LLM with guided decoding enabled.
zhuwenwen's avatar
zhuwenwen committed
74
    llm = LLM(model=os.path.join(models_path_prefix, "distilbert/distilgpt2"),
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
              enforce_eager=True,
              guided_decoding_backend="lm-format-enforcer",
              gpu_memory_utilization=0.3)
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
    outputs = llm.generate(
        prompts=[
            f"Give an example IPv4 address with this regex: {sample_regex}"
        ] * 2,
        sampling_params=sampling_params,
        use_tqdm=True,
        guided_options_request=dict(guided_regex=sample_regex))

    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

youkaichao's avatar
youkaichao committed
92
93
94
95

def test_lazy_outlines(sample_regex):
    """If users don't use guided decoding, outlines should not be imported.
    """
96
    # make sure outlines is not imported
youkaichao's avatar
youkaichao committed
97
    module_name = "outlines"
youkaichao's avatar
youkaichao committed
98
99
100
101
102
103
104
105
106
    # In CI, we only check finally if the module is imported.
    # If it is indeed imported, we can rerun the test with `use_blame=True`,
    # which will trace every function call to find the first import location,
    # and help find the root cause.
    # We don't run it in CI by default because it is slow.
    use_blame = False
    context = blame(
        lambda: module_name in sys.modules) if use_blame else nullcontext()
    with context as result:
youkaichao's avatar
youkaichao committed
107
108
        run_normal()
        run_lmfe(sample_regex)
youkaichao's avatar
youkaichao committed
109
110
111
112
113
114
    if use_blame:
        assert isinstance(result, BlameResult)
        print(f"the first import location is:\n{result.trace_stack}")
    assert module_name not in sys.modules, (
        f"Module {module_name} is imported. To see the first"
        f" import location, run the test with `use_blame=True`.")