transformers.py 14 KB
Newer Older
chenxl's avatar
chenxl committed
1
from typing import Any, List, Optional, Set
2
3
4
5
6
7
8
9
10
11
from transformers import (
    LlamaTokenizer,
    AutoTokenizer,
    AutoConfig,
    LlamaForCausalLM,
    GenerationConfig,
    StaticCache,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
chenxl's avatar
chenxl committed
12

liam's avatar
liam committed
13
from ktransformers.server.config.config import Config
chenxl's avatar
chenxl committed
14
15
16
17
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.utils.multi_timer import Profiler
import torch
import sys, os
18
from ..base import ThreadContext, BackendInterfaceBase
chenxl's avatar
chenxl committed
19
from ktransformers.server.config.log import logger
20
from ..args import ConfigArgs, default_args
chenxl's avatar
chenxl committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39


# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
class TextStreamer:

    def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
        self.tokenizer = tokenizer
        self.skip_prompt = skip_prompt
        self.decode_kwargs = decode_kwargs

        # variables used in the streaming process
        self.token_cache = []
        self.print_len = 0
        self.next_tokens_are_prompt = True

    def reset(self):
        self.token_cache = []
        self.print_len = 0

40
    def put(self, value) -> Optional[str]:
chenxl's avatar
chenxl committed
41
42
        """
        Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
43
44
        """
        if not isinstance(value, int):
chenxl's avatar
chenxl committed
45
46
47
48
49
50
51
52
            raise ValueError("TextStreamer only supports batch size 1, and int type input")

        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return None

        # Add the new token to the cache and decodes the entire thing.
        self.token_cache.append(value)
53
        text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
chenxl's avatar
chenxl committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

        # After the symbol for a new line, we flush the cache.
        if text.endswith("\n"):
            printable_text = text[self.print_len :]
            self.reset()
        # If the last token is a CJK character, we print the characters.
        elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
            printable_text = text[self.print_len :]
            self.print_len += len(printable_text)
        # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
        # which may change with the subsequent token -- there are probably smarter ways to do this!)
        else:
            printable_text = text[self.print_len : text.rfind(" ") + 1]
            self.print_len += len(printable_text)
        return printable_text

70
    def end(self) -> Optional[str]:
chenxl's avatar
chenxl committed
71
72
73
74
75
76
77
78
79
80
81
        """Flushes any remaining cache and prints a newline to stdout."""
        # Flush the cache, if it exists
        if len(self.token_cache) > 0:
            text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
            printable_text = text[self.print_len :]
            self.reset()
        else:
            printable_text = ""

        self.next_tokens_are_prompt = True
        return printable_text
82

chenxl's avatar
chenxl committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    def _is_chinese_char(self, cp):
        """Checks whether CP is the codepoint of a CJK character."""
        # This defines a "chinese character" as anything in the CJK Unicode block:
        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
        #
        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
        # despite its name. The modern Korean Hangul alphabet is a different block,
        # as is Japanese Hiragana and Katakana. Those alphabets are used to write
        # space-separated words, so they are not treated specially and handled
        # like the all of the other languages.
        if (
            (cp >= 0x4E00 and cp <= 0x9FFF)
            or (cp >= 0x3400 and cp <= 0x4DBF)  #
            or (cp >= 0x20000 and cp <= 0x2A6DF)  #
            or (cp >= 0x2A700 and cp <= 0x2B73F)  #
            or (cp >= 0x2B740 and cp <= 0x2B81F)  #
            or (cp >= 0x2B820 and cp <= 0x2CEAF)  #
            or (cp >= 0xF900 and cp <= 0xFAFF)
            or (cp >= 0x2F800 and cp <= 0x2FA1F)  #
        ):  #
            return True

        return False


108
class TransformersThreadContext(ThreadContext):
chenxl's avatar
chenxl committed
109
110
111
    def get_local_messages(self):
        local_messages = []
        for m in self.messages:
112
113
            local_messages.append({"role": m.role.value, "content": m.get_text_content()})

chenxl's avatar
chenxl committed
114
115
116
117
        return local_messages


class TransformersInterface(BackendInterfaceBase):
118
    use_static_cache: bool = True
chenxl's avatar
chenxl committed
119
120
121

    model: Any
    tokenizer: AutoTokenizer
122

chenxl's avatar
chenxl committed
123
    cache: StaticCache
124
125
126
    generated_ids: torch.Tensor
    seq_length: int

chenxl's avatar
chenxl committed
127
128
129
130
131
132
    streamer: TextStreamer

    # thread_related
    last_request_id: Optional[str] = None
    ever_generated_ids: Set[int] = set()

133
    def __init__(self, args: ConfigArgs = default_args):
chenxl's avatar
chenxl committed
134
        self.args = args
135

chenxl's avatar
chenxl committed
136
        self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
137
        self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True)
138
        # logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}")
