hf_engine.py 16.2 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
15
16
import asyncio
import os
chenych's avatar
chenych committed
17
from collections.abc import AsyncGenerator
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
from threading import Thread
chenych's avatar
chenych committed
19
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
20
21
22

import torch
from transformers import GenerationConfig, TextIteratorStreamer
luopl's avatar
luopl committed
23
from typing_extensions import override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
24
25

from ..data import get_template_and_fix_tokenizer
luopl's avatar
luopl committed
26
from ..extras import logging
chenych's avatar
chenych committed
27
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
28
29
30
31
32
33
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response


if TYPE_CHECKING:
chenych's avatar
chenych committed
34
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
35
36
37
    from trl import PreTrainedModelWrapper

    from ..data import Template
chenych's avatar
chenych committed
38
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
39
40
41
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


luopl's avatar
luopl committed
42
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
43
44


Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
45
46
47
48
49
50
51
52
class HuggingfaceEngine(BaseEngine):
    def __init__(
        self,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
    ) -> None:
chenych's avatar
chenych committed
53
        self.name = EngineName.HF
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
54
        self.can_generate = finetuning_args.stage == "sft"
chenych's avatar
chenych committed
55
56
57
        tokenizer_module = load_tokenizer(model_args)
        self.tokenizer = tokenizer_module["tokenizer"]
        self.processor = tokenizer_module["processor"]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
58
        self.tokenizer.padding_side = "left" if self.can_generate else "right"
luopl's avatar
luopl committed
59
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
60
61
62
63
        self.model = load_model(
            self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
        )  # must after fixing tokenizer to resize vocab
        self.generating_args = generating_args.to_dict()
chenych's avatar
chenych committed
64
65
66
        try:
            asyncio.get_event_loop()
        except RuntimeError:
luopl's avatar
luopl committed
67
            logger.warning_rank0_once("There is no current event loop, creating a new one.")
chenych's avatar
chenych committed
68
69
70
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

luopl's avatar
luopl committed
71
        self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
72
73
74
75
76

    @staticmethod
    def _process_args(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
77
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
78
        template: "Template",
chenych's avatar
chenych committed
79
80
        generating_args: dict[str, Any],
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
81
82
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
83
84
85
86
87
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
        input_kwargs: Optional[dict[str, Any]] = {},
    ) -> tuple[dict[str, Any], int]:
chenych's avatar
chenych committed
88
        mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
luopl's avatar
luopl committed
89
90
91
92
        if images is not None:
            mm_input_dict.update({"images": images, "imglens": [len(images)]})
            if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
luopl's avatar
luopl committed
93

luopl's avatar
luopl committed
94
95
96
97
        if videos is not None:
            mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
            if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
luopl's avatar
luopl committed
98

chenych's avatar
chenych committed
99
100
101
102
103
        if audios is not None:
            mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
            if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]

