hf_engine.py 16.1 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
15
16
import asyncio
import os
chenych's avatar
chenych committed
17
from collections.abc import AsyncGenerator
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
from threading import Thread
chenych's avatar
chenych committed
19
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
20
21
22

import torch
from transformers import GenerationConfig, TextIteratorStreamer
luopl's avatar
luopl committed
23
from typing_extensions import override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
24
25

from ..data import get_template_and_fix_tokenizer
luopl's avatar
luopl committed
26
from ..extras import logging
chenych's avatar
chenych committed
27
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
28
29
30
31
32
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response


if TYPE_CHECKING:
chenych's avatar
chenych committed
33
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
34
35
36
    from trl import PreTrainedModelWrapper

    from ..data import Template
chenych's avatar
chenych committed
37
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
38
39
40
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


luopl's avatar
luopl committed
41
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
42
43


Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
44
45
46
47
48
49
50
51
class HuggingfaceEngine(BaseEngine):
    def __init__(
        self,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
    ) -> None:
chenych's avatar
chenych committed
52
        self.name = EngineName.HF
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
53
        self.can_generate = finetuning_args.stage == "sft"
chenych's avatar
chenych committed
54
55
56
        tokenizer_module = load_tokenizer(model_args)
        self.tokenizer = tokenizer_module["tokenizer"]
        self.processor = tokenizer_module["processor"]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
57
        self.tokenizer.padding_side = "left" if self.can_generate else "right"
luopl's avatar
luopl committed
58
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
59
60
61
62
        self.model = load_model(
            self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
        )  # must after fixing tokenizer to resize vocab
        self.generating_args = generating_args.to_dict()
chenych's avatar
chenych committed
63
64
65
        try:
            asyncio.get_event_loop()
        except RuntimeError:
luopl's avatar
luopl committed
66
            logger.warning_rank0_once("There is no current event loop, creating a new one.")
chenych's avatar
chenych committed
67
68
69
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

luopl's avatar
luopl committed
70
        self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
71
72
73
74
75

    @staticmethod
    def _process_args(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
76
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
77
        template: "Template",
chenych's avatar
chenych committed
78
79
        generating_args: dict[str, Any],
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
80
81
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
82
83
84
85
86
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
        input_kwargs: Optional[dict[str, Any]] = {},
    ) -> tuple[dict[str, Any], int]:
chenych's avatar
chenych committed
87
        mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
luopl's avatar
luopl committed
88
89
90
91
        if images is not None:
            mm_input_dict.update({"images": images, "imglens": [len(images)]})
            if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
luopl's avatar
luopl committed
92

luopl's avatar
luopl committed
93
94
95
96
        if videos is not None:
            mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
            if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
luopl's avatar
luopl committed
97

chenych's avatar
chenych committed
98
99
100
101
102
        if audios is not None:
            mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
            if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]

