hf_engine.py 16.1 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
15
16
import asyncio
import os
chenych's avatar
chenych committed
17
from collections.abc import AsyncGenerator
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
from threading import Thread
chenych's avatar
chenych committed
19
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
20
21
22

import torch
from transformers import GenerationConfig, TextIteratorStreamer
luopl's avatar
luopl committed
23
from typing_extensions import override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
24
25

from ..data import get_template_and_fix_tokenizer
luopl's avatar
luopl committed
26
from ..extras import logging
chenych's avatar
chenych committed
27
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
28
29
30
31
32
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response


if TYPE_CHECKING:
chenych's avatar
chenych committed
33
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
34
35
36
    from trl import PreTrainedModelWrapper

    from ..data import Template
chenych's avatar
chenych committed
37
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
38
39
40
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


luopl's avatar
luopl committed
41
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
42
43


Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
44
45
46
47
48
49
50
51
class HuggingfaceEngine(BaseEngine):
    def __init__(
        self,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
    ) -> None:
chenych's avatar
chenych committed
52
        self.name = EngineName.HF
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
53
        self.can_generate = finetuning_args.stage == "sft"
chenych's avatar
chenych committed
54
55
56
        tokenizer_module = load_tokenizer(model_args)
        self.tokenizer = tokenizer_module["tokenizer"]
        self.processor = tokenizer_module["processor"]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
57
        self.tokenizer.padding_side = "left" if self.can_generate else "right"
luopl's avatar
luopl committed
58
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
59
60
61
62
        self.model = load_model(
            self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
        )  # must after fixing tokenizer to resize vocab
        self.generating_args = generating_args.to_dict()
chenych's avatar
chenych committed
63
64
65
        try:
            asyncio.get_event_loop()
        except RuntimeError:
luopl's avatar
luopl committed
66
            logger.warning_rank0_once("There is no current event loop, creating a new one.")
chenych's avatar
chenych committed
67
68
69
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

luopl's avatar
luopl committed
70
        self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
71
72
73
74
75

    @staticmethod
    def _process_args(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
76
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
77
        template: "Template",
chenych's avatar
chenych committed
78
79
        generating_args: dict[str, Any],
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
80
81
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
82
83
84
85
86
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
        input_kwargs: Optional[dict[str, Any]] = {},
    ) -> tuple[dict[str, Any], int]:
chenych's avatar
chenych committed
87
        mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
luopl's avatar
luopl committed
88
89
90
91
        if images is not None:
            mm_input_dict.update({"images": images, "imglens": [len(images)]})
            if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
luopl's avatar
luopl committed
92

luopl's avatar
luopl committed
93
94
95
96
        if videos is not None:
            mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
            if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
luopl's avatar
luopl committed
97

chenych's avatar
chenych committed
98
99
100
101
102
        if audios is not None:
            mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
            if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]

