encoder_decoder.py 5.44 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
Demonstrate prompting of text-to-text
汪志鹏's avatar
汪志鹏 committed
5
6
7
encoder/decoder models, specifically BART and mBART.

This script is refactored to allow model selection via command-line arguments.
8
"""
9

汪志鹏's avatar
汪志鹏 committed
10
11
12
import argparse
from typing import NamedTuple, Optional

13
from vllm import LLM, SamplingParams
14
15
16
17
18
19
from vllm.inputs import (
    ExplicitEncoderDecoderPrompt,
    TextPrompt,
    TokensPrompt,
    zip_enc_dec_prompts,
)
20

21

汪志鹏's avatar
汪志鹏 committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class ModelRequestData(NamedTuple):
    """
    Holds the configuration for a specific model, including its
    HuggingFace ID and the prompts to use for the demo.
    """

    model_id: str
    encoder_prompts: list
    decoder_prompts: list
    hf_overrides: Optional[dict] = None


def get_bart_config() -> ModelRequestData:
    """
    Returns the configuration for facebook/bart-large-cnn.
    This uses the exact test cases from the original script.
    """
    encoder_prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "An encoder prompt",
    ]
    decoder_prompts = [
        "A decoder prompt",
        "Another decoder prompt",
    ]
    return ModelRequestData(
        model_id="facebook/bart-large-cnn",
        encoder_prompts=encoder_prompts,
        decoder_prompts=decoder_prompts,
53
    )
汪志鹏's avatar
汪志鹏 committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71


def get_mbart_config() -> ModelRequestData:
    """
    Returns the configuration for facebook/mbart-large-en-ro.
    This uses prompts suitable for an English-to-Romanian translation task.
    """
    encoder_prompts = [
        "The quick brown fox jumps over the lazy dog.",
        "How are you today?",
    ]
    decoder_prompts = ["", ""]
    hf_overrides = {"architectures": ["MBartForConditionalGeneration"]}
    return ModelRequestData(
        model_id="facebook/mbart-large-en-ro",
        encoder_prompts=encoder_prompts,
        decoder_prompts=decoder_prompts,
        hf_overrides=hf_overrides,
72
    )
汪志鹏's avatar
汪志鹏 committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96


MODEL_GETTERS = {
    "bart": get_bart_config,
    "mbart": get_mbart_config,
}


def create_all_prompt_types(
    encoder_prompts_raw: list,
    decoder_prompts_raw: list,
    tokenizer,
) -> list:
    """
    Generates a list of diverse prompt types for demonstration.
    This function is generic and uses the provided raw prompts
    to create various vLLM input objects.
    """
    text_prompt_raw = encoder_prompts_raw[0]
    text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)])
    tokens_prompt = TokensPrompt(
        prompt_token_ids=tokenizer.encode(
            encoder_prompts_raw[2 % len(encoder_prompts_raw)]
        )
97
98
    )

汪志鹏's avatar
汪志鹏 committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    decoder_tokens_prompt = TokensPrompt(
        prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0])
    )
    single_prompt_examples = [
        text_prompt_raw,
        text_prompt,
        tokens_prompt,
    ]
    explicit_pair_examples = [
        ExplicitEncoderDecoderPrompt(
            encoder_prompt=text_prompt_raw,
            decoder_prompt=decoder_tokens_prompt,
        ),
        ExplicitEncoderDecoderPrompt(
            encoder_prompt=text_prompt,
            decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)],
        ),
        ExplicitEncoderDecoderPrompt(
            encoder_prompt=tokens_prompt,
            decoder_prompt=text_prompt,
        ),
    ]
121
    zipped_prompt_list = zip_enc_dec_prompts(
汪志鹏's avatar
汪志鹏 committed
122
123
        encoder_prompts_raw,
        decoder_prompts_raw,
124
    )
汪志鹏's avatar
汪志鹏 committed
125
    return single_prompt_examples + explicit_pair_examples + zipped_prompt_list
126
127


汪志鹏's avatar
汪志鹏 committed
128
129
def create_sampling_params() -> SamplingParams:
    """Create a sampling params object."""
130
131
132
133
    return SamplingParams(
        temperature=0,
        top_p=1.0,
        min_tokens=0,
汪志鹏's avatar
汪志鹏 committed
134
        max_tokens=30,
135
136
    )

137

汪志鹏's avatar
汪志鹏 committed
138
139
140
def print_outputs(outputs: list):
    """Formats and prints the generation outputs."""
    print("-" * 80)
141
142
143
144
    for i, output in enumerate(outputs):
        prompt = output.prompt
        encoder_prompt = output.encoder_prompt
        generated_text = output.outputs[0].text
145
        print(f"Output {i + 1}:")
汪志鹏's avatar
汪志鹏 committed
146
147
148
149
150
151
152
153
154
155
156
157
158
        print(f"Encoder Prompt: {encoder_prompt!r}")
        print(f"Decoder Prompt: {prompt!r}")
        print(f"Generated Text: {generated_text!r}")
        print("-" * 80)


def main(args):
    """Main execution function."""
    model_key = args.model
    if model_key not in MODEL_GETTERS:
        raise ValueError(
            f"Unknown model: {model_key}. "
            f"Available models: {list(MODEL_GETTERS.keys())}"
159
        )
汪志鹏's avatar
汪志鹏 committed
160
161
    config_getter = MODEL_GETTERS[model_key]
    model_config = config_getter()
162

汪志鹏's avatar
汪志鹏 committed
163
    print(f"🚀 Running demo for model: {model_config.model_id}")
164
    llm = LLM(
汪志鹏's avatar
汪志鹏 committed
165
166
167
        model=model_config.model_id,
        dtype="float",
        hf_overrides=model_config.hf_overrides,
168
169
    )
    tokenizer = llm.llm_engine.get_tokenizer_group()
汪志鹏's avatar
汪志鹏 committed
170
171
172
173
174
    prompts = create_all_prompt_types(
        encoder_prompts_raw=model_config.encoder_prompts,
        decoder_prompts_raw=model_config.decoder_prompts,
        tokenizer=tokenizer,
    )
175
176
177
178
179
180
    sampling_params = create_sampling_params()
    outputs = llm.generate(prompts, sampling_params)
    print_outputs(outputs)


if __name__ == "__main__":
汪志鹏's avatar
汪志鹏 committed
181
182
183
184
185
186
187
188
189
190
191
192
193
    parser = argparse.ArgumentParser(
        description="A flexible demo for vLLM encoder-decoder models."
    )
    parser.add_argument(
        "--model",
        "-m",
        type=str,
        default="bart",
        choices=MODEL_GETTERS.keys(),
        help="The short name of the model to run.",
    )
    args = parser.parse_args()
    main(args)