luopl's avatar
luopl committed
103
        messages = template.mm_plugin.process_messages(
chenych's avatar
chenych committed
104
            messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
luopl's avatar
luopl committed
105
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
106
        paired_messages = messages + [{"role": "assistant", "content": ""}]
chenych's avatar
chenych committed
107
        system = system or generating_args["default_system"]
luopl's avatar
luopl committed
108
109
        prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
        prompt_ids, _ = template.mm_plugin.process_token_ids(
chenych's avatar
chenych committed
110
111
112
113
114
115
116
            prompt_ids,
            None,
            mm_input_dict["images"],
            mm_input_dict["videos"],
            mm_input_dict["audios"],
            tokenizer,
            processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
117
118
119
        )
        prompt_length = len(prompt_ids)
        inputs = torch.tensor([prompt_ids], device=model.device)
chenych's avatar
chenych committed
120
121
122
123
124
125
126
127
128
        attention_mask = torch.ones_like(inputs, dtype=torch.bool)

        do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
        temperature: Optional[float] = input_kwargs.pop("temperature", None)
        top_p: Optional[float] = input_kwargs.pop("top_p", None)
        top_k: Optional[float] = input_kwargs.pop("top_k", None)
        num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
        repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
        length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
chenych's avatar
chenych committed
129
        skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
chenych's avatar
chenych committed
130
131
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
        max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
chenych's avatar
chenych committed
132
        stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
chenych's avatar
chenych committed
133
134

        if stop is not None:
luopl's avatar
luopl committed
135
            logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
chenych's avatar
chenych committed
136
137

        generating_args = generating_args.copy()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
138
139
140
        generating_args.update(
            dict(
                do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
chenych's avatar
chenych committed
141
142
143
144
145
146
147
148
                temperature=temperature if temperature is not None else generating_args["temperature"],
                top_p=top_p if top_p is not None else generating_args["top_p"],
                top_k=top_k if top_k is not None else generating_args["top_k"],
                num_return_sequences=num_return_sequences,
                repetition_penalty=repetition_penalty
                if repetition_penalty is not None
                else generating_args["repetition_penalty"],
                length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
chenych's avatar
chenych committed
149
150
151
                skip_special_tokens=skip_special_tokens
                if skip_special_tokens is not None
                else generating_args["skip_special_tokens"],
luopl's avatar
luopl committed
152
                eos_token_id=template.get_stop_token_ids(tokenizer),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
153
154
155
156
                pad_token_id=tokenizer.pad_token_id,
            )
        )

chenych's avatar
chenych committed
157
        if isinstance(num_return_sequences, int) and num_return_sequences > 1:  # do_sample needs temperature > 0
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
158
            generating_args["do_sample"] = True
chenych's avatar
chenych committed
159
160
161
162
163
164
165
166
            generating_args["temperature"] = generating_args["temperature"] or 1.0

        if not generating_args["temperature"]:
            generating_args["do_sample"] = False

        if not generating_args["do_sample"]:
            generating_args.pop("temperature", None)
            generating_args.pop("top_p", None)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
167
168
169
170
171
172
173
174
175
176
177

        if max_length:
            generating_args.pop("max_new_tokens", None)
            generating_args["max_length"] = max_length

        if max_new_tokens:
            generating_args.pop("max_length", None)
            generating_args["max_new_tokens"] = max_new_tokens

        gen_kwargs = dict(
            inputs=inputs,
chenych's avatar
chenych committed
178
            attention_mask=attention_mask,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
179
180
181
            generation_config=GenerationConfig(**generating_args),
        )

luopl's avatar
luopl committed
182
        mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
luopl's avatar
luopl committed
183
        for key, value in mm_inputs.items():
chenych's avatar
chenych committed
184
            if isinstance(value, list) and isinstance(value[0], torch.Tensor):  # for pixtral inputs
luopl's avatar
luopl committed
185
                value = torch.stack(value)  # assume they have same sizes
chenych's avatar
chenych committed
186
187
188
            elif (
                isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor)
            ):  # for minicpmv inputs
luopl's avatar
luopl committed
189
                value = torch.stack([torch.stack(v) for v in value])
luopl's avatar
luopl committed
190
191
192
            elif not isinstance(value, torch.Tensor):
                value = torch.tensor(value)

luopl's avatar
luopl committed
193
194
195
            if torch.is_floating_point(value):  # cast data dtype for paligemma
                value = value.to(model.dtype)

chenych's avatar
chenych committed
196
197
198
199
            if key == "second_per_grid_ts":  # qwen2.5vl special case
                gen_kwargs[key] = value.tolist()
            else:
                gen_kwargs[key] = value.to(model.device)
chenych's avatar
chenych committed
200

luopl's avatar
luopl committed
201
202
203
        if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
            gen_kwargs["input_ids"] = inputs
            gen_kwargs["tokenizer"] = tokenizer
chenych's avatar
chenych committed
204
205
206
207
            if "audio_feature_lens" in mm_inputs:
                gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]

            gen_kwargs.pop("image_sizes", None)
luopl's avatar
luopl committed
208

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
209
210
211
212
213
214
215
        return gen_kwargs, prompt_length

    @staticmethod
    @torch.inference_mode()
    def _chat(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
216
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
217
        template: "Template",
chenych's avatar
chenych committed
218
219
        generating_args: dict[str, Any],
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
220
221
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
222
223
224
225
226
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
        input_kwargs: Optional[dict[str, Any]] = {},
    ) -> list["Response"]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
227
        gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
luopl's avatar
luopl committed
228
229
230
231
232
233
234
235
236
237
            model,
            tokenizer,
            processor,
            template,
            generating_args,
            messages,
            system,
            tools,
            images,
            videos,
chenych's avatar
chenych committed
238
            audios,
luopl's avatar
luopl committed
239
            input_kwargs,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
240
241
        )
        generate_output = model.generate(**gen_kwargs)
luopl's avatar
luopl committed
242
243
244
        if isinstance(generate_output, tuple):
            generate_output = generate_output[1][0]  # post-process the minicpm_o output

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
245
        response_ids = generate_output[:, prompt_length:]
