encoder_decoder.py 3.96 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""
3
4
Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART
5
"""
6
7

from vllm import LLM, SamplingParams
8
9
10
11
12
13
from vllm.inputs import (
    ExplicitEncoderDecoderPrompt,
    TextPrompt,
    TokensPrompt,
    zip_enc_dec_prompts,
)
14

15
16
17
18
19
20
21
22
23
24

def create_prompts(tokenizer):
    # Test prompts
    #
    # This section shows all of the valid ways to prompt an
    # encoder/decoder model.
    #
    # - Helpers for building prompts
    text_prompt_raw = "Hello, my name is"
    text_prompt = TextPrompt(prompt="The president of the United States is")
25
26
27
    tokens_prompt = TokensPrompt(
        prompt_token_ids=tokenizer.encode(prompt="The capital of France is")
    )
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    # - Pass a single prompt to encoder/decoder model
    #   (implicitly encoder input prompt);
    #   decoder input prompt is assumed to be None

    single_text_prompt_raw = text_prompt_raw  # Pass a string directly
    single_text_prompt = text_prompt  # Pass a TextPrompt
    single_tokens_prompt = tokens_prompt  # Pass a TokensPrompt

    # ruff: noqa: E501
    # - Pass explicit encoder and decoder input prompts within one data structure.
    #   Encoder and decoder prompts can both independently be text or tokens, with
    #   no requirement that they be the same prompt type. Some example prompt-type
    #   combinations are shown below, note that these are not exhaustive.

    enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
        # Pass encoder prompt string directly, &
        # pass decoder prompt tokens
        encoder_prompt=single_text_prompt_raw,
        decoder_prompt=single_tokens_prompt,
    )
    enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
        # Pass TextPrompt to encoder, and
        # pass decoder prompt string directly
        encoder_prompt=single_text_prompt,
        decoder_prompt=single_text_prompt_raw,
    )
    enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
        # Pass encoder prompt tokens directly, and
        # pass TextPrompt to decoder
        encoder_prompt=single_tokens_prompt,
        decoder_prompt=single_text_prompt,
    )

    # - Finally, here's a useful helper function for zipping encoder and
    #   decoder prompts together into a list of ExplicitEncoderDecoderPrompt
    #   instances
    zipped_prompt_list = zip_enc_dec_prompts(
65
66
67
        ["An encoder prompt", "Another encoder prompt"],
        ["A decoder prompt", "Another decoder prompt"],
    )
68
69
70
71

    # - Let's put all of the above example prompts together into one list
    #   which we will pass to the encoder/decoder LLM.
    return [
72
73
74
75
76
77
        single_text_prompt_raw,
        single_text_prompt,
        single_tokens_prompt,
        enc_dec_prompt1,
        enc_dec_prompt2,
        enc_dec_prompt3,
78
79
    ] + zipped_prompt_list

80
81

# Create a sampling params object.
82
83
84
85
86
87
88
89
def create_sampling_params():
    return SamplingParams(
        temperature=0,
        top_p=1.0,
        min_tokens=0,
        max_tokens=20,
    )

90
91

# Print the outputs.
92
def print_outputs(outputs):
93
    print("-" * 50)
94
95
96
97
    for i, output in enumerate(outputs):
        prompt = output.prompt
        encoder_prompt = output.encoder_prompt
        generated_text = output.outputs[0].text
98
99
100
101
102
103
        print(f"Output {i + 1}:")
        print(
            f"Encoder prompt: {encoder_prompt!r}\n"
            f"Decoder prompt: {prompt!r}\n"
            f"Generated text: {generated_text!r}"
        )
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        print("-" * 50)


def main():
    dtype = "float"

    # Create a BART encoder/decoder model instance
    llm = LLM(
        model="facebook/bart-large-cnn",
        dtype=dtype,
    )

    # Get BART tokenizer
    tokenizer = llm.llm_engine.get_tokenizer_group()

    prompts = create_prompts(tokenizer)
    sampling_params = create_sampling_params()

    # Generate output tokens from the prompts. The output is a list of
    # RequestOutput objects that contain the prompt, generated
    # text, and other information.
    outputs = llm.generate(prompts, sampling_params)

    print_outputs(outputs)


if __name__ == "__main__":
    main()