139
140
141
142
143
144
145
146

        self.cache = StaticCache(
            config=self.model.config,
            max_batch_size=args.batch_size,
            max_cache_len=args.cache_lens,
            device=args.device,
            dtype=self.model.dtype,
        )
147
        # logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}")
chenxl's avatar
chenxl committed
148

149
        self.streamer = TextStreamer(self.tokenizer)
chenxl's avatar
chenxl committed
150
151
152

    @property
    def current_ids(self):
153
154
        return self.generated_ids[:, self.seq_length - 1].unsqueeze(1)

chenxl's avatar
chenxl committed
155
156
    @property
    def active_cache_position(self):
157
        return torch.tensor([self.seq_length - 1], device=self.args.device)
chenxl's avatar
chenxl committed
158

159
160
    def tokenize_prompt(self, prompt: str):
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
chenxl's avatar
chenxl committed
161
162
        return input_ids

163
    def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
chenxl's avatar
chenxl committed
164
        for m in messages:
165
166
167
            if m["role"] == "system":
                logger.warning(f'change {m["role"]} to user')
                m["role"] = "user"
anyanqilin's avatar
anyanqilin committed
168

chenxl's avatar
chenxl committed
169
        new_messages = [messages[0]]
170
171
172
173
        for m in messages[1:]:
            if m["role"] == "user" and new_messages[-1]["role"] == "user":
                logger.warning("merge two adjacent user messages")
                new_messages[-1]["content"] += m["content"]
chenxl's avatar
chenxl committed
174
            else:
175
                new_messages.append(m)
176
177
178
179
180
181
182
        # if (self.last_request_id is not None) and self.last_request_id == thread_id:
        #     input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device)
        # else:
        #     input_ids = self.tokenizer.apply_chat_template(
        #         new_messages, return_tensors="pt", add_generation_prompt=True
        #     ).to(self.args.device)
        input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
chenxl's avatar
chenxl committed
183
        if (self.last_request_id is not None) and self.last_request_id == thread_id:
184
185
186
187
188
189
190
191
192
            x = self.generated_ids[:,:self.seq_length]
            y = input_ids[:,:self.seq_length]
            # We can only hope that the input_ids are the same
            unequal_mask = torch.ne(x,y)
            unequal_positions = torch.nonzero(unequal_mask)
            num_unequal_elements = unequal_mask.sum().item()
            logger.warning(f'num_unequal_elements: {num_unequal_elements}') 

            input_ids = input_ids[:,self.seq_length:]
193
        logger.debug(f"get input ids of shape {input_ids.shape}")
chenxl's avatar
chenxl committed
194
        return input_ids
195
196
197
198

    def append_new_tokens(self, new_tokens: int) -> Optional[str]:
        self.generated_ids[0, self.seq_length] = new_tokens
        self.seq_length += 1
chenxl's avatar
chenxl committed
199
200
        return self.streamer.put(new_tokens)

201
    def logits_to_token(self, logits: torch.Tensor):
Azure's avatar
Azure committed
202
        logits = logits / self.args.temperature if self.args.temperature!=0 else logits
chenxl's avatar
chenxl committed
203
204
205
206
207
208
209
210

        for token_idx in self.ever_generated_ids:
            if logits[token_idx] < 0:
                logits[token_idx] *= self.args.repetition_penalty
            else:
                logits[token_idx] /= self.args.repetition_penalty

        probs = torch.nn.functional.softmax(logits, dim=-1)
211

chenxl's avatar
chenxl committed
212
213
214
215
216
217
218
219
220
221
222
223
        sample = True
        if sample:
            last = torch.multinomial(probs, num_samples=1)
        else:
            _, last = torch.topk(probs, k=1, dim=-1)

        last = last.item()
        self.ever_generated_ids.add(last)
        return last

    def decode_one_tokens(self):
        if self.use_static_cache:
224
            mask = torch.ones((1, self.seq_length)).to(self.args.device)
chenxl's avatar
chenxl committed
225
226
227
228
229
230
            logits = self.model(
                self.current_ids,
                cache_position=self.active_cache_position,
                past_key_values=self.cache,
                attention_mask=mask,
                return_dict=False,
231
                use_cache=True,
chenxl's avatar
chenxl committed
232
233
            )[0]
        else:
234
235
            logits = self.model(self.current_ids, return_dict=False)[0]
        logits = logits[0, -1, :]
chenxl's avatar
chenxl committed
236
237
238
239

        return self.logits_to_token(logits)

    @torch.no_grad
240
    def prefill(self, input_ids: torch.Tensor, is_new: bool):
chenxl's avatar
chenxl committed
241
        input_ids_length = input_ids.shape[-1]
242
243
        self.profiler.set_counter("prefill", input_ids_length)
        logger.debug(f"input_ids: {input_ids.shape}")
