encoder_decoder.py 4.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART
"""

from vllm import LLM, SamplingParams
from vllm.inputs import (
    ExplicitEncoderDecoderPrompt,
    TextPrompt,
    TokensPrompt,
    zip_enc_dec_prompts,
)


def create_prompts(tokenizer):
    # Test prompts
    #
    # This section shows all of the valid ways to prompt an
    # encoder/decoder model.
    #
    # - Helpers for building prompts
    text_prompt_raw = "Hello, my name is"
    text_prompt = TextPrompt(prompt="The president of the United States is")
    tokens_prompt = TokensPrompt(
        prompt_token_ids=tokenizer.encode(prompt="The capital of France is")
    )
    # - Pass a single prompt to encoder/decoder model
    #   (implicitly encoder input prompt);
    #   decoder input prompt is assumed to be None

    single_text_prompt_raw = text_prompt_raw  # Pass a string directly
    single_text_prompt = text_prompt  # Pass a TextPrompt
    single_tokens_prompt = tokens_prompt  # Pass a TokensPrompt

    # ruff: noqa: E501
    # - Pass explicit encoder and decoder input prompts within one data structure.
    #   Encoder and decoder prompts can both independently be text or tokens, with
    #   no requirement that they be the same prompt type. Some example prompt-type
    #   combinations are shown below, note that these are not exhaustive.

    enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
        # Pass encoder prompt string directly, &
        # pass decoder prompt tokens
        encoder_prompt=single_text_prompt_raw,
        decoder_prompt=single_tokens_prompt,
    )
    enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
        # Pass TextPrompt to encoder, and
        # pass decoder prompt string directly
        encoder_prompt=single_text_prompt,
        decoder_prompt=single_text_prompt_raw,
    )
    enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
        # Pass encoder prompt tokens directly, and
        # pass TextPrompt to decoder
        encoder_prompt=single_tokens_prompt,
        decoder_prompt=single_text_prompt,
    )

    # - Finally, here's a useful helper function for zipping encoder and
    #   decoder prompts together into a list of ExplicitEncoderDecoderPrompt
    #   instances
    zipped_prompt_list = zip_enc_dec_prompts(
        ["An encoder prompt", "Another encoder prompt"],
        ["A decoder prompt", "Another decoder prompt"],
    )

    # - Let's put all of the above example prompts together into one list
    #   which we will pass to the encoder/decoder LLM.
    return [
        single_text_prompt_raw,
        single_text_prompt,
        single_tokens_prompt,
        enc_dec_prompt1,
        enc_dec_prompt2,
        enc_dec_prompt3,
    ] + zipped_prompt_list


# Create a sampling params object.
def create_sampling_params():
    return SamplingParams(
        temperature=0,
        top_p=1.0,
        min_tokens=0,
        max_tokens=20,
    )


# Print the outputs.
def print_outputs(outputs):
    print("-" * 50)
    for i, output in enumerate(outputs):
        prompt = output.prompt
        encoder_prompt = output.encoder_prompt
        generated_text = output.outputs[0].text
        print(f"Output {i + 1}:")
        print(
            f"Encoder prompt: {encoder_prompt!r}\n"
            f"Decoder prompt: {prompt!r}\n"
            f"Generated text: {generated_text!r}"
        )
        print("-" * 50)


def main():
    dtype = "float"

    # Create a BART encoder/decoder model instance
    llm = LLM(
        model="facebook/bart-large-cnn",
        dtype=dtype,
    )

    # Get BART tokenizer
    tokenizer = llm.llm_engine.get_tokenizer_group()

    prompts = create_prompts(tokenizer)
    sampling_params = create_sampling_params()

    # Generate output tokens from the prompts. The output is a list of
    # RequestOutput objects that contain the prompt, generated
    # text, and other information.
    outputs = llm.generate(prompts, sampling_params)

    print_outputs(outputs)


if __name__ == "__main__":
    main()