vllm_engine.py 11.2 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union

luopl's avatar
luopl committed
18
19
from typing_extensions import override

chenych's avatar
chenych committed
20
from ..data import get_template_and_fix_tokenizer
luopl's avatar
luopl committed
21
from ..extras import logging
chenych's avatar
chenych committed
22
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
chenych's avatar
chenych committed
23
from ..extras.misc import get_device_count
chenych's avatar
chenych committed
24
from ..extras.packages import is_vllm_available
chenych's avatar
chenych committed
25
26
27
28
29
30
31
32
33
34
35
36
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response


if is_vllm_available():
    from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
    from vllm.lora.request import LoRARequest


if TYPE_CHECKING:
chenych's avatar
chenych committed
37
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
chenych's avatar
chenych committed
38
39
40
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


luopl's avatar
luopl committed
41
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
42
43
44
45
46
47
48
49
50
51


class VllmEngine(BaseEngine):
    def __init__(
        self,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
    ) -> None:
chenych's avatar
chenych committed
52
        self.name = EngineName.VLLM
chenych's avatar
chenych committed
53
        self.model_args = model_args
chenych's avatar
chenych committed
54
55
56
57
58
59
60
61
62
63
64
65
        config = load_config(model_args)  # may download model from ms hub
        if getattr(config, "quantization_config", None):  # gptq models should use float16
            quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
            quant_method = quantization_config.get("quant_method", "")
            if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
                model_args.infer_dtype = "float16"

        self.can_generate = finetuning_args.stage == "sft"
        tokenizer_module = load_tokenizer(model_args)
        self.tokenizer = tokenizer_module["tokenizer"]
        self.processor = tokenizer_module["processor"]
        self.tokenizer.padding_side = "left"
luopl's avatar
luopl committed
66
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
luopl's avatar
luopl committed
67
        self.template.mm_plugin.expand_mm_tokens = False  # for vllm generate
chenych's avatar
chenych committed
68
69
70
71
        self.generating_args = generating_args.to_dict()

        engine_args = {
            "model": model_args.model_name_or_path,
luopl's avatar
luopl committed
72
            "trust_remote_code": model_args.trust_remote_code,
chenych's avatar
chenych committed
73
74
75
76
77
78
79
80
81
82
83
            "download_dir": model_args.cache_dir,
            "dtype": model_args.infer_dtype,
            "max_model_len": model_args.vllm_maxlen,
            "tensor_parallel_size": get_device_count() or 1,
            "gpu_memory_utilization": model_args.vllm_gpu_util,
            "disable_log_stats": True,
            "disable_log_requests": True,
            "enforce_eager": model_args.vllm_enforce_eager,
            "enable_lora": model_args.adapter_name_or_path is not None,
            "max_lora_rank": model_args.vllm_max_lora_rank,
        }
luopl's avatar
luopl committed
84
85
86
        if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
            engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}

luopl's avatar
luopl committed
87
88
        if isinstance(model_args.vllm_config, dict):
            engine_args.update(model_args.vllm_config)
chenych's avatar
chenych committed
89

luopl's avatar
luopl committed
90
91
        if getattr(config, "is_yi_vl_derived_model", None):
            import vllm.model_executor.models.llava
chenych's avatar
chenych committed
92

luopl's avatar
luopl committed
93
            logger.info_rank0("Detected Yi-VL model, applying projector patch.")
luopl's avatar
luopl committed
94
            vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
chenych's avatar
chenych committed
95
96
97
98
99
100
101
102
103
104
105
106

        self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
        if model_args.adapter_name_or_path is not None:
            self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
        else:
            self.lora_request = None

    async def _generate(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
107
108
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
109
        audios: Optional[Sequence["AudioInput"]] = None,
chenych's avatar
chenych committed
110
111
        **input_kwargs,
    ) -> AsyncIterator["RequestOutput"]:
luopl's avatar
luopl committed
112
        request_id = f"chatcmpl-{uuid.uuid4().hex}"
chenych's avatar
chenych committed
113
        mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
luopl's avatar
luopl committed
114
        if images is not None:
luopl's avatar
luopl committed
115
            mm_input_dict.update({"images": images, "imglens": [len(images)]})
luopl's avatar
luopl committed
116
117
            if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
chenych's avatar
chenych committed
118

luopl's avatar
luopl committed
119
120
121
122
        if videos is not None:
            mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
            if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
luopl's avatar
luopl committed
123

chenych's avatar
chenych committed
124
125
126
127
128
        if audios is not None:
            mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
            if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]

