hf_engine.py 16.7 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
15
16
17
18
import asyncio
import concurrent.futures
import os
from threading import Thread
chenych's avatar
chenych committed
19
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
20
21
22

import torch
from transformers import GenerationConfig, TextIteratorStreamer
luopl's avatar
luopl committed
23
from typing_extensions import override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
24
25

from ..data import get_template_and_fix_tokenizer
luopl's avatar
luopl committed
26
from ..extras import logging
chenych's avatar
chenych committed
27
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
28
29
30
31
32
33
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response


if TYPE_CHECKING:
chenych's avatar
chenych committed
34
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
35
36
37
    from trl import PreTrainedModelWrapper

    from ..data import Template
chenych's avatar
chenych committed
38
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
39
40
41
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


luopl's avatar
luopl committed
42
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
43
44


Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
45
46
47
48
49
50
51
52
53
class HuggingfaceEngine(BaseEngine):
    def __init__(
        self,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
    ) -> None:
        self.can_generate = finetuning_args.stage == "sft"
chenych's avatar
chenych committed
54
55
56
        tokenizer_module = load_tokenizer(model_args)
        self.tokenizer = tokenizer_module["tokenizer"]
        self.processor = tokenizer_module["processor"]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
57
        self.tokenizer.padding_side = "left" if self.can_generate else "right"
luopl's avatar
luopl committed
58
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
59
60
61
62
        self.model = load_model(
            self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
        )  # must after fixing tokenizer to resize vocab
        self.generating_args = generating_args.to_dict()
chenych's avatar
chenych committed
63
64
65
        try:
            asyncio.get_event_loop()
        except RuntimeError:
luopl's avatar
luopl committed
66
            logger.warning_rank0_once("There is no current event loop, creating a new one.")
chenych's avatar
chenych committed
67
68
69
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

luopl's avatar
luopl committed
70
        self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
71
72
73
74
75

    @staticmethod
    def _process_args(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
76
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
77
78
79
80
81
        template: "Template",
        generating_args: Dict[str, Any],
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
82
83
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
84
        audios: Optional[Sequence["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
85
86
        input_kwargs: Optional[Dict[str, Any]] = {},
    ) -> Tuple[Dict[str, Any], int]:
chenych's avatar
chenych committed
87
        mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
luopl's avatar
luopl committed
88
89
90
91
        if images is not None:
            mm_input_dict.update({"images": images, "imglens": [len(images)]})
            if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
luopl's avatar
luopl committed
92

luopl's avatar
luopl committed
93
94
95
96
        if videos is not None:
            mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
            if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
luopl's avatar
luopl committed
97

chenych's avatar
chenych committed
98
99
100
101
102
        if audios is not None:
            mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
            if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]

luopl's avatar
luopl committed
103
        messages = template.mm_plugin.process_messages(
chenych's avatar
chenych committed
104
            messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
luopl's avatar
luopl committed
105
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
106
        paired_messages = messages + [{"role": "assistant", "content": ""}]
chenych's avatar
chenych committed
107
        system = system or generating_args["default_system"]
luopl's avatar
luopl committed
108
109
        prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
        prompt_ids, _ = template.mm_plugin.process_token_ids(
chenych's avatar
chenych committed
110
111
112
113
114
115
116
            prompt_ids,
            None,
            mm_input_dict["images"],
            mm_input_dict["videos"],
            mm_input_dict["audios"],
            tokenizer,
            processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
117
118
119
        )
        prompt_length = len(prompt_ids)
        inputs = torch.tensor([prompt_ids], device=model.device)
chenych's avatar
chenych committed
120
121
122
123
124
125
126
127
128
        attention_mask = torch.ones_like(inputs, dtype=torch.bool)

        do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
        temperature: Optional[float] = input_kwargs.pop("temperature", None)
        top_p: Optional[float] = input_kwargs.pop("top_p", None)
        top_k: Optional[float] = input_kwargs.pop("top_k", None)
        num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
        repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
        length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
chenych's avatar
chenych committed
129
        skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
chenych's avatar
chenych committed
130
131
132
133
134
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
        max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
        stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)

        if stop is not None:
luopl's avatar
luopl committed
135
            logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
chenych's avatar
chenych committed
136
137

        generating_args = generating_args.copy()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
138
139
140
        generating_args.update(
            dict(
                do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
chenych's avatar
chenych committed
141
142
143
144
145
146
147
148
                temperature=temperature if temperature is not None else generating_args["temperature"],
                top_p=top_p if top_p is not None else generating_args["top_p"],
                top_k=top_k if top_k is not None else generating_args["top_k"],
                num_return_sequences=num_return_sequences,
                repetition_penalty=repetition_penalty
                if repetition_penalty is not None
                else generating_args["repetition_penalty"],
                length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
chenych's avatar
chenych committed
149
150
151
                skip_special_tokens=skip_special_tokens
                if skip_special_tokens is not None
                else generating_args["skip_special_tokens"],
luopl's avatar
luopl committed
152
                eos_token_id=template.get_stop_token_ids(tokenizer),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
153
154
155
156
                pad_token_id=tokenizer.pad_token_id,
            )
        )

chenych's avatar
chenych committed
157
        if isinstance(num_return_sequences, int) and num_return_sequences > 1:  # do_sample needs temperature > 0
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
158
            generating_args["do_sample"] = True
chenych's avatar
chenych committed
159
160
161
162
163
164
165
166
            generating_args["temperature"] = generating_args["temperature"] or 1.0

        if not generating_args["temperature"]:
            generating_args["do_sample"] = False

        if not generating_args["do_sample"]:
            generating_args.pop("temperature", None)
            generating_args.pop("top_p", None)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
167
168
169
170
171
172
173
174
175
176
177

        if max_length:
            generating_args.pop("max_new_tokens", None)
            generating_args["max_length"] = max_length

        if max_new_tokens:
            generating_args.pop("max_length", None)
            generating_args["max_new_tokens"] = max_new_tokens

        gen_kwargs = dict(
            inputs=inputs,
chenych's avatar
chenych committed
178
            attention_mask=attention_mask,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
179
180
181
182
            generation_config=GenerationConfig(**generating_args),
            logits_processor=get_logits_processor(),
        )

luopl's avatar
luopl committed
183
        mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
luopl's avatar
luopl committed
184
        for key, value in mm_inputs.items():
chenych's avatar
chenych committed
185
            if isinstance(value, list) and isinstance(value[0], torch.Tensor):  # for pixtral inputs
luopl's avatar
luopl committed
186
                value = torch.stack(value)  # assume they have same sizes
chenych's avatar
chenych committed
187
188
189
            elif (
                isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor)
            ):  # for minicpmv inputs
