"docs/vscode:/vscode.git/clone" did not exist on "98834fefaaabe7219e35499ada8d6026a1f9b6a2"
encoder_decoder_multimodal.py 3.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation.
"""
7

8
import os
9
import time
10
11
from collections.abc import Sequence
from typing import NamedTuple
12

13
from vllm import LLM, EngineArgs, PromptType, SamplingParams
14
from vllm.assets.audio import AudioAsset
15
from vllm.utils.argparse_utils import FlexibleArgumentParser
16
17


18
19
20
21
22
class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
    prompts: Sequence[PromptType]


23
def run_whisper():
24
25
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

26
    engine_args = EngineArgs(
27
28
29
30
31
32
33
34
        model="openai/whisper-large-v3-turbo",
        max_model_len=448,
        max_num_seqs=16,
        limit_mm_per_prompt={"audio": 1},
        dtype="half",
    )

    prompts = [
35
        {  # Test implicit prompt
36
37
38
39
40
            "prompt": "<|startoftranscript|>",
            "multi_modal_data": {
                "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
            },
        },
41
        {  # Test explicit encoder/decoder prompt
42
43
44
45
46
47
48
            "encoder_prompt": {
                "prompt": "",
                "multi_modal_data": {
                    "audio": AudioAsset("winning_call").audio_and_sample_rate,
                },
            },
            "decoder_prompt": "<|startoftranscript|>",
49
        },
50
    ]
51
52
53
54
55

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
56
57
58
59
60
61
62


model_example_map = {
    "whisper": run_whisper,
}


63
64
def parse_args():
    parser = FlexibleArgumentParser(
65
66
67
68
69
70
71
        description="Demo on using vLLM for offline inference with "
        "vision language models for text generation"
    )
    parser.add_argument(
        "--model-type",
        "-m",
        type=str,
72
        default="whisper",
73
74
75
76
77
78
        choices=model_example_map.keys(),
        help='Huggingface "model_type".',
    )
    parser.add_argument(
        "--seed",
        type=int,
79
        default=0,
80
81
        help="Set the seed when initializing `vllm.LLM`.",
    )
82
83
84
    return parser.parse_args()


85
86
87
88
89
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

90
91
    req_data = model_example_map[model]()

92
    # Disable other modalities to save memory
93
    engine_args = req_data.engine_args
94
    default_limits = {"image": 0, "video": 0, "audio": 0}
95
96
97
98
    limit_mm_per_prompt = default_limits | (engine_args.limit_mm_per_prompt or {})
    engine_args.limit_mm_per_prompt = limit_mm_per_prompt
    engine_args.seed = args.seed
    llm = LLM.from_engine_args(engine_args)
99
100

    prompts = req_data.prompts
101
102
103
104
105
106

    # Create a sampling params object.
    sampling_params = SamplingParams(
        temperature=0,
        top_p=1.0,
        max_tokens=64,
107
        skip_special_tokens=False,
108
109
110
111
112
113
114
115
116
117
118
119
120
    )

    start = time.time()

    # Generate output tokens from the prompts. The output is a list of
    # RequestOutput objects that contain the prompt, generated
    # text, and other information.
    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
121
        print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
122
123
124
125
126
127
128
129

    duration = time.time() - start

    print("Duration:", duration)
    print("RPS:", len(prompts) / duration)


if __name__ == "__main__":
130
    args = parse_args()
131
    main(args)