luopl's avatar
luopl committed
104
        messages = template.mm_plugin.process_messages(
chenych's avatar
chenych committed
105
            messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
luopl's avatar
luopl committed
106
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
107
        paired_messages = messages + [{"role": "assistant", "content": ""}]
chenych's avatar
chenych committed
108
        system = system or generating_args["default_system"]
luopl's avatar
luopl committed
109
110
        prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
        prompt_ids, _ = template.mm_plugin.process_token_ids(
chenych's avatar
chenych committed
111
112
113
114
115
116
117
            prompt_ids,
            None,
            mm_input_dict["images"],
            mm_input_dict["videos"],
            mm_input_dict["audios"],
            tokenizer,
            processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
118
119
120
        )
        prompt_length = len(prompt_ids)
        inputs = torch.tensor([prompt_ids], device=model.device)
chenych's avatar
chenych committed
121
122
123
124
125
126
127
128
129
        attention_mask = torch.ones_like(inputs, dtype=torch.bool)

        do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
        temperature: Optional[float] = input_kwargs.pop("temperature", None)
        top_p: Optional[float] = input_kwargs.pop("top_p", None)
        top_k: Optional[float] = input_kwargs.pop("top_k", None)
        num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
        repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
        length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
chenych's avatar
chenych committed
130
        skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
chenych's avatar
chenych committed
131
132
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
        max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
chenych's avatar
chenych committed
133
        stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
chenych's avatar
chenych committed
134
135

        if stop is not None:
luopl's avatar
luopl committed
136
            logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
chenych's avatar
chenych committed
137
138

        generating_args = generating_args.copy()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
139
140
141
        generating_args.update(
            dict(
                do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
chenych's avatar
chenych committed
142
143
144
145
146
147
148
149
                temperature=temperature if temperature is not None else generating_args["temperature"],
                top_p=top_p if top_p is not None else generating_args["top_p"],
                top_k=top_k if top_k is not None else generating_args["top_k"],
                num_return_sequences=num_return_sequences,
                repetition_penalty=repetition_penalty
                if repetition_penalty is not None
                else generating_args["repetition_penalty"],
                length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
chenych's avatar
chenych committed
150
151
152
                skip_special_tokens=skip_special_tokens
                if skip_special_tokens is not None
                else generating_args["skip_special_tokens"],
luopl's avatar
luopl committed
153
                eos_token_id=template.get_stop_token_ids(tokenizer),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
154
155
156
157
                pad_token_id=tokenizer.pad_token_id,
            )
        )

chenych's avatar
chenych committed
158
        if isinstance(num_return_sequences, int) and num_return_sequences > 1:  # do_sample needs temperature > 0
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
159
            generating_args["do_sample"] = True
chenych's avatar
chenych committed
160
161
162
163
164
165
166
167
            generating_args["temperature"] = generating_args["temperature"] or 1.0

        if not generating_args["temperature"]:
            generating_args["do_sample"] = False

        if not generating_args["do_sample"]:
            generating_args.pop("temperature", None)
            generating_args.pop("top_p", None)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
168
169
170
171
172
173
174
175
176
177
178

        if max_length:
            generating_args.pop("max_new_tokens", None)
            generating_args["max_length"] = max_length

        if max_new_tokens:
            generating_args.pop("max_length", None)
            generating_args["max_new_tokens"] = max_new_tokens

        gen_kwargs = dict(
            inputs=inputs,
chenych's avatar
chenych committed
179
            attention_mask=attention_mask,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
180
181
182
183
            generation_config=GenerationConfig(**generating_args),
            logits_processor=get_logits_processor(),
        )

luopl's avatar
luopl committed
184
        mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
luopl's avatar
luopl committed
185
        for key, value in mm_inputs.items():
chenych's avatar
chenych committed
186
            if isinstance(value, list) and isinstance(value[0], torch.Tensor):  # for pixtral inputs
luopl's avatar
luopl committed
187
                value = torch.stack(value)  # assume they have same sizes
chenych's avatar
chenych committed
188
189
190
            elif (
                isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor)
            ):  # for minicpmv inputs
luopl's avatar
luopl committed
191
                value = torch.stack([torch.stack(v) for v in value])
luopl's avatar
luopl committed
192
193
194
            elif not isinstance(value, torch.Tensor):
                value = torch.tensor(value)

luopl's avatar
luopl committed
195
196
197
            if torch.is_floating_point(value):  # cast data dtype for paligemma
                value = value.to(model.dtype)

chenych's avatar
chenych committed
198
199
200
201
            if key == "second_per_grid_ts":  # qwen2.5vl special case
                gen_kwargs[key] = value.tolist()
            else:
                gen_kwargs[key] = value.to(model.device)
chenych's avatar
chenych committed
202

luopl's avatar
luopl committed
203
204
205
        if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
            gen_kwargs["input_ids"] = inputs
            gen_kwargs["tokenizer"] = tokenizer
chenych's avatar
chenych committed
206
207
208
209
            if "audio_feature_lens" in mm_inputs:
                gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]

            gen_kwargs.pop("image_sizes", None)
luopl's avatar
luopl committed
210

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
211
212
213
214
215
216
217
        return gen_kwargs, prompt_length

    @staticmethod
    @torch.inference_mode()
    def _chat(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
218
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
219
        template: "Template",
chenych's avatar
chenych committed
220
221
        generating_args: dict[str, Any],
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
222
223
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
224
225
226
227
228
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
        input_kwargs: Optional[dict[str, Any]] = {},
    ) -> list["Response"]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
229
        gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
luopl's avatar
luopl committed
230
231
232
233
234
235
236
237
238
239
            model,
            tokenizer,
            processor,
            template,
            generating_args,
            messages,
            system,
            tools,
            images,
            videos,
chenych's avatar
chenych committed
240
            audios,
luopl's avatar
luopl committed
241
            input_kwargs,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
242
243
        )
        generate_output = model.generate(**gen_kwargs)
luopl's avatar
luopl committed
244
245
246
        if isinstance(generate_output, tuple):
            generate_output = generate_output[1][0]  # post-process the minicpm_o output

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
247
        response_ids = generate_output[:, prompt_length:]
