hf_engine.py 13.6 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
15
16
17
18
import asyncio
import concurrent.futures
import os
from threading import Thread
chenych's avatar
chenych committed
19
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
20
21
22
23
24

import torch
from transformers import GenerationConfig, TextIteratorStreamer

from ..data import get_template_and_fix_tokenizer
chenych's avatar
chenych committed
25
from ..extras.logging import get_logger
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
26
27
28
29
30
31
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response


if TYPE_CHECKING:
chenych's avatar
chenych committed
32
33
34
    from numpy.typing import NDArray
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
    from transformers.image_processing_utils import BaseImageProcessor
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
35
36
37
38
39
40
    from trl import PreTrainedModelWrapper

    from ..data import Template
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


chenych's avatar
chenych committed
41
42
43
logger = get_logger(__name__)


Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
44
45
46
47
48
49
50
51
52
class HuggingfaceEngine(BaseEngine):
    def __init__(
        self,
        model_args: "ModelArguments",
        data_args: "DataArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
    ) -> None:
        self.can_generate = finetuning_args.stage == "sft"
chenych's avatar
chenych committed
53
54
55
        tokenizer_module = load_tokenizer(model_args)
        self.tokenizer = tokenizer_module["tokenizer"]
        self.processor = tokenizer_module["processor"]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
56
        self.tokenizer.padding_side = "left" if self.can_generate else "right"
chenych's avatar
chenych committed
57
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
58
59
60
61
        self.model = load_model(
            self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
        )  # must after fixing tokenizer to resize vocab
        self.generating_args = generating_args.to_dict()
chenych's avatar
chenych committed
62
63
64
65
66
67
68
69
        try:
            asyncio.get_event_loop()
        except RuntimeError:
            logger.warning("There is no current event loop, creating a new one.")
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

        self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
70
71
72
73
74

    @staticmethod
    def _process_args(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
75
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
76
77
78
79
80
        template: "Template",
        generating_args: Dict[str, Any],
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
81
        image: Optional["NDArray"] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
82
83
        input_kwargs: Optional[Dict[str, Any]] = {},
    ) -> Tuple[Dict[str, Any], int]:
chenych's avatar
chenych committed
84
85
86
87
88
89
90
91
        if (
            processor is not None
            and image is not None
            and not hasattr(processor, "image_seq_length")
            and template.image_token not in messages[0]["content"]
        ):  # llava-like models
            messages[0]["content"] = template.image_token + messages[0]["content"]

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
92
        paired_messages = messages + [{"role": "assistant", "content": ""}]
chenych's avatar
chenych committed
93
94
        system = system or generating_args["default_system"]
        pixel_values = None
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
95
96
97
        prompt_ids, _ = template.encode_oneturn(
            tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
        )
chenych's avatar
chenych committed
98
99
100
101
102
103
104
105
        if processor is not None and image is not None:  # add image features
            image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
            batch_feature = image_processor(image, return_tensors="pt")
            pixel_values = batch_feature.to(model.device)["pixel_values"]  # shape (B, C, H, W)
            if hasattr(processor, "image_seq_length"):  # paligemma models
                image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
                prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
106
107
        prompt_length = len(prompt_ids)
        inputs = torch.tensor([prompt_ids], device=model.device)
chenych's avatar
chenych committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        attention_mask = torch.ones_like(inputs, dtype=torch.bool)

        do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
        temperature: Optional[float] = input_kwargs.pop("temperature", None)
        top_p: Optional[float] = input_kwargs.pop("top_p", None)
        top_k: Optional[float] = input_kwargs.pop("top_k", None)
        num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
        repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
        length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
        max_length: Optional[int] = input_kwargs.pop("max_length", None)
        max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
        stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)

        if stop is not None:
            logger.warning("Stop parameter is not supported by the huggingface engine yet.")

        generating_args = generating_args.copy()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
125
126
127
        generating_args.update(
            dict(
                do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
chenych's avatar
chenych committed
128
129
130
131
132
133
134
135
                temperature=temperature if temperature is not None else generating_args["temperature"],
                top_p=top_p if top_p is not None else generating_args["top_p"],
                top_k=top_k if top_k is not None else generating_args["top_k"],
                num_return_sequences=num_return_sequences,
                repetition_penalty=repetition_penalty
                if repetition_penalty is not None
                else generating_args["repetition_penalty"],
                length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
136
137
138
139
140
                eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
                pad_token_id=tokenizer.pad_token_id,
            )
        )

chenych's avatar
chenych committed
141
        if isinstance(num_return_sequences, int) and num_return_sequences > 1:  # do_sample needs temperature > 0
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
142
            generating_args["do_sample"] = True
chenych's avatar
chenych committed
143
144
145
146
147
148
149
150
            generating_args["temperature"] = generating_args["temperature"] or 1.0

        if not generating_args["temperature"]:
            generating_args["do_sample"] = False

        if not generating_args["do_sample"]:
            generating_args.pop("temperature", None)
            generating_args.pop("top_p", None)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
