vllm_engine.py 11.1 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union

luopl's avatar
luopl committed
18
19
from typing_extensions import override

chenych's avatar
chenych committed
20
from ..data import get_template_and_fix_tokenizer
luopl's avatar
luopl committed
21
from ..extras import logging
chenych's avatar
chenych committed
22
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
chenych's avatar
chenych committed
23
from ..extras.misc import get_device_count
chenych's avatar
chenych committed
24
from ..extras.packages import is_vllm_available
chenych's avatar
chenych committed
25
26
27
28
29
30
31
32
33
34
35
36
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response


if is_vllm_available():
    from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
    from vllm.lora.request import LoRARequest


if TYPE_CHECKING:
chenych's avatar
chenych committed
37
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
chenych's avatar
chenych committed
38
39
40
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


luopl's avatar
luopl committed
41
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
42
43
44
45
46
47
48
49
50
51


class VllmEngine(BaseEngine):
    def __init__(
        self,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
    ) -> None:
chenych's avatar
chenych committed
52
        self.model_args = model_args
chenych's avatar
chenych committed
53
54
55
56
57
58
59
60
61
62
63
64
        config = load_config(model_args)  # may download model from ms hub
        if getattr(config, "quantization_config", None):  # gptq models should use float16
            quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
            quant_method = quantization_config.get("quant_method", "")
            if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
                model_args.infer_dtype = "float16"

        self.can_generate = finetuning_args.stage == "sft"
        tokenizer_module = load_tokenizer(model_args)
        self.tokenizer = tokenizer_module["tokenizer"]
        self.processor = tokenizer_module["processor"]
        self.tokenizer.padding_side = "left"
luopl's avatar
luopl committed
65
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
luopl's avatar
luopl committed
66
        self.template.mm_plugin.expand_mm_tokens = False  # for vllm generate
chenych's avatar
chenych committed
67
68
69
70
        self.generating_args = generating_args.to_dict()

        engine_args = {
            "model": model_args.model_name_or_path,
luopl's avatar
luopl committed
71
            "trust_remote_code": model_args.trust_remote_code,
chenych's avatar
chenych committed
72
73
74
75
76
77
78
79
80
81
82
            "download_dir": model_args.cache_dir,
            "dtype": model_args.infer_dtype,
            "max_model_len": model_args.vllm_maxlen,
            "tensor_parallel_size": get_device_count() or 1,
            "gpu_memory_utilization": model_args.vllm_gpu_util,
            "disable_log_stats": True,
            "disable_log_requests": True,
            "enforce_eager": model_args.vllm_enforce_eager,
            "enable_lora": model_args.adapter_name_or_path is not None,
            "max_lora_rank": model_args.vllm_max_lora_rank,
        }
luopl's avatar
luopl committed
83
84
85
        if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
            engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}

luopl's avatar
luopl committed
86
87
        if isinstance(model_args.vllm_config, dict):
            engine_args.update(model_args.vllm_config)
chenych's avatar
chenych committed
88

luopl's avatar
luopl committed
89
90
        if getattr(config, "is_yi_vl_derived_model", None):
            import vllm.model_executor.models.llava
chenych's avatar
chenych committed
91

luopl's avatar
luopl committed
92
            logger.info_rank0("Detected Yi-VL model, applying projector patch.")
luopl's avatar
luopl committed
93
            vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
chenych's avatar
chenych committed
94
95
96
97
98
99
100
101
102
103
104
105

        self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
        if model_args.adapter_name_or_path is not None:
            self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
        else:
            self.lora_request = None

    async def _generate(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
106
107
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
108
        audios: Optional[Sequence["AudioInput"]] = None,
chenych's avatar
chenych committed
109
110
        **input_kwargs,
    ) -> AsyncIterator["RequestOutput"]:
luopl's avatar
luopl committed
111
        request_id = f"chatcmpl-{uuid.uuid4().hex}"
chenych's avatar
chenych committed
112
        mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
luopl's avatar
luopl committed
113
        if images is not None:
luopl's avatar
luopl committed
114
            mm_input_dict.update({"images": images, "imglens": [len(images)]})
luopl's avatar
luopl committed
115
116
            if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
chenych's avatar
chenych committed
117

luopl's avatar
luopl committed
118
119
120
121
        if videos is not None:
            mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
            if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
luopl's avatar
luopl committed
122

chenych's avatar
chenych committed
123
124
125
126
127
        if audios is not None:
            mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
            if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]