luopl's avatar
luopl committed
246
        response = tokenizer.batch_decode(
chenych's avatar
chenych committed
247
248
249
            response_ids,
            skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
            clean_up_tokenization_spaces=True,
luopl's avatar
luopl committed
250
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        results = []
        for i in range(len(response)):
            eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
            response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
            results.append(
                Response(
                    response_text=response[i],
                    response_length=response_length,
                    prompt_length=prompt_length,
                    finish_reason="stop" if len(eos_index) else "length",
                )
            )

        return results

    @staticmethod
    @torch.inference_mode()
    def _stream_chat(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
271
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
272
        template: "Template",
chenych's avatar
chenych committed
273
274
        generating_args: dict[str, Any],
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
275
276
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
277
278
279
280
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
        input_kwargs: Optional[dict[str, Any]] = {},
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
281
282
    ) -> Callable[[], str]:
        gen_kwargs, _ = HuggingfaceEngine._process_args(
luopl's avatar
luopl committed
283
284
285
286
287
288
289
290
291
292
            model,
            tokenizer,
            processor,
            template,
            generating_args,
            messages,
            system,
            tools,
            images,
            videos,
chenych's avatar
chenych committed
293
            audios,
luopl's avatar
luopl committed
294
            input_kwargs,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
295
        )
luopl's avatar
luopl committed
296
        streamer = TextIteratorStreamer(
chenych's avatar
chenych committed
297
298
299
            tokenizer,
            skip_prompt=True,
            skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
luopl's avatar
luopl committed
300
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        gen_kwargs["streamer"] = streamer
        thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
        thread.start()

        def stream():
            try:
                return streamer.__next__()
            except StopIteration:
                raise StopAsyncIteration()

        return stream

    @staticmethod
    @torch.inference_mode()
    def _get_scores(
        model: "PreTrainedModelWrapper",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
318
319
320
        batch_input: list[str],
        input_kwargs: Optional[dict[str, Any]] = {},
    ) -> list[float]:
luopl's avatar
luopl committed
321
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
322
        device = getattr(model.pretrained_model, "device", "cuda")
chenych's avatar
chenych committed
323
        inputs: dict[str, torch.Tensor] = tokenizer(
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
324
325
326
327
328
            batch_input,
            padding=True,
            truncation=True,
            max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
            return_tensors="pt",
luopl's avatar
luopl committed
329
            add_special_tokens=False,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
330
        ).to(device)
chenych's avatar
chenych committed
331
        values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
luopl's avatar
luopl committed
332
        scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
333
334
        return scores

luopl's avatar
luopl committed
335
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
336
337
    async def chat(
        self,
chenych's avatar
chenych committed
338
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
339
340
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
341
342
343
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
344
        **input_kwargs,
chenych's avatar
chenych committed
345
    ) -> list["Response"]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
346
347
348
349
350
351
        if not self.can_generate:
            raise ValueError("The current model does not support `chat`.")

        input_args = (
            self.model,
            self.tokenizer,
chenych's avatar
chenych committed
352
            self.processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
353
354
355
356
357
            self.template,
            self.generating_args,
            messages,
            system,
            tools,
luopl's avatar
luopl committed
358
359
            images,
            videos,
chenych's avatar
chenych committed
360
            audios,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
361
362
            input_kwargs,
        )
chenych's avatar
chenych committed
363
        async with self.semaphore:
chenych's avatar
chenych committed
364
            return await asyncio.to_thread(self._chat, *input_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
365

luopl's avatar
luopl committed
366
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
367
368
    async def stream_chat(
        self,
chenych's avatar
chenych committed
369
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
370
371
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
372
373
374
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
375
376
377
378
379
380
381
382
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
        if not self.can_generate:
            raise ValueError("The current model does not support `stream_chat`.")

        input_args = (
            self.model,
            self.tokenizer,
chenych's avatar
chenych committed
383
            self.processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
384
385
386
387
388
            self.template,
            self.generating_args,
            messages,
            system,
            tools,
luopl's avatar
luopl committed
389
390
            images,
            videos,
chenych's avatar
chenych committed
391
            audios,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
392
393
            input_kwargs,
        )
chenych's avatar
chenych committed
394
        async with self.semaphore:
chenych's avatar
chenych committed
395
396
397
398
399
400
            stream = self._stream_chat(*input_args)
            while True:
                try:
                    yield await asyncio.to_thread(stream)
                except StopAsyncIteration:
                    break
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
401

luopl's avatar
luopl committed
402
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
403
404
    async def get_scores(
        self,
chenych's avatar
chenych committed
405
        batch_input: list[str],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
406
        **input_kwargs,
chenych's avatar
chenych committed
407
    ) -> list[float]:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
408
409
410
411
        if self.can_generate:
            raise ValueError("Cannot get scores using an auto-regressive model.")

        input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
chenych's avatar
chenych committed
412
        async with self.semaphore:
chenych's avatar
chenych committed
413
            return await asyncio.to_thread(self._get_scores, *input_args)