vllm_engine.py 11.5 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
chenych's avatar
chenych committed
16
17
from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Any, Optional, Union
chenych's avatar
chenych committed
18

luopl's avatar
luopl committed
19
20
from typing_extensions import override

chenych's avatar
chenych committed
21
from ..data import get_template_and_fix_tokenizer
luopl's avatar
luopl committed
22
from ..extras import logging
chenych's avatar
chenych committed
23
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
chenych's avatar
chenych committed
24
from ..extras.misc import get_device_count
chenych's avatar
chenych committed
25
from ..extras.packages import is_vllm_available
chenych's avatar
chenych committed
26
27
28
29
30
31
32
33
34
35
36
37
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response


if is_vllm_available():
    from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
    from vllm.lora.request import LoRARequest


if TYPE_CHECKING:
chenych's avatar
chenych committed
38
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
chenych's avatar
chenych committed
39
40
41
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


luopl's avatar
luopl committed
42
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
43
44
45
46
47
48
49
50
51
52


class VllmEngine(BaseEngine):
    def __init__(
        self,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
    ) -> None:
chenych's avatar
chenych committed
53
        self.name = EngineName.VLLM
chenych's avatar
chenych committed
54
        self.model_args = model_args
chenych's avatar
chenych committed
55
56
        config = load_config(model_args)  # may download model from ms hub
        if getattr(config, "quantization_config", None):  # gptq models should use float16
chenych's avatar
chenych committed
57
            quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
chenych's avatar
chenych committed
58
59
60
61
62
63
64
65
66
            quant_method = quantization_config.get("quant_method", "")
            if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
                model_args.infer_dtype = "float16"

        self.can_generate = finetuning_args.stage == "sft"
        tokenizer_module = load_tokenizer(model_args)
        self.tokenizer = tokenizer_module["tokenizer"]
        self.processor = tokenizer_module["processor"]
        self.tokenizer.padding_side = "left"
luopl's avatar
luopl committed
67
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
luopl's avatar
luopl committed
68
        self.template.mm_plugin.expand_mm_tokens = False  # for vllm generate
chenych's avatar
chenych committed
69
70
71
72
        self.generating_args = generating_args.to_dict()

        engine_args = {
            "model": model_args.model_name_or_path,
luopl's avatar
luopl committed
73
            "trust_remote_code": model_args.trust_remote_code,
chenych's avatar
chenych committed
74
75
76
77
78
79
80
81
82
83
84
            "download_dir": model_args.cache_dir,
            "dtype": model_args.infer_dtype,
            "max_model_len": model_args.vllm_maxlen,
            "tensor_parallel_size": get_device_count() or 1,
            "gpu_memory_utilization": model_args.vllm_gpu_util,
            "disable_log_stats": True,
            "disable_log_requests": True,
            "enforce_eager": model_args.vllm_enforce_eager,
            "enable_lora": model_args.adapter_name_or_path is not None,
            "max_lora_rank": model_args.vllm_max_lora_rank,
        }
luopl's avatar
luopl committed
85
        if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
chenych's avatar
chenych committed
86
            engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
luopl's avatar
luopl committed
87

luopl's avatar
luopl committed
88
89
        if isinstance(model_args.vllm_config, dict):
            engine_args.update(model_args.vllm_config)
chenych's avatar
chenych committed
90

luopl's avatar
luopl committed
91
92
        if getattr(config, "is_yi_vl_derived_model", None):
            import vllm.model_executor.models.llava
chenych's avatar
chenych committed
93

luopl's avatar
luopl committed
94
            logger.info_rank0("Detected Yi-VL model, applying projector patch.")
luopl's avatar
luopl committed
95
            vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
chenych's avatar
chenych committed
96
97
98
99
100
101
102
103
104

        self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
        if model_args.adapter_name_or_path is not None:
            self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
        else:
            self.lora_request = None

    async def _generate(
        self,
chenych's avatar
chenych committed
105
        messages: list[dict[str, str]],
chenych's avatar
chenych committed
106
107
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
108
109
110
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
chenych's avatar
chenych committed
111
112
        **input_kwargs,
    ) -> AsyncIterator["RequestOutput"]:
luopl's avatar
luopl committed
113
        request_id = f"chatcmpl-{uuid.uuid4().hex}"
chenych's avatar
chenych committed
114
115
116
117
118
119
120
121
        if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
            messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]

        if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
            messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]

        if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
            messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
chenych's avatar
chenych committed
122

luopl's avatar
luopl committed
123
        messages = self.template.mm_plugin.process_messages(
chenych's avatar
chenych committed
124
            messages, images or [], videos or [], audios or [], self.processor
luopl's avatar
luopl committed
125
126
        )
        paired_messages = messages + [{"role": "assistant", "content": ""}]
chenych's avatar
chenych committed
127
        system = system or self.generating_args["default_system"]
luopl's avatar
luopl committed
128
        prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
chenych's avatar
chenych committed
129
130
131
132
133
134
135
136
        prompt_length = len(prompt_ids)

        temperature: Optional[float] = input_kwargs.pop("temperature", None)
        top_p: Optional[float] = input_kwargs.pop("top_p", None)
        top_k: Optional[float] = input_kwargs.pop("top_k", None)
        num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
        repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
        length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