151
152
153
154
155
156
157
158
159
160
161

        if max_length:
            generating_args.pop("max_new_tokens", None)
            generating_args["max_length"] = max_length

        if max_new_tokens:
            generating_args.pop("max_length", None)
            generating_args["max_new_tokens"] = max_new_tokens

        gen_kwargs = dict(
            inputs=inputs,
chenych's avatar
chenych committed
162
            attention_mask=attention_mask,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
163
164
165
166
            generation_config=GenerationConfig(**generating_args),
            logits_processor=get_logits_processor(),
        )

chenych's avatar
chenych committed
167
168
169
        if pixel_values is not None:
            gen_kwargs["pixel_values"] = pixel_values

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
170
171
172
173
174
175
176
        return gen_kwargs, prompt_length

    @staticmethod
    @torch.inference_mode()
    def _chat(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
177
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
178
179
180
181
182
        template: "Template",
        generating_args: Dict[str, Any],
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
183
        image: Optional["NDArray"] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
184
185
186
        input_kwargs: Optional[Dict[str, Any]] = {},
    ) -> List["Response"]:
        gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
chenych's avatar
chenych committed
187
            model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        )
        generate_output = model.generate(**gen_kwargs)
        response_ids = generate_output[:, prompt_length:]
        response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        results = []
        for i in range(len(response)):
            eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
            response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
            results.append(
                Response(
                    response_text=response[i],
                    response_length=response_length,
                    prompt_length=prompt_length,
                    finish_reason="stop" if len(eos_index) else "length",
                )
            )

        return results

    @staticmethod
    @torch.inference_mode()
    def _stream_chat(
        model: "PreTrainedModel",
        tokenizer: "PreTrainedTokenizer",
chenych's avatar
chenych committed
212
        processor: Optional["ProcessorMixin"],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
213
214
215
216
217
        template: "Template",
        generating_args: Dict[str, Any],
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
218
        image: Optional["NDArray"] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
219
220
221
        input_kwargs: Optional[Dict[str, Any]] = {},
    ) -> Callable[[], str]:
        gen_kwargs, _ = HuggingfaceEngine._process_args(
chenych's avatar
chenych committed
222
            model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        )
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
        gen_kwargs["streamer"] = streamer
        thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
        thread.start()

        def stream():
            try:
                return streamer.__next__()
            except StopIteration:
                raise StopAsyncIteration()

        return stream

    @staticmethod
    @torch.inference_mode()
    def _get_scores(
        model: "PreTrainedModelWrapper",
        tokenizer: "PreTrainedTokenizer",
        batch_input: List[str],
        input_kwargs: Optional[Dict[str, Any]] = {},
    ) -> List[float]:
        max_length = input_kwargs.pop("max_length", None)
        device = getattr(model.pretrained_model, "device", "cuda")
        inputs = tokenizer(
            batch_input,
            padding=True,
            truncation=True,
            max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
            return_tensors="pt",
            add_special_tokens=True,
        ).to(device)

        input_ids: torch.Tensor = inputs["input_ids"]
        _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)

        if getattr(model.config, "model_type", None) == "chatglm":
            values = torch.transpose(values, 0, 1)

        scores = []
        for i in range(input_ids.size(0)):
            end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
            end_index = end_indexes[-1].item() if len(end_indexes) else 0
            scores.append(values[i, end_index].nan_to_num().item())

        return scores

    async def chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
275
        image: Optional["NDArray"] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
276
277
278
279
280
281
282
283
284
        **input_kwargs,
    ) -> List["Response"]:
        if not self.can_generate:
            raise ValueError("The current model does not support `chat`.")

        loop = asyncio.get_running_loop()
        input_args = (
            self.model,
            self.tokenizer,
chenych's avatar
chenych committed
285
            self.processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
286
287
288
289
290
            self.template,
            self.generating_args,
            messages,
            system,
            tools,
chenych's avatar
chenych committed
291
            image,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
292
293
            input_kwargs,
        )
chenych's avatar
chenych committed
294
        async with self.semaphore:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
295
296
297
298
299
300
301
302
            with concurrent.futures.ThreadPoolExecutor() as pool:
                return await loop.run_in_executor(pool, self._chat, *input_args)

    async def stream_chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
303
        image: Optional["NDArray"] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
304
305
306
307
308
309
310
311
312
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
        if not self.can_generate:
            raise ValueError("The current model does not support `stream_chat`.")

        loop = asyncio.get_running_loop()
        input_args = (
            self.model,
            self.tokenizer,
chenych's avatar
chenych committed
313
            self.processor,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
314
315
316
317
318
            self.template,
            self.generating_args,
            messages,
            system,
            tools,
chenych's avatar
chenych committed
319
            image,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
320
321
            input_kwargs,
        )
chenych's avatar
chenych committed
322
        async with self.semaphore:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
            with concurrent.futures.ThreadPoolExecutor() as pool:
                stream = self._stream_chat(*input_args)
                while True:
                    try:
                        yield await loop.run_in_executor(pool, stream)
                    except StopAsyncIteration:
                        break

    async def get_scores(
        self,
        batch_input: List[str],
        **input_kwargs,
    ) -> List[float]:
        if self.can_generate:
            raise ValueError("Cannot get scores using an auto-regressive model.")

        loop = asyncio.get_running_loop()
        input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
chenych's avatar
chenych committed
341
        async with self.semaphore:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
342
343
            with concurrent.futures.ThreadPoolExecutor() as pool:
                return await loop.run_in_executor(pool, self._get_scores, *input_args)