hf_engine.py 14.4 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
15
16
17
18
import asyncio
import concurrent.futures
import os
from threading import Thread
chenych's avatar
chenych committed
19
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
20
21
22

import torch
from transformers import GenerationConfig, TextIteratorStreamer
luopl's avatar
luopl committed
23
from typing_extensions import override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
24
25

from ..data import get_template_and_fix_tokenizer
luopl's avatar
luopl committed
26
from ..extras import logging
luopl's avatar
luopl committed
27
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
28
29
30
31
32
33
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response


if TYPE_CHECKING:
chenych's avatar
chenych committed
34
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
35
36
37
    from trl import PreTrainedModelWrapper

    from ..data import Template
luopl's avatar
luopl committed
38
    from ..data.mm_plugin import ImageInput, VideoInput
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
39
40
41
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


luopl's avatar
luopl committed
42
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
43
44


Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
45
46
47
48
49
50
51
52
53
class HuggingfaceEngine(BaseEngine):
    def __init__(
        self,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
    ) -> None:
        self.can_generate = finetuning_args.stage == "sft"
chenych's avatar
chenych committed
54
55
56
        tokenizer_module = load_tokenizer(model_args)
        self.tokenizer = tokenizer_module["tokenizer"]
        self.processor = tokenizer_module["processor"]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
57
        self.tokenizer.padding_side = "left" if self.can_generate else "right"
luopl's avatar
luopl committed
58
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
59
60
61
62
        self.model = load_model(
            self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
        )  # must after fixing tokenizer to resize vocab
        self.generating_args = generating_args.to_dict()
chenych's avatar
chenych committed
63
64
65
        try:
            asyncio.get_event_loop()
        except RuntimeError:
luopl's avatar
luopl committed
66
            logger.warning_once("There is no current event loop, creating a new one.")
chenych's avatar
chenych committed
67
68
69
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

luopl's avatar
luopl committed
70
        self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
71
72
73
74
75

    @staticmethod
    def _process_args(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
76
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
77
78
79
80
81
        template: "Template",
        generating_args: Dict[str, Any],
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
82
83
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
84
85
        input_kwargs: Optional[Dict[str, Any]] = {},
    ) -> Tuple[Dict[str, Any], int]:
luopl's avatar
luopl committed
86
        mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
luopl's avatar
luopl committed
87
88
89
90
        if images is not None:
            mm_input_dict.update({"images": images, "imglens": [len(images)]})
            if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
luopl's avatar
luopl committed
91

luopl's avatar
luopl committed
92
93
94
95
        if videos is not None:
            mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
            if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
                messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
luopl's avatar
luopl committed
96
97
98
99

        messages = template.mm_plugin.process_messages(
            messages, mm_input_dict["images"], mm_input_dict["videos"], processor
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
100
        paired_messages = messages + [{"role": "assistant", "content": ""}]
chenych's avatar
chenych committed
101
        system = system or generating_args["default_system"]
luopl's avatar
luopl committed
102
103
104
        prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
        prompt_ids, _ = template.mm_plugin.process_token_ids(
            prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
105
106
107
        )
        prompt_length = len(prompt_ids)
        inputs = torch.tensor([prompt_ids], device=model.device)
chenych's avatar
chenych committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        attention_mask = torch.ones_like(inputs, dtype=torch.bool)

        do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
        temperature: Optional[float] = input_kwargs.pop("temperature", None)
        top_p: Optional[float] = input_kwargs.pop("top_p", None)
        top_k: Optional[float] = input_kwargs.pop("top_k", None)
        num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
        repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
        length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
        max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
        stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)

        if stop is not None:
luopl's avatar
luopl committed
122
            logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
chenych's avatar
chenych committed
123
124

        generating_args = generating_args.copy()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
125
126
127
        generating_args.update(
            dict(
                do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
chenych's avatar
chenych committed
128
129
130
131
132
133
134
135
                temperature=temperature if temperature is not None else generating_args["temperature"],
                top_p=top_p if top_p is not None else generating_args["top_p"],
                top_k=top_k if top_k is not None else generating_args["top_k"],
                num_return_sequences=num_return_sequences,
                repetition_penalty=repetition_penalty
                if repetition_penalty is not None
                else generating_args["repetition_penalty"],
                length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
136
137
138
139
140
                eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
                pad_token_id=tokenizer.pad_token_id,
            )
        )

chenych's avatar
chenych committed
141
        if isinstance(num_return_sequences, int) and num_return_sequences > 1:  # do_sample needs temperature > 0
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
142
            generating_args["do_sample"] = True
chenych's avatar
chenych committed
143
144
145
146
147
148
149
150
            generating_args["temperature"] = generating_args["temperature"] or 1.0

        if not generating_args["temperature"]:
            generating_args["do_sample"] = False

        if not generating_args["do_sample"]:
            generating_args.pop("temperature", None)
            generating_args.pop("top_p", None)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
151
152
153
154
155
156
157
158
159
160
161

        if max_length:
            generating_args.pop("max_new_tokens", None)
            generating_args["max_length"] = max_length

        if max_new_tokens:
            generating_args.pop("max_length", None)
            generating_args["max_new_tokens"] = max_new_tokens

        gen_kwargs = dict(
            inputs=inputs,
chenych's avatar
chenych committed
162
            attention_mask=attention_mask,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
163
164
165
166
            generation_config=GenerationConfig(**generating_args),
            logits_processor=get_logits_processor(),
        )

luopl's avatar
luopl committed
167
        mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