chenych's avatar
chenych committed
137
        skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
chenych's avatar
chenych committed
138
139
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
        max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
chenych's avatar
chenych committed
140
        stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
chenych's avatar
chenych committed
141

luopl's avatar
luopl committed
142
143
144
        if length_penalty is not None:
            logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")

chenych's avatar
chenych committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        if "max_new_tokens" in self.generating_args:
            max_tokens = self.generating_args["max_new_tokens"]
        elif "max_length" in self.generating_args:
            if self.generating_args["max_length"] > prompt_length:
                max_tokens = self.generating_args["max_length"] - prompt_length
            else:
                max_tokens = 1

        if max_length:
            max_tokens = max_length - prompt_length if max_length > prompt_length else 1

        if max_new_tokens:
            max_tokens = max_new_tokens

        sampling_params = SamplingParams(
            n=num_return_sequences,
            repetition_penalty=(
                repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
            )
            or 1.0,  # repetition_penalty must > 0
            temperature=temperature if temperature is not None else self.generating_args["temperature"],
            top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0,  # top_p must > 0
chenych's avatar
chenych committed
167
            top_k=(top_k if top_k is not None else self.generating_args["top_k"]) or -1,  # top_k must > 0
chenych's avatar
chenych committed
168
            stop=stop,
luopl's avatar
luopl committed
169
            stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
chenych's avatar
chenych committed
170
            max_tokens=max_tokens,
chenych's avatar
chenych committed
171
172
173
            skip_special_tokens=skip_special_tokens
            if skip_special_tokens is not None
            else self.generating_args["skip_special_tokens"],
chenych's avatar
chenych committed
174
175
        )

luopl's avatar
luopl committed
176
        if images is not None:  # add image features
chenych's avatar
chenych committed
177
178
179
180
181
            multi_modal_data = {
                "image": self.template.mm_plugin._regularize_images(
                    images,
                    image_max_pixels=self.model_args.image_max_pixels,
                    image_min_pixels=self.model_args.image_min_pixels,
chenych's avatar
chenych committed
182
                )["images"]
chenych's avatar
chenych committed
183
            }
chenych's avatar
chenych committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        elif videos is not None:
            multi_modal_data = {
                "video": self.template.mm_plugin._regularize_videos(
                    videos,
                    image_max_pixels=self.model_args.video_max_pixels,
                    image_min_pixels=self.model_args.video_min_pixels,
                    video_fps=self.model_args.video_fps,
                    video_maxlen=self.model_args.video_maxlen,
                )["videos"]
            }
        elif audios is not None:
            audio_data = self.template.mm_plugin._regularize_audios(
                audios,
                sampling_rate=self.model_args.audio_sampling_rate,
            )
            multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
luopl's avatar
luopl committed
200
201
202
        else:
            multi_modal_data = None

chenych's avatar
chenych committed
203
        result_generator = self.model.generate(
luopl's avatar
luopl committed
204
            {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
chenych's avatar
chenych committed
205
206
207
208
209
210
            sampling_params=sampling_params,
            request_id=request_id,
            lora_request=self.lora_request,
        )
        return result_generator

luopl's avatar
luopl committed
211
    @override
chenych's avatar
chenych committed
212
213
    async def chat(
        self,
chenych's avatar
chenych committed
214
        messages: list[dict[str, str]],
chenych's avatar
chenych committed
215
216
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
217
218
219
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
chenych's avatar
chenych committed
220
        **input_kwargs,
chenych's avatar
chenych committed
221
    ) -> list["Response"]:
chenych's avatar
chenych committed
222
        final_output = None
chenych's avatar
chenych committed
223
        generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
chenych's avatar
chenych committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        async for request_output in generator:
            final_output = request_output

        results = []
        for output in final_output.outputs:
            results.append(
                Response(
                    response_text=output.text,
                    response_length=len(output.token_ids),
                    prompt_length=len(final_output.prompt_token_ids),
                    finish_reason=output.finish_reason,
                )
            )

        return results

luopl's avatar
luopl committed
240
    @override
chenych's avatar
chenych committed
241
242
    async def stream_chat(
        self,
chenych's avatar
chenych committed
243
        messages: list[dict[str, str]],
chenych's avatar
chenych committed
244
245
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
246
247
248
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
chenych's avatar
chenych committed
249
250
251
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
        generated_text = ""
chenych's avatar
chenych committed
252
        generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
chenych's avatar
chenych committed
253
254
255
256
257
        async for result in generator:
            delta_text = result.outputs[0].text[len(generated_text) :]
            generated_text = result.outputs[0].text
            yield delta_text

luopl's avatar
luopl committed
258
    @override
chenych's avatar
chenych committed
259
260
    async def get_scores(
        self,
chenych's avatar
chenych committed
261
        batch_input: list[str],
chenych's avatar
chenych committed
262
        **input_kwargs,
chenych's avatar
chenych committed
263
264
    ) -> list[float]:
        raise NotImplementedError("vLLM engine does not support `get_scores`.")