luopl's avatar
luopl committed
129
        messages = self.template.mm_plugin.process_messages(
chenych's avatar
chenych committed
130
            messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor
luopl's avatar
luopl committed
131
132
        )
        paired_messages = messages + [{"role": "assistant", "content": ""}]
chenych's avatar
chenych committed
133
        system = system or self.generating_args["default_system"]
luopl's avatar
luopl committed
134
        prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
chenych's avatar
chenych committed
135
136
137
138
139
140
141
142
        prompt_length = len(prompt_ids)

        temperature: Optional[float] = input_kwargs.pop("temperature", None)
        top_p: Optional[float] = input_kwargs.pop("top_p", None)
        top_k: Optional[float] = input_kwargs.pop("top_k", None)
        num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
        repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
        length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
chenych's avatar
chenych committed
143
        skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
chenych's avatar
chenych committed
144
145
146
147
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
        max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
        stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)

luopl's avatar
luopl committed
148
149
150
        if length_penalty is not None:
            logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")

chenych's avatar
chenych committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        if "max_new_tokens" in self.generating_args:
            max_tokens = self.generating_args["max_new_tokens"]
        elif "max_length" in self.generating_args:
            if self.generating_args["max_length"] > prompt_length:
                max_tokens = self.generating_args["max_length"] - prompt_length
            else:
                max_tokens = 1

        if max_length:
            max_tokens = max_length - prompt_length if max_length > prompt_length else 1

        if max_new_tokens:
            max_tokens = max_new_tokens

        sampling_params = SamplingParams(
            n=num_return_sequences,
            repetition_penalty=(
                repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
            )
            or 1.0,  # repetition_penalty must > 0
            temperature=temperature if temperature is not None else self.generating_args["temperature"],
            top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0,  # top_p must > 0
chenych's avatar
chenych committed
173
            top_k=(top_k if top_k is not None else self.generating_args["top_k"]) or -1,  # top_k must > 0
chenych's avatar
chenych committed
174
            stop=stop,
luopl's avatar
luopl committed
175
            stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
chenych's avatar
chenych committed
176
            max_tokens=max_tokens,
chenych's avatar
chenych committed
177
178
179
            skip_special_tokens=skip_special_tokens
            if skip_special_tokens is not None
            else self.generating_args["skip_special_tokens"],
chenych's avatar
chenych committed
180
181
        )

luopl's avatar
luopl committed
182
        if images is not None:  # add image features
chenych's avatar
chenych committed
183
184
185
186
187
188
189
            multi_modal_data = {
                "image": self.template.mm_plugin._regularize_images(
                    images,
                    image_max_pixels=self.model_args.image_max_pixels,
                    image_min_pixels=self.model_args.image_min_pixels,
                )
            }
luopl's avatar
luopl committed
190
191
192
        else:
            multi_modal_data = None

chenych's avatar
chenych committed
193
        result_generator = self.model.generate(
luopl's avatar
luopl committed
194
            {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
chenych's avatar
chenych committed
195
196
197
198
199
200
            sampling_params=sampling_params,
            request_id=request_id,
            lora_request=self.lora_request,
        )
        return result_generator

luopl's avatar
luopl committed
201
    @override
chenych's avatar
chenych committed
202
203
204
205
206
    async def chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
207
208
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
209
        audios: Optional[Sequence["AudioInput"]] = None,
chenych's avatar
chenych committed
210
211
212
        **input_kwargs,
    ) -> List["Response"]:
        final_output = None
chenych's avatar
chenych committed
213
        generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
chenych's avatar
chenych committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        async for request_output in generator:
            final_output = request_output

        results = []
        for output in final_output.outputs:
            results.append(
                Response(
                    response_text=output.text,
                    response_length=len(output.token_ids),
                    prompt_length=len(final_output.prompt_token_ids),
                    finish_reason=output.finish_reason,
                )
            )

        return results

luopl's avatar
luopl committed
230
    @override
chenych's avatar
chenych committed
231
232
233
234
235
    async def stream_chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
236
237
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
238
        audios: Optional[Sequence["AudioInput"]] = None,
chenych's avatar
chenych committed
239
240
241
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
        generated_text = ""
chenych's avatar
chenych committed
242
        generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
chenych's avatar
chenych committed
243
244
245
246
247
        async for result in generator:
            delta_text = result.outputs[0].text[len(generated_text) :]
            generated_text = result.outputs[0].text
            yield delta_text

luopl's avatar
luopl committed
248
    @override
chenych's avatar
chenych committed
249
250
251
252
253
254
    async def get_scores(
        self,
        batch_input: List[str],
        **input_kwargs,
    ) -> List[float]:
        raise NotImplementedError("vLLM engine does not support get_scores.")