encoder_decoder.py 5.48 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
Demonstrate prompting of text-to-text
汪志鹏's avatar
汪志鹏 committed
5
6
7
encoder/decoder models, specifically BART and mBART.

This script is refactored to allow model selection via command-line arguments.
8
9

NOTE: This example is not yet supported in V1.
10
"""
11

汪志鹏's avatar
汪志鹏 committed
12
13
14
import argparse
from typing import NamedTuple, Optional

15
from vllm import LLM, SamplingParams
16
17
18
19
20
21
from vllm.inputs import (
    ExplicitEncoderDecoderPrompt,
    TextPrompt,
    TokensPrompt,
    zip_enc_dec_prompts,
)
22

23

汪志鹏's avatar
汪志鹏 committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class ModelRequestData(NamedTuple):
    """
    Holds the configuration for a specific model, including its
    HuggingFace ID and the prompts to use for the demo.
    """

    model_id: str
    encoder_prompts: list
    decoder_prompts: list
    hf_overrides: Optional[dict] = None


def get_bart_config() -> ModelRequestData:
    """
    Returns the configuration for facebook/bart-large-cnn.
    This uses the exact test cases from the original script.
    """
    encoder_prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "An encoder prompt",
    ]
    decoder_prompts = [
        "A decoder prompt",
        "Another decoder prompt",
    ]
    return ModelRequestData(
        model_id="facebook/bart-large-cnn",
        encoder_prompts=encoder_prompts,
        decoder_prompts=decoder_prompts,
55
    )
汪志鹏's avatar
汪志鹏 committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73


def get_mbart_config() -> ModelRequestData:
    """
    Returns the configuration for facebook/mbart-large-en-ro.
    This uses prompts suitable for an English-to-Romanian translation task.
    """
    encoder_prompts = [
        "The quick brown fox jumps over the lazy dog.",
        "How are you today?",
    ]
    decoder_prompts = ["", ""]
    hf_overrides = {"architectures": ["MBartForConditionalGeneration"]}
    return ModelRequestData(
        model_id="facebook/mbart-large-en-ro",
        encoder_prompts=encoder_prompts,
        decoder_prompts=decoder_prompts,
        hf_overrides=hf_overrides,
74
    )
汪志鹏's avatar
汪志鹏 committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98


MODEL_GETTERS = {
    "bart": get_bart_config,
    "mbart": get_mbart_config,
}


def create_all_prompt_types(
    encoder_prompts_raw: list,
    decoder_prompts_raw: list,
    tokenizer,
) -> list:
    """
    Generates a list of diverse prompt types for demonstration.
    This function is generic and uses the provided raw prompts
    to create various vLLM input objects.
    """
    text_prompt_raw = encoder_prompts_raw[0]
    text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)])
    tokens_prompt = TokensPrompt(
        prompt_token_ids=tokenizer.encode(
            encoder_prompts_raw[2 % len(encoder_prompts_raw)]
        )
99
100
    )

汪志鹏's avatar
汪志鹏 committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    decoder_tokens_prompt = TokensPrompt(
        prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0])
    )
    single_prompt_examples = [
        text_prompt_raw,
        text_prompt,
        tokens_prompt,
    ]
    explicit_pair_examples = [
        ExplicitEncoderDecoderPrompt(
            encoder_prompt=text_prompt_raw,
            decoder_prompt=decoder_tokens_prompt,
        ),
        ExplicitEncoderDecoderPrompt(
            encoder_prompt=text_prompt,
            decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)],
        ),
        ExplicitEncoderDecoderPrompt(
            encoder_prompt=tokens_prompt,
            decoder_prompt=text_prompt,
        ),
    ]
123
    zipped_prompt_list = zip_enc_dec_prompts(
汪志鹏's avatar
汪志鹏 committed
124
125
        encoder_prompts_raw,
        decoder_prompts_raw,
126
    )
汪志鹏's avatar
汪志鹏 committed
127
    return single_prompt_examples + explicit_pair_examples + zipped_prompt_list
128
129


汪志鹏's avatar
汪志鹏 committed
130
131
def create_sampling_params() -> SamplingParams:
    """Create a sampling params object."""
132
133
134
135
    return SamplingParams(
        temperature=0,
        top_p=1.0,
        min_tokens=0,
汪志鹏's avatar
汪志鹏 committed
136
        max_tokens=30,
137
138
    )

139

汪志鹏's avatar
汪志鹏 committed
140
141
142
def print_outputs(outputs: list):
    """Formats and prints the generation outputs."""
    print("-" * 80)
143
144
145
146
    for i, output in enumerate(outputs):
        prompt = output.prompt
        encoder_prompt = output.encoder_prompt
        generated_text = output.outputs[0].text
147
        print(f"Output {i + 1}:")
汪志鹏's avatar
汪志鹏 committed
148
149
150
151
152
153
154
155
156
157
158
159
160
        print(f"Encoder Prompt: {encoder_prompt!r}")
        print(f"Decoder Prompt: {prompt!r}")
        print(f"Generated Text: {generated_text!r}")
        print("-" * 80)


def main(args):
    """Main execution function."""
    model_key = args.model
    if model_key not in MODEL_GETTERS:
        raise ValueError(
            f"Unknown model: {model_key}. "
            f"Available models: {list(MODEL_GETTERS.keys())}"
161
        )
汪志鹏's avatar
汪志鹏 committed
162
163
    config_getter = MODEL_GETTERS[model_key]
    model_config = config_getter()
164

汪志鹏's avatar
汪志鹏 committed
165
    print(f"🚀 Running demo for model: {model_config.model_id}")
166
    llm = LLM(
汪志鹏's avatar
汪志鹏 committed
167
168
169
        model=model_config.model_id,
        dtype="float",
        hf_overrides=model_config.hf_overrides,
170
171
    )
    tokenizer = llm.llm_engine.get_tokenizer_group()
汪志鹏's avatar
汪志鹏 committed
172
173
174
175
176
    prompts = create_all_prompt_types(
        encoder_prompts_raw=model_config.encoder_prompts,
        decoder_prompts_raw=model_config.decoder_prompts,
        tokenizer=tokenizer,
    )
177
178
179
180
181
182
    sampling_params = create_sampling_params()
    outputs = llm.generate(prompts, sampling_params)
    print_outputs(outputs)


if __name__ == "__main__":
汪志鹏's avatar
汪志鹏 committed
183
184
185
186
187
188
189
190
191
192
193
194
195
    parser = argparse.ArgumentParser(
        description="A flexible demo for vLLM encoder-decoder models."
    )
    parser.add_argument(
        "--model",
        "-m",
        type=str,
        default="bart",
        choices=MODEL_GETTERS.keys(),
        help="The short name of the model to run.",
    )
    args = parser.parse_args()
    main(args)