vllm_engine.py 10.6 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union

luopl's avatar
luopl committed
18
19
from typing_extensions import override

chenych's avatar
chenych committed
20
from ..data import get_template_and_fix_tokenizer
luopl's avatar
luopl committed
21
from ..extras import logging
luopl's avatar
luopl committed
22
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
chenych's avatar
chenych committed
23
from ..extras.misc import get_device_count
luopl's avatar
luopl committed
24
from ..extras.packages import is_pillow_available, is_vllm_available
chenych's avatar
chenych committed
25
26
27
28
29
30
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response


luopl's avatar
luopl committed
31
32
33
34
35
if is_pillow_available():
    from PIL import Image
    from PIL.Image import Image as ImageObject


chenych's avatar
chenych committed
36
37
38
39
40
41
if is_vllm_available():
    from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
    from vllm.lora.request import LoRARequest


if TYPE_CHECKING:
luopl's avatar
luopl committed
42
    from ..data.mm_plugin import ImageInput, VideoInput
chenych's avatar
chenych committed
43
44
45
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


luopl's avatar
luopl committed
46
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68


class VllmEngine(BaseEngine):
    def __init__(
        self,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
    ) -> None:
        config = load_config(model_args)  # may download model from ms hub
        if getattr(config, "quantization_config", None):  # gptq models should use float16
            quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
            quant_method = quantization_config.get("quant_method", "")
            if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
                model_args.infer_dtype = "float16"

        self.can_generate = finetuning_args.stage == "sft"
        tokenizer_module = load_tokenizer(model_args)
        self.tokenizer = tokenizer_module["tokenizer"]
        self.processor = tokenizer_module["processor"]
        self.tokenizer.padding_side = "left"
luopl's avatar
luopl committed
69
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
luopl's avatar
luopl committed
70
        self.template.mm_plugin.expand_mm_tokens = False  # for vllm generate
chenych's avatar
chenych committed
71
72
73
74
        self.generating_args = generating_args.to_dict()

        engine_args = {
            "model": model_args.model_name_or_path,
luopl's avatar
luopl committed
75
            "trust_remote_code": model_args.trust_remote_code,
chenych's avatar
chenych committed
76
77
78
79
80
81
82
83
84
85
86
            "download_dir": model_args.cache_dir,
            "dtype": model_args.infer_dtype,
            "max_model_len": model_args.vllm_maxlen,
            "tensor_parallel_size": get_device_count() or 1,
            "gpu_memory_utilization": model_args.vllm_gpu_util,
            "disable_log_stats": True,
            "disable_log_requests": True,
            "enforce_eager": model_args.vllm_enforce_eager,
            "enable_lora": model_args.adapter_name_or_path is not None,
            "max_lora_rank": model_args.vllm_max_lora_rank,
        }
luopl's avatar
luopl committed
87
88
89
        if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
            engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}

luopl's avatar
luopl committed
90
91
        if isinstance(model_args.vllm_config, dict):
            engine_args.update(model_args.vllm_config)
chenych's avatar
chenych committed
92

luopl's avatar
luopl committed
93
94
        if getattr(config, "is_yi_vl_derived_model", None):
            import vllm.model_executor.models.llava
chenych's avatar
chenych committed
95

luopl's avatar
luopl committed
96
            logger.info_rank0("Detected Yi-VL model, applying projector patch.")
luopl's avatar
luopl committed
97
            vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
chenych's avatar
chenych committed
98
99
100
101
102
103
104
105
106
107
108
109

        self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
        if model_args.adapter_name_or_path is not None:
            self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
        else:
            self.lora_request = None

    async def _generate(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
110
111
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
112
113
        **input_kwargs,
    ) -> AsyncIterator["RequestOutput"]:
luopl's avatar
luopl committed
114
        request_id = f"chatcmpl-{uuid.uuid4().hex}"
luopl's avatar
luopl committed
115
        mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
luopl's avatar
luopl committed
116
        if images is not None:
luopl's avatar
luopl committed
117
            mm_input_dict.update({"images": images, "imglens": [len(images)]})
luopl's avatar
luopl committed
118
119
            if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
chenych's avatar
chenych committed
120

luopl's avatar
luopl committed
121
122
123
124
        if videos is not None:
            mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
            if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
luopl's avatar
luopl committed
125