luopl's avatar
luopl committed
190
                value = torch.stack([torch.stack(v) for v in value])
luopl's avatar
luopl committed
191
192
193
            elif not isinstance(value, torch.Tensor):
                value = torch.tensor(value)

luopl's avatar
luopl committed
194
195
196
            if torch.is_floating_point(value):  # cast data dtype for paligemma
                value = value.to(model.dtype)

chenych's avatar
chenych committed
197
198
199
200
            if key == "second_per_grid_ts":  # qwen2.5vl special case
                gen_kwargs[key] = value.tolist()
            else:
                gen_kwargs[key] = value.to(model.device)
chenych's avatar
chenych committed
201

luopl's avatar
luopl committed
202
203
204
        if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
            gen_kwargs["input_ids"] = inputs
            gen_kwargs["tokenizer"] = tokenizer
chenych's avatar
chenych committed
205
206
207
208
            if "audio_feature_lens" in mm_inputs:
                gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]

            gen_kwargs.pop("image_sizes", None)
luopl's avatar
luopl committed
209

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
210
211
212
213
214
215
216
        return gen_kwargs, prompt_length

    @staticmethod
    @torch.inference_mode()
    def _chat(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
217
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
218
219
220
221
222
        template: "Template",
        generating_args: Dict[str, Any],
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
223
224
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
225
        audios: Optional[Sequence["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
226
227
228
        input_kwargs: Optional[Dict[str, Any]] = {},
    ) -> List["Response"]:
        gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
luopl's avatar
luopl committed
229
230
231
232
233
234
235
236
237
238
            model,
            tokenizer,
            processor,
            template,
            generating_args,
            messages,
            system,
            tools,
            images,
            videos,
chenych's avatar
chenych committed
239
            audios,
luopl's avatar
luopl committed
240
            input_kwargs,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
241
242
        )
        generate_output = model.generate(**gen_kwargs)
luopl's avatar
luopl committed
243
244
245
        if isinstance(generate_output, tuple):
            generate_output = generate_output[1][0]  # post-process the minicpm_o output

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
246
        response_ids = generate_output[:, prompt_length:]
luopl's avatar
luopl committed
247
        response = tokenizer.batch_decode(
chenych's avatar
chenych committed
248
249
250
            response_ids,
            skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
            clean_up_tokenization_spaces=True,
luopl's avatar
luopl committed
251
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        results = []
        for i in range(len(response)):
            eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
            response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
            results.append(
                Response(
                    response_text=response[i],
                    response_length=response_length,
                    prompt_length=prompt_length,
                    finish_reason="stop" if len(eos_index) else "length",
                )
            )

        return results

    @staticmethod
    @torch.inference_mode()
    def _stream_chat(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
272
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
273
274
275
276
277
        template: "Template",
        generating_args: Dict[str, Any],
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
278
279
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
280
        audios: Optional[Sequence["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
281
282
283
        input_kwargs: Optional[Dict[str, Any]] = {},
    ) -> Callable[[], str]:
        gen_kwargs, _ = HuggingfaceEngine._process_args(
luopl's avatar
luopl committed
284
285
286
287
288
289
290
291
292
293
            model,
            tokenizer,
            processor,
            template,
            generating_args,
            messages,
            system,
            tools,
            images,
            videos,
chenych's avatar
chenych committed
294
            audios,
luopl's avatar
luopl committed
295
            input_kwargs,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
296
        )
luopl's avatar
luopl committed
297
        streamer = TextIteratorStreamer(
chenych's avatar
chenych committed
298
299
300
            tokenizer,
            skip_prompt=True,
            skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
luopl's avatar
luopl committed
301
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        gen_kwargs["streamer"] = streamer
        thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
        thread.start()

        def stream():
            try:
                return streamer.__next__()
            except StopIteration:
                raise StopAsyncIteration()

        return stream

    @staticmethod
    @torch.inference_mode()
    def _get_scores(
        model: "PreTrainedModelWrapper",
        tokenizer: "PreTrainedTokenizer",
        batch_input: List[str],
        input_kwargs: Optional[Dict[str, Any]] = {},
    ) -> List[float]:
luopl's avatar
luopl committed
322
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
323
        device = getattr(model.pretrained_model, "device", "cuda")
luopl's avatar
luopl committed
324
        inputs: Dict[str, "torch.Tensor"] = tokenizer(
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
325
326
327
328
329
            batch_input,
            padding=True,
            truncation=True,
            max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
            return_tensors="pt",
luopl's avatar
luopl committed
330
            add_special_tokens=False,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
331
        ).to(device)
luopl's avatar
luopl committed
332
333
        values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
        scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
334
335
        return scores

luopl's avatar
luopl committed
336
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
337
338
339
340
341
    async def chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
342
343
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
344
        audios: Optional[Sequence["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
345
346
347
348
349
350
351
352
353
        **input_kwargs,
    ) -> List["Response"]:
        if not self.can_generate:
            raise ValueError("The current model does not support `chat`.")

        loop = asyncio.get_running_loop()
        input_args = (
            self.model,
            self.tokenizer,
chenych's avatar
chenych committed
354
            self.processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
355
356
357
358
359
            self.template,
            self.generating_args,
            messages,
            system,
            tools,
luopl's avatar
luopl committed
360
361
            images,
            videos,
chenych's avatar
chenych committed
362
            audios,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
363
364
            input_kwargs,
        )
chenych's avatar
chenych committed
365
        async with self.semaphore:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
366
367
368
            with concurrent.futures.ThreadPoolExecutor() as pool:
                return await loop.run_in_executor(pool, self._chat, *input_args)

luopl's avatar
luopl committed
369
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
370
371
372
373
374
    async def stream_chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
375
376
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
377
        audios: Optional[Sequence["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
378
379
380
381
382
383
384
385
386
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
        if not self.can_generate:
            raise ValueError("The current model does not support `stream_chat`.")

        loop = asyncio.get_running_loop()
        input_args = (
            self.model,
            self.tokenizer,
chenych's avatar
chenych committed
387
            self.processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
388
389
390
391
392
            self.template,
            self.generating_args,
            messages,
            system,
            tools,
luopl's avatar
luopl committed
393
394
            images,
            videos,
chenych's avatar
chenych committed
395
            audios,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
396
397
            input_kwargs,
        )
chenych's avatar
chenych committed
398
        async with self.semaphore:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
399
400
401
402
403
404
405
406
            with concurrent.futures.ThreadPoolExecutor() as pool:
                stream = self._stream_chat(*input_args)
                while True:
                    try:
                        yield await loop.run_in_executor(pool, stream)
                    except StopAsyncIteration:
                        break

luopl's avatar
luopl committed
407
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
408
409
410
411
412
413
414
415
416
417
    async def get_scores(
        self,
        batch_input: List[str],
        **input_kwargs,
    ) -> List[float]:
        if self.can_generate:
            raise ValueError("Cannot get scores using an auto-regressive model.")

        loop = asyncio.get_running_loop()
        input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
chenych's avatar
chenych committed
418
        async with self.semaphore:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
419
420
            with concurrent.futures.ThreadPoolExecutor() as pool:
                return await loop.run_in_executor(pool, self._get_scores, *input_args)