luopl's avatar
luopl committed
103
        messages = template.mm_plugin.process_messages(
chenych's avatar
chenych committed
104
            messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
luopl's avatar
luopl committed
105
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
106
        paired_messages = messages + [{"role": "assistant", "content": ""}]
luopl's avatar
luopl committed
107
108
        prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
        prompt_ids, _ = template.mm_plugin.process_token_ids(
chenych's avatar
chenych committed
109
110
111
112
113
114
115
            prompt_ids,
            None,
            mm_input_dict["images"],
            mm_input_dict["videos"],
            mm_input_dict["audios"],
            tokenizer,
            processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
116
117
118
        )
        prompt_length = len(prompt_ids)
        inputs = torch.tensor([prompt_ids], device=model.device)
mashun1's avatar
mashun1 committed
119
        attention_mask = torch.ones_like(inputs, dtype=torch.long)
chenych's avatar
chenych committed
120
121
122
123
124
125
126
127

        do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
        temperature: Optional[float] = input_kwargs.pop("temperature", None)
        top_p: Optional[float] = input_kwargs.pop("top_p", None)
        top_k: Optional[float] = input_kwargs.pop("top_k", None)
        num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
        repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
        length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
chenych's avatar
chenych committed
128
        skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
chenych's avatar
chenych committed
129
130
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
        max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
chenych's avatar
chenych committed
131
        stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
chenych's avatar
chenych committed
132
133

        if stop is not None:
luopl's avatar
luopl committed
134
            logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
chenych's avatar
chenych committed
135
136

        generating_args = generating_args.copy()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
137
138
139
        generating_args.update(
            dict(
                do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
chenych's avatar
chenych committed
140
141
142
143
144
145
146
147
                temperature=temperature if temperature is not None else generating_args["temperature"],
                top_p=top_p if top_p is not None else generating_args["top_p"],
                top_k=top_k if top_k is not None else generating_args["top_k"],
                num_return_sequences=num_return_sequences,
                repetition_penalty=repetition_penalty
                if repetition_penalty is not None
                else generating_args["repetition_penalty"],
                length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
chenych's avatar
chenych committed
148
149
150
                skip_special_tokens=skip_special_tokens
                if skip_special_tokens is not None
                else generating_args["skip_special_tokens"],
luopl's avatar
luopl committed
151
                eos_token_id=template.get_stop_token_ids(tokenizer),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
152
153
154
155
                pad_token_id=tokenizer.pad_token_id,
            )
        )

chenych's avatar
chenych committed
156
        if isinstance(num_return_sequences, int) and num_return_sequences > 1:  # do_sample needs temperature > 0
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
157
            generating_args["do_sample"] = True
chenych's avatar
chenych committed
158
159
160
161
162
163
164
165
            generating_args["temperature"] = generating_args["temperature"] or 1.0

        if not generating_args["temperature"]:
            generating_args["do_sample"] = False

        if not generating_args["do_sample"]:
            generating_args.pop("temperature", None)
            generating_args.pop("top_p", None)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
166
167
168
169
170
171
172
173
174
175
176

        if max_length:
            generating_args.pop("max_new_tokens", None)
            generating_args["max_length"] = max_length

        if max_new_tokens:
            generating_args.pop("max_length", None)
            generating_args["max_new_tokens"] = max_new_tokens

        gen_kwargs = dict(
            inputs=inputs,
chenych's avatar
chenych committed
177
            attention_mask=attention_mask,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
178
179
180
            generation_config=GenerationConfig(**generating_args),
        )

luopl's avatar
luopl committed
181
        mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
luopl's avatar
luopl committed
182
        for key, value in mm_inputs.items():
chenych's avatar
chenych committed
183
            if isinstance(value, list) and isinstance(value[0], torch.Tensor):  # for pixtral inputs
luopl's avatar
luopl committed
184
                value = torch.stack(value)  # assume they have same sizes
chenych's avatar
chenych committed
185
186
187
            elif (
                isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor)
            ):  # for minicpmv inputs
luopl's avatar
luopl committed
188
                value = torch.stack([torch.stack(v) for v in value])
luopl's avatar
luopl committed
189
190
191
            elif not isinstance(value, torch.Tensor):
                value = torch.tensor(value)

luopl's avatar
luopl committed
192
193
194
            if torch.is_floating_point(value):  # cast data dtype for paligemma
                value = value.to(model.dtype)

chenych's avatar
chenych committed
195
196
197
198
            if key == "second_per_grid_ts":  # qwen2.5vl special case
                gen_kwargs[key] = value.tolist()
            else:
                gen_kwargs[key] = value.to(model.device)
chenych's avatar
chenych committed
199

luopl's avatar
luopl committed
200
201
202
        if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
            gen_kwargs["input_ids"] = inputs
            gen_kwargs["tokenizer"] = tokenizer
chenych's avatar
chenych committed
203
204
205
206
            if "audio_feature_lens" in mm_inputs:
                gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]

            gen_kwargs.pop("image_sizes", None)
luopl's avatar
luopl committed
207

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
208
209
210
211
212
213
214
        return gen_kwargs, prompt_length

    @staticmethod
    @torch.inference_mode()
    def _chat(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
215
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
216
        template: "Template",
chenych's avatar
chenych committed
217
218
        generating_args: dict[str, Any],
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
219
220
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
221
222
223
224
225
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
        input_kwargs: Optional[dict[str, Any]] = {},
    ) -> list["Response"]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
226
        gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
luopl's avatar
luopl committed
227
228
229
230
231
232
233
234
235
236
            model,
            tokenizer,
            processor,
            template,
            generating_args,
            messages,
            system,
            tools,
            images,
            videos,
chenych's avatar
chenych committed
237
            audios,
luopl's avatar
luopl committed
238
            input_kwargs,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
239
240
        )
        generate_output = model.generate(**gen_kwargs)
luopl's avatar
luopl committed
241
242
243
        if isinstance(generate_output, tuple):
            generate_output = generate_output[1][0]  # post-process the minicpm_o output

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
244
        response_ids = generate_output[:, prompt_length:]