luopl's avatar
luopl committed
126
127
128
129
        messages = self.template.mm_plugin.process_messages(
            messages, mm_input_dict["images"], mm_input_dict["videos"], self.processor
        )
        paired_messages = messages + [{"role": "assistant", "content": ""}]
chenych's avatar
chenych committed
130
        system = system or self.generating_args["default_system"]
luopl's avatar
luopl committed
131
        prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
chenych's avatar
chenych committed
132
133
134
135
136
137
138
139
140
141
142
143
        prompt_length = len(prompt_ids)

        temperature: Optional[float] = input_kwargs.pop("temperature", None)
        top_p: Optional[float] = input_kwargs.pop("top_p", None)
        top_k: Optional[float] = input_kwargs.pop("top_k", None)
        num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
        repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
        length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
        max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
        stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)

luopl's avatar
luopl committed
144
145
146
        if length_penalty is not None:
            logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")

chenych's avatar
chenych committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        if "max_new_tokens" in self.generating_args:
            max_tokens = self.generating_args["max_new_tokens"]
        elif "max_length" in self.generating_args:
            if self.generating_args["max_length"] > prompt_length:
                max_tokens = self.generating_args["max_length"] - prompt_length
            else:
                max_tokens = 1

        if max_length:
            max_tokens = max_length - prompt_length if max_length > prompt_length else 1

        if max_new_tokens:
            max_tokens = max_new_tokens

        sampling_params = SamplingParams(
            n=num_return_sequences,
            repetition_penalty=(
                repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
            )
            or 1.0,  # repetition_penalty must > 0
            temperature=temperature if temperature is not None else self.generating_args["temperature"],
            top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0,  # top_p must > 0
            top_k=top_k if top_k is not None else self.generating_args["top_k"],
            stop=stop,
luopl's avatar
luopl committed
171
            stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
chenych's avatar
chenych committed
172
            max_tokens=max_tokens,
luopl's avatar
luopl committed
173
            skip_special_tokens=self.generating_args["skip_special_tokens"],
chenych's avatar
chenych committed
174
175
        )

luopl's avatar
luopl committed
176
        if images is not None:  # add image features
luopl's avatar
luopl committed
177
            multi_modal_data = {"image": []}
luopl's avatar
luopl committed
178
179
180
181
182
183
            for image in images:
                if not isinstance(image, (str, ImageObject)):
                    raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")

                if isinstance(image, str):
                    image = Image.open(image).convert("RGB")
luopl's avatar
luopl committed
184

luopl's avatar
luopl committed
185
                multi_modal_data["image"].append(image)
luopl's avatar
luopl committed
186
187
188
        else:
            multi_modal_data = None

chenych's avatar
chenych committed
189
        result_generator = self.model.generate(
luopl's avatar
luopl committed
190
            {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
chenych's avatar
chenych committed
191
192
193
194
195
196
            sampling_params=sampling_params,
            request_id=request_id,
            lora_request=self.lora_request,
        )
        return result_generator

luopl's avatar
luopl committed
197
    @override
chenych's avatar
chenych committed
198
199
200
201
202
    async def chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
203
204
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
205
206
207
        **input_kwargs,
    ) -> List["Response"]:
        final_output = None
luopl's avatar
luopl committed
208
        generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
chenych's avatar
chenych committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        async for request_output in generator:
            final_output = request_output

        results = []
        for output in final_output.outputs:
            results.append(
                Response(
                    response_text=output.text,
                    response_length=len(output.token_ids),
                    prompt_length=len(final_output.prompt_token_ids),
                    finish_reason=output.finish_reason,
                )
            )

        return results

luopl's avatar
luopl committed
225
    @override
chenych's avatar
chenych committed
226
227
228
229
230
    async def stream_chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
231
232
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
233
234
235
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
        generated_text = ""
luopl's avatar
luopl committed
236
        generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
chenych's avatar
chenych committed
237
238
239
240
241
        async for result in generator:
            delta_text = result.outputs[0].text[len(generated_text) :]
            generated_text = result.outputs[0].text
            yield delta_text

luopl's avatar
luopl committed
242
    @override
chenych's avatar
chenych committed
243
244
245
246
247
248
    async def get_scores(
        self,
        batch_input: List[str],
        **input_kwargs,
    ) -> List[float]:
        raise NotImplementedError("vLLM engine does not support get_scores.")