luopl's avatar
luopl committed
168
        for key, value in mm_inputs.items():
luopl's avatar
luopl committed
169
170
171
172
173
            if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value):  # for pixtral inputs
                value = torch.stack(value)  # assume they have same sizes
            elif not isinstance(value, torch.Tensor):
                value = torch.tensor(value)

luopl's avatar
luopl committed
174
            gen_kwargs[key] = value.to(model.device)
chenych's avatar
chenych committed
175

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
176
177
178
179
180
181
182
        return gen_kwargs, prompt_length

    @staticmethod
    @torch.inference_mode()
    def _chat(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
183
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
184
185
186
187
188
        template: "Template",
        generating_args: Dict[str, Any],
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
189
190
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
191
192
193
        input_kwargs: Optional[Dict[str, Any]] = {},
    ) -> List["Response"]:
        gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
luopl's avatar
luopl committed
194
195
196
197
198
199
200
201
202
203
204
            model,
            tokenizer,
            processor,
            template,
            generating_args,
            messages,
            system,
            tools,
            images,
            videos,
            input_kwargs,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        )
        generate_output = model.generate(**gen_kwargs)
        response_ids = generate_output[:, prompt_length:]
        response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        results = []
        for i in range(len(response)):
            eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
            response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
            results.append(
                Response(
                    response_text=response[i],
                    response_length=response_length,
                    prompt_length=prompt_length,
                    finish_reason="stop" if len(eos_index) else "length",
                )
            )

        return results

    @staticmethod
    @torch.inference_mode()
    def _stream_chat(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
229
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
230
231
232
233
234
        template: "Template",
        generating_args: Dict[str, Any],
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
235
236
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
237
238
239
        input_kwargs: Optional[Dict[str, Any]] = {},
    ) -> Callable[[], str]:
        gen_kwargs, _ = HuggingfaceEngine._process_args(
luopl's avatar
luopl committed
240
241
242
243
244
245
246
247
248
249
250
            model,
            tokenizer,
            processor,
            template,
            generating_args,
            messages,
            system,
            tools,
            images,
            videos,
            input_kwargs,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        )
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
        gen_kwargs["streamer"] = streamer
        thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
        thread.start()

        def stream():
            try:
                return streamer.__next__()
            except StopIteration:
                raise StopAsyncIteration()

        return stream

    @staticmethod
    @torch.inference_mode()
    def _get_scores(
        model: "PreTrainedModelWrapper",
        tokenizer: "PreTrainedTokenizer",
        batch_input: List[str],
        input_kwargs: Optional[Dict[str, Any]] = {},
    ) -> List[float]:
luopl's avatar
luopl committed
273
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
274
        device = getattr(model.pretrained_model, "device", "cuda")
luopl's avatar
luopl committed
275
        inputs: Dict[str, "torch.Tensor"] = tokenizer(
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
276
277
278
279
280
            batch_input,
            padding=True,
            truncation=True,
            max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
            return_tensors="pt",
luopl's avatar
luopl committed
281
            add_special_tokens=False,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
282
        ).to(device)
luopl's avatar
luopl committed
283
284
        values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
        scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
285
286
        return scores

luopl's avatar
luopl committed
287
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
288
289
290
291
292
    async def chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
293
294
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
295
296
297
298
299
300
301
302
303
        **input_kwargs,
    ) -> List["Response"]:
        if not self.can_generate:
            raise ValueError("The current model does not support `chat`.")

        loop = asyncio.get_running_loop()
        input_args = (
            self.model,
            self.tokenizer,
chenych's avatar
chenych committed
304
            self.processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
305
306
307
308
309
            self.template,
            self.generating_args,
            messages,
            system,
            tools,
luopl's avatar
luopl committed
310
311
            images,
            videos,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
312
313
            input_kwargs,
        )
chenych's avatar
chenych committed
314
        async with self.semaphore:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
315
316
317
            with concurrent.futures.ThreadPoolExecutor() as pool:
                return await loop.run_in_executor(pool, self._chat, *input_args)

luopl's avatar
luopl committed
318
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
319
320
321
322
323
    async def stream_chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
324
325
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
326
327
328
329
330
331
332
333
334
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
        if not self.can_generate:
            raise ValueError("The current model does not support `stream_chat`.")

        loop = asyncio.get_running_loop()
        input_args = (
            self.model,
            self.tokenizer,
chenych's avatar
chenych committed
335
            self.processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
336
337
338
339
340
            self.template,
            self.generating_args,
            messages,
            system,
            tools,
luopl's avatar
luopl committed
341
342
            images,
            videos,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
343
344
            input_kwargs,
        )
chenych's avatar
chenych committed
345
        async with self.semaphore:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
346
347
348
349
350
351
352
353
            with concurrent.futures.ThreadPoolExecutor() as pool:
                stream = self._stream_chat(*input_args)
                while True:
                    try:
                        yield await loop.run_in_executor(pool, stream)
                    except StopAsyncIteration:
                        break

luopl's avatar
luopl committed
354
    @override
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
355
356
357
358
359
360
361
362
363
364
    async def get_scores(
        self,
        batch_input: List[str],
        **input_kwargs,
    ) -> List[float]:
        if self.can_generate:
            raise ValueError("Cannot get scores using an auto-regressive model.")

        loop = asyncio.get_running_loop()
        input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
chenych's avatar
chenych committed
365
        async with self.semaphore:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
366
367
            with concurrent.futures.ThreadPoolExecutor() as pool:
                return await loop.run_in_executor(pool, self._get_scores, *input_args)