luopl's avatar
luopl committed
245
        response = tokenizer.batch_decode(
chenych's avatar
chenych committed
246
247
248
            response_ids,
            skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
            clean_up_tokenization_spaces=True,
luopl's avatar
luopl committed
249
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        results = []
        for i in range(len(response)):
            eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
            response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
            results.append(
                Response(
                    response_text=response[i],
                    response_length=response_length,
                    prompt_length=prompt_length,
                    finish_reason="stop" if len(eos_index) else "length",
                )
            )

        return results

    @staticmethod
    @torch.inference_mode()
    def _stream_chat(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
270
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
271
        template: "Template",
chenych's avatar
chenych committed
272
273
        generating_args: dict[str, Any],
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
274
275
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
276
277
278
279
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
        input_kwargs: Optional[dict[str, Any]] = {},
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
280
281
    ) -> Callable[[], str]:
        gen_kwargs, _ = HuggingfaceEngine._process_args(
luopl's avatar
luopl committed
282
283
284
285
286
287
288
289
290
291
            model,
            tokenizer,
            processor,
            template,
            generating_args,
            messages,
            system,
            tools,
            images,
            videos,
chenych's avatar
chenych committed
292
            audios,
luopl's avatar
luopl committed
293
            input_kwargs,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
294
        )
luopl's avatar
luopl committed
295
        streamer = TextIteratorStreamer(
chenych's avatar
chenych committed
296
297
298
            tokenizer,
            skip_prompt=True,
            skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
luopl's avatar
luopl committed
299
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        gen_kwargs["streamer"] = streamer
        thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
        thread.start()

        def stream():
            try:
                return streamer.__next__()
            except StopIteration:
                raise StopAsyncIteration()

        return stream

    @staticmethod
    @torch.inference_mode()
    def _get_scores(
        model: "PreTrainedModelWrapper",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
317
318
319
        batch_input: list[str],
        input_kwargs: Optional[dict[str, Any]] = {},
    ) -> list[float]:
luopl's avatar
luopl committed
320
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
321
        device = getattr(model.pretrained_model, "device", "cuda")
chenych's avatar
chenych committed
322
        inputs: dict[str, torch.Tensor] = tokenizer(
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
323
324
325
326
327
            batch_input,
            padding=True,
            truncation=True,
            max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
            return_tensors="pt",
luopl's avatar
luopl committed
328
            add_special_tokens=False,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
329
        ).to(device)
chenych's avatar
chenych committed
330
        values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
luopl's avatar
luopl committed
331
        scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
332
333
        return scores

luopl's avatar
luopl committed
334
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
335
336
    async def chat(
        self,
chenych's avatar
chenych committed
337
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
338
339
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
340
341
342
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
343
        **input_kwargs,
chenych's avatar
chenych committed
344
    ) -> list["Response"]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
345
346
347
348
349
350
        if not self.can_generate:
            raise ValueError("The current model does not support `chat`.")

        input_args = (
            self.model,
            self.tokenizer,
chenych's avatar
chenych committed
351
            self.processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
352
353
354
355
356
            self.template,
            self.generating_args,
            messages,
            system,
            tools,
luopl's avatar
luopl committed
357
358
            images,
            videos,
chenych's avatar
chenych committed
359
            audios,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
360
361
            input_kwargs,
        )
chenych's avatar
chenych committed
362
        async with self.semaphore:
chenych's avatar
chenych committed
363
            return await asyncio.to_thread(self._chat, *input_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
364

luopl's avatar
luopl committed
365
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
366
367
    async def stream_chat(
        self,
chenych's avatar
chenych committed
368
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
369
370
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
371
372
373
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
374
375
376
377
378
379
380
381
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
        if not self.can_generate:
            raise ValueError("The current model does not support `stream_chat`.")

        input_args = (
            self.model,
            self.tokenizer,
chenych's avatar
chenych committed
382
            self.processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
383
384
385
386
387
            self.template,
            self.generating_args,
            messages,
            system,
            tools,
luopl's avatar
luopl committed
388
389
            images,
            videos,
chenych's avatar
chenych committed
390
            audios,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
391
392
            input_kwargs,
        )
chenych's avatar
chenych committed
393
        async with self.semaphore:
chenych's avatar
chenych committed
394
395
396
397
398
399
            stream = self._stream_chat(*input_args)
            while True:
                try:
                    yield await asyncio.to_thread(stream)
                except StopAsyncIteration:
                    break
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
400

luopl's avatar
luopl committed
401
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
402
403
    async def get_scores(
        self,
chenych's avatar
chenych committed
404
        batch_input: list[str],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
405
        **input_kwargs,
chenych's avatar
chenych committed
406
    ) -> list[float]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
407
408
409
410
        if self.can_generate:
            raise ValueError("Cannot get scores using an auto-regressive model.")

        input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
chenych's avatar
chenych committed
411
        async with self.semaphore:
chenych's avatar
chenych committed
412
            return await asyncio.to_thread(self._get_scores, *input_args)