luopl's avatar
luopl committed
128
        messages = self.template.mm_plugin.process_messages(
chenych's avatar
chenych committed
129
            messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor
luopl's avatar
luopl committed
130
131
        )
        paired_messages = messages + [{"role": "assistant", "content": ""}]
chenych's avatar
chenych committed
132
        system = system or self.generating_args["default_system"]
luopl's avatar
luopl committed
133
        prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
chenych's avatar
chenych committed
134
135
136
137
138
139
140
141
        prompt_length = len(prompt_ids)

        temperature: Optional[float] = input_kwargs.pop("temperature", None)
        top_p: Optional[float] = input_kwargs.pop("top_p", None)
        top_k: Optional[float] = input_kwargs.pop("top_k", None)
        num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
        repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
        length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
chenych's avatar
chenych committed
142
        skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
chenych's avatar
chenych committed
143
144
145
146
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
        max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
        stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)

luopl's avatar
luopl committed
147
148
149
        if length_penalty is not None:
            logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")

chenych's avatar
chenych committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        if "max_new_tokens" in self.generating_args:
            max_tokens = self.generating_args["max_new_tokens"]
        elif "max_length" in self.generating_args:
            if self.generating_args["max_length"] > prompt_length:
                max_tokens = self.generating_args["max_length"] - prompt_length
            else:
                max_tokens = 1

        if max_length:
            max_tokens = max_length - prompt_length if max_length > prompt_length else 1

        if max_new_tokens:
            max_tokens = max_new_tokens

        sampling_params = SamplingParams(
            n=num_return_sequences,
            repetition_penalty=(
                repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
            )
            or 1.0,  # repetition_penalty must > 0
            temperature=temperature if temperature is not None else self.generating_args["temperature"],
            top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0,  # top_p must > 0
            top_k=top_k if top_k is not None else self.generating_args["top_k"],
            stop=stop,
luopl's avatar
luopl committed
174
            stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
chenych's avatar
chenych committed
175
            max_tokens=max_tokens,
chenych's avatar
chenych committed
176
177
178
            skip_special_tokens=skip_special_tokens
            if skip_special_tokens is not None
            else self.generating_args["skip_special_tokens"],
chenych's avatar
chenych committed
179
180
        )

luopl's avatar
luopl committed
181
        if images is not None:  # add image features
chenych's avatar
chenych committed
182
183
184
185
186
187
188
            multi_modal_data = {
                "image": self.template.mm_plugin._regularize_images(
                    images,
                    image_max_pixels=self.model_args.image_max_pixels,
                    image_min_pixels=self.model_args.image_min_pixels,
                )
            }
luopl's avatar
luopl committed
189
190
191
        else:
            multi_modal_data = None

chenych's avatar
chenych committed
192
        result_generator = self.model.generate(
luopl's avatar
luopl committed
193
            {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
chenych's avatar
chenych committed
194
195
196
197
198
199
            sampling_params=sampling_params,
            request_id=request_id,
            lora_request=self.lora_request,
        )
        return result_generator

luopl's avatar
luopl committed
200
    @override
chenych's avatar
chenych committed
201
202
203
204
205
    async def chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
206
207
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
208
        audios: Optional[Sequence["AudioInput"]] = None,
chenych's avatar
chenych committed
209
210
211
        **input_kwargs,
    ) -> List["Response"]:
        final_output = None
chenych's avatar
chenych committed
212
        generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
chenych's avatar
chenych committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        async for request_output in generator:
            final_output = request_output

        results = []
        for output in final_output.outputs:
            results.append(
                Response(
                    response_text=output.text,
                    response_length=len(output.token_ids),
                    prompt_length=len(final_output.prompt_token_ids),
                    finish_reason=output.finish_reason,
                )
            )

        return results

luopl's avatar
luopl committed
229
    @override
chenych's avatar
chenych committed
230
231
232
233
234
    async def stream_chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
235
236
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
237
        audios: Optional[Sequence["AudioInput"]] = None,
chenych's avatar
chenych committed
238
239
240
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
        generated_text = ""
chenych's avatar
chenych committed
241
        generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
chenych's avatar
chenych committed
242
243
244
245
246
        async for result in generator:
            delta_text = result.outputs[0].text[len(generated_text) :]
            generated_text = result.outputs[0].text
            yield delta_text

luopl's avatar
luopl committed
247
    @override
chenych's avatar
chenych committed
248
249
250
251
252
253
    async def get_scores(
        self,
        batch_input: List[str],
        **input_kwargs,
    ) -> List[float]:
        raise NotImplementedError("vLLM engine does not support get_scores.")