luopl's avatar
luopl committed
248
        response = tokenizer.batch_decode(
chenych's avatar
chenych committed
249
250
251
            response_ids,
            skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
            clean_up_tokenization_spaces=True,
luopl's avatar
luopl committed
252
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        results = []
        for i in range(len(response)):
            eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
            response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
            results.append(
                Response(
                    response_text=response[i],
                    response_length=response_length,
                    prompt_length=prompt_length,
                    finish_reason="stop" if len(eos_index) else "length",
                )
            )

        return results

    @staticmethod
    @torch.inference_mode()
    def _stream_chat(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
273
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
274
        template: "Template",
chenych's avatar
chenych committed
275
276
        generating_args: dict[str, Any],
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
277
278
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
279
280
281
282
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
        input_kwargs: Optional[dict[str, Any]] = {},
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
283
284
    ) -> Callable[[], str]:
        gen_kwargs, _ = HuggingfaceEngine._process_args(
luopl's avatar
luopl committed
285
286
287
288
289
290
291
292
293
294
            model,
            tokenizer,
            processor,
            template,
            generating_args,
            messages,
            system,
            tools,
            images,
            videos,
chenych's avatar
chenych committed
295
            audios,
luopl's avatar
luopl committed
296
            input_kwargs,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
297
        )
luopl's avatar
luopl committed
298
        streamer = TextIteratorStreamer(
chenych's avatar
chenych committed
299
300
301
            tokenizer,
            skip_prompt=True,
            skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
luopl's avatar
luopl committed
302
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        gen_kwargs["streamer"] = streamer
        thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
        thread.start()

        def stream():
            try:
                return streamer.__next__()
            except StopIteration:
                raise StopAsyncIteration()

        return stream

    @staticmethod
    @torch.inference_mode()
    def _get_scores(
        model: "PreTrainedModelWrapper",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
320
321
322
        batch_input: list[str],
        input_kwargs: Optional[dict[str, Any]] = {},
    ) -> list[float]:
luopl's avatar
luopl committed
323
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
324
        device = getattr(model.pretrained_model, "device", "cuda")
chenych's avatar
chenych committed
325
        inputs: dict[str, torch.Tensor] = tokenizer(
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
326
327
328
329
330
            batch_input,
            padding=True,
            truncation=True,
            max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
            return_tensors="pt",
luopl's avatar
luopl committed
331
            add_special_tokens=False,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
332
        ).to(device)
chenych's avatar
chenych committed
333
        values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
luopl's avatar
luopl committed
334
        scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
335
336
        return scores

luopl's avatar
luopl committed
337
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
338
339
    async def chat(
        self,
chenych's avatar
chenych committed
340
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
341
342
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
343
344
345
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
346
        **input_kwargs,
chenych's avatar
chenych committed
347
    ) -> list["Response"]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
348
349
350
351
352
353
        if not self.can_generate:
            raise ValueError("The current model does not support `chat`.")

        input_args = (
            self.model,
            self.tokenizer,
chenych's avatar
chenych committed
354
            self.processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
355
356
357
358
359
            self.template,
            self.generating_args,
            messages,
            system,
            tools,
luopl's avatar
luopl committed
360
361
            images,
            videos,
chenych's avatar
chenych committed
362
            audios,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
363
364
            input_kwargs,
        )
chenych's avatar
chenych committed
365
        async with self.semaphore:
chenych's avatar
chenych committed
366
            return await asyncio.to_thread(self._chat, *input_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
367

luopl's avatar
luopl committed
368
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
369
370
    async def stream_chat(
        self,
chenych's avatar
chenych committed
371
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
372
373
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
374
375
376
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
377
378
379
380
381
382
383
384
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
        if not self.can_generate:
            raise ValueError("The current model does not support `stream_chat`.")

        input_args = (
            self.model,
            self.tokenizer,
chenych's avatar
chenych committed
385
            self.processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
386
387
388
389
390
            self.template,
            self.generating_args,
            messages,
            system,
            tools,
luopl's avatar
luopl committed
391
392
            images,
            videos,
chenych's avatar
chenych committed
393
            audios,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
394
395
            input_kwargs,
        )
chenych's avatar
chenych committed
396
        async with self.semaphore:
chenych's avatar
chenych committed
397
398
399
400
401
402
            stream = self._stream_chat(*input_args)
            while True:
                try:
                    yield await asyncio.to_thread(stream)
                except StopAsyncIteration:
                    break
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
403

luopl's avatar
luopl committed
404
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
405
406
    async def get_scores(
        self,
chenych's avatar
chenych committed
407
        batch_input: list[str],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
408
        **input_kwargs,
chenych's avatar
chenych committed
409
    ) -> list[float]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
410
411
412
413
        if self.can_generate:
            raise ValueError("Cannot get scores using an auto-regressive model.")

        input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
chenych's avatar
chenych committed
414
        async with self.semaphore:
chenych's avatar
chenych committed
415
            return await asyncio.to_thread(self._get_scores, *input_args)