test_lazy_outlines.py 3.64 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import sys
youkaichao's avatar
youkaichao committed
4
from contextlib import nullcontext
5

youkaichao's avatar
youkaichao committed
6
from vllm_test_utils import BlameResult, blame
youkaichao's avatar
youkaichao committed
7

8
from vllm import LLM, SamplingParams
9
from vllm.config import LoadFormat
10
from vllm.distributed import cleanup_dist_env_and_memory
11
12


13
def run_normal_opt125m():
14
15
16
17
18
19
20
21
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

22
    # Create an LLM without guided decoding as a baseline.
23
24
25
26
27
28
29
30
31
    llm = LLM(model="facebook/opt-125m",
              enforce_eager=True,
              gpu_memory_utilization=0.3)
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

32
33
34
35
    # Destroy the LLM object and free up the GPU memory.
    del llm
    cleanup_dist_env_and_memory()

youkaichao's avatar
youkaichao committed
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def run_normal():
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

    # Create an LLM without guided decoding as a baseline.
    llm = LLM(model="s3://vllm-ci-model-weights/distilgpt2",
              load_format=LoadFormat.RUNAI_STREAMER,
              enforce_eager=True,
              gpu_memory_utilization=0.3)
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

    # Destroy the LLM object and free up the GPU memory.
    del llm
    cleanup_dist_env_and_memory()


youkaichao's avatar
youkaichao committed
62
def run_lmfe(sample_regex):
63
    # Create an LLM with guided decoding enabled.
64
65
    llm = LLM(model="s3://vllm-ci-model-weights/distilgpt2",
              load_format=LoadFormat.RUNAI_STREAMER,
66
67
              enforce_eager=True,
              guided_decoding_backend="lm-format-enforcer",
68
              gpu_memory_utilization=0.3)
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
    outputs = llm.generate(
        prompts=[
            f"Give an example IPv4 address with this regex: {sample_regex}"
        ] * 2,
        sampling_params=sampling_params,
        use_tqdm=True,
        guided_options_request=dict(guided_regex=sample_regex))

    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

youkaichao's avatar
youkaichao committed
83
84
85
86

def test_lazy_outlines(sample_regex):
    """If users don't use guided decoding, outlines should not be imported.
    """
87
    # make sure outlines is not imported
youkaichao's avatar
youkaichao committed
88
    module_name = "outlines"
youkaichao's avatar
youkaichao committed
89
90
91
92
93
94
95
96
97
    # In CI, we only check finally if the module is imported.
    # If it is indeed imported, we can rerun the test with `use_blame=True`,
    # which will trace every function call to find the first import location,
    # and help find the root cause.
    # We don't run it in CI by default because it is slow.
    use_blame = False
    context = blame(
        lambda: module_name in sys.modules) if use_blame else nullcontext()
    with context as result:
youkaichao's avatar
youkaichao committed
98
99
        run_normal()
        run_lmfe(sample_regex)
youkaichao's avatar
youkaichao committed
100
101
102
103
104
105
    if use_blame:
        assert isinstance(result, BlameResult)
        print(f"the first import location is:\n{result.trace_stack}")
    assert module_name not in sys.modules, (
        f"Module {module_name} is imported. To see the first"
        f" import location, run the test with `use_blame=True`.")