chenxl's avatar
chenxl committed
244
245
246
247
248
249
250

        if is_new:
            self.cache.reset()
            self.ever_generated_ids.clear()
            former_seq_length = 0
            self.seq_length = input_ids_length
            self.generated_ids = torch.zeros(
251
252
253
254
255
                self.args.batch_size,
                self.seq_length + self.args.max_new_tokens + 1,
                dtype=torch.int,
                device=self.args.device,
            )
chenxl's avatar
chenxl committed
256
        else:
257
            logger.debug(f"generate_ids: {self.generated_ids.shape}")
chenxl's avatar
chenxl committed
258
259
            former_seq_length = self.seq_length
            self.seq_length += input_ids_length
260
            expected_length = self.seq_length + self.args.max_new_tokens + 1
chenxl's avatar
chenxl committed
261
            delta_length = expected_length - self.generated_ids.shape[-1]
262
            if delta_length > 0:
chenxl's avatar
chenxl committed
263
264
265
                new_generate_ids = torch.zeros(
                    self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
                )
266
267
268
269
                self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
        logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
        cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
        self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
chenxl's avatar
chenxl committed
270

271
        mask = torch.ones((1, self.seq_length)).to(self.args.device)
chenxl's avatar
chenxl committed
272
        device = input_ids.device
273
        if not (type(self) is TransformersInterface):
chenxl's avatar
chenxl committed
274
275
276
277
            input_ids = input_ids.to("cpu")
        inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
        if self.use_static_cache:
            logits = self.model(
278
279
280
281
282
283
                inputs_embeds=inputs_embeds,
                cache_position=cache_position,
                past_key_values=self.cache,
                return_dict=False,
                use_cache=True,
                attention_mask=mask,
chenxl's avatar
chenxl committed
284
285
            )[0]
        else:
286
            logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
chenxl's avatar
chenxl committed
287

288
        next_token = self.logits_to_token(logits[0, -1, :])
chenxl's avatar
chenxl committed
289
290
291
292
        yield self.append_new_tokens(next_token)

    @torch.no_grad
    def generate(self):
293
        self.profiler.set_counter("decode", 0)
chenxl's avatar
chenxl committed
294
295
296
        for _ in range(1, self.args.max_new_tokens):
            with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
                next_token = self.decode_one_tokens()
297
                self.profiler.inc("decode")
chenxl's avatar
chenxl committed
298
299
300
301
302
303
                if next_token == self.tokenizer.eos_token_id:
                    assert self.args.batch_size == 1
                    break
                yield self.append_new_tokens(next_token)
        yield self.streamer.end()

304
    def check_is_new(self, thread_id: str):
chenxl's avatar
chenxl committed
305
306
307
308
309
310
        if not self.use_static_cache:
            return True
        if self.last_request_id is None:
            self.last_request_id = thread_id
            return True
        else:
311
            if self.last_request_id == thread_id:
chenxl's avatar
chenxl committed
312
313
314
315
316
                return False
            else:
                self.last_request_id = thread_id
                return True

317
318
319
320
321
    async def inference(self, local_messages, thread_id: str):
        self.profiler.create_and_start_timer("tokenize")
        if isinstance(local_messages, List):
            input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
        elif isinstance(local_messages, str):
Azure's avatar
Azure committed
322
            #local_messages = local_messages[0]['content']
chenxl's avatar
chenxl committed
323
            input_ids = self.tokenize_prompt(local_messages)
Azure's avatar
Azure committed
324
            #input_ids = torch.tensor([[6366]], device=input_ids.device)
chenxl's avatar
chenxl committed
325
        else:
326
            raise ValueError("local_messages should be List or str")
liam's avatar
liam committed
327
        if Config().user_force_think:
liam's avatar
liam committed
328
            token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_ids.device)
liam's avatar
liam committed
329
330
331
            input_ids = torch.cat(
                [input_ids, token_thinks], dim=1
            )
chenxl's avatar
chenxl committed
332

333
        self.profiler.pause_timer("tokenize")
chenxl's avatar
chenxl committed
334

335
        self.profiler.create_and_start_timer("prefill")
liam's avatar
liam committed
336
        if Config().user_force_think:
liam's avatar
liam committed
337
338
339
            t = "<think>\n"
            print(t,end="",flush=True)
            yield t
340
        for t in self.prefill(input_ids, self.check_is_new(thread_id)):
chenxl's avatar
chenxl committed
341
            if t is not None:
342
                print(t, end="",flush=True)
chenxl's avatar
chenxl committed
343
                yield t
344
        self.profiler.pause_timer("prefill")
chenxl's avatar
chenxl committed
345

346
        self.profiler.create_and_start_timer("decode")
chenxl's avatar
chenxl committed
347
348
        for t in self.generate():
            if t is not None:
349
                print(t, end="",flush=True)
liam's avatar
liam committed
350
                yield t 
351
352
        print("")
        self.profiler.pause_timer("decode")
chenxl's avatar
chenxl committed
353
        self.report_last_time_performance()