transformers.py 13.4 KB
Newer Older
chenxl's avatar
chenxl committed
1
from typing import Any, List, Optional, Set
2
3
4
5
6
7
8
9
10
11
from transformers import (
    LlamaTokenizer,
    AutoTokenizer,
    AutoConfig,
    LlamaForCausalLM,
    GenerationConfig,
    StaticCache,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
chenxl's avatar
chenxl committed
12
13
14
15
16

from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.utils.multi_timer import Profiler
import torch
import sys, os
17
from ..base import ThreadContext, BackendInterfaceBase
chenxl's avatar
chenxl committed
18
from ktransformers.server.config.log import logger
19
from ..args import ConfigArgs, default_args
chenxl's avatar
chenxl committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
class TextStreamer:

    def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
        self.tokenizer = tokenizer
        self.skip_prompt = skip_prompt
        self.decode_kwargs = decode_kwargs

        # variables used in the streaming process
        self.token_cache = []
        self.print_len = 0
        self.next_tokens_are_prompt = True

    def reset(self):
        self.token_cache = []
        self.print_len = 0

39
    def put(self, value) -> Optional[str]:
chenxl's avatar
chenxl committed
40
41
        """
        Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
42
43
        """
        if not isinstance(value, int):
chenxl's avatar
chenxl committed
44
45
46
47
48
49
50
51
            raise ValueError("TextStreamer only supports batch size 1, and int type input")

        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return None

        # Add the new token to the cache and decodes the entire thing.
        self.token_cache.append(value)
52
        text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
chenxl's avatar
chenxl committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

        # After the symbol for a new line, we flush the cache.
        if text.endswith("\n"):
            printable_text = text[self.print_len :]
            self.reset()
        # If the last token is a CJK character, we print the characters.
        elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
            printable_text = text[self.print_len :]
            self.print_len += len(printable_text)
        # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
        # which may change with the subsequent token -- there are probably smarter ways to do this!)
        else:
            printable_text = text[self.print_len : text.rfind(" ") + 1]
            self.print_len += len(printable_text)
        return printable_text

69
    def end(self) -> Optional[str]:
chenxl's avatar
chenxl committed
70
71
72
73
74
75
76
77
78
79
80
        """Flushes any remaining cache and prints a newline to stdout."""
        # Flush the cache, if it exists
        if len(self.token_cache) > 0:
            text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
            printable_text = text[self.print_len :]
            self.reset()
        else:
            printable_text = ""

        self.next_tokens_are_prompt = True
        return printable_text
81

chenxl's avatar
chenxl committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    def _is_chinese_char(self, cp):
        """Checks whether CP is the codepoint of a CJK character."""
        # This defines a "chinese character" as anything in the CJK Unicode block:
        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
        #
        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
        # despite its name. The modern Korean Hangul alphabet is a different block,
        # as is Japanese Hiragana and Katakana. Those alphabets are used to write
        # space-separated words, so they are not treated specially and handled
        # like the all of the other languages.
        if (
            (cp >= 0x4E00 and cp <= 0x9FFF)
            or (cp >= 0x3400 and cp <= 0x4DBF)  #
            or (cp >= 0x20000 and cp <= 0x2A6DF)  #
            or (cp >= 0x2A700 and cp <= 0x2B73F)  #
            or (cp >= 0x2B740 and cp <= 0x2B81F)  #
            or (cp >= 0x2B820 and cp <= 0x2CEAF)  #
            or (cp >= 0xF900 and cp <= 0xFAFF)
            or (cp >= 0x2F800 and cp <= 0x2FA1F)  #
        ):  #
            return True

        return False


107
class TransformersThreadContext(ThreadContext):
chenxl's avatar
chenxl committed
108
109
110
    def get_local_messages(self):
        local_messages = []
        for m in self.messages:
111
112
            local_messages.append({"role": m.role.value, "content": m.get_text_content()})

chenxl's avatar
chenxl committed
113
114
115
116
        return local_messages


class TransformersInterface(BackendInterfaceBase):
117
    use_static_cache: bool = True
chenxl's avatar
chenxl committed
118
119
120

    model: Any
    tokenizer: AutoTokenizer
121

chenxl's avatar
chenxl committed
122
    cache: StaticCache
123
124
125
    generated_ids: torch.Tensor
    seq_length: int

chenxl's avatar
chenxl committed
126
127
128
129
130
131
    streamer: TextStreamer

    # thread_related
    last_request_id: Optional[str] = None
    ever_generated_ids: Set[int] = set()

132
    def __init__(self, args: ConfigArgs = default_args):
chenxl's avatar
chenxl committed
133
        self.args = args
134

chenxl's avatar
chenxl committed
135
        self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
136
        self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True)
137
        # logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}")
138
139
140
141
142
143
144
145

        self.cache = StaticCache(
            config=self.model.config,
            max_batch_size=args.batch_size,
            max_cache_len=args.cache_lens,
            device=args.device,
            dtype=self.model.dtype,
        )
146
        # logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}")
chenxl's avatar
chenxl committed
147

148
        self.streamer = TextStreamer(self.tokenizer)
chenxl's avatar
chenxl committed
149
150
151

    @property
    def current_ids(self):
152
153
        return self.generated_ids[:, self.seq_length - 1].unsqueeze(1)

chenxl's avatar
chenxl committed
154
155
    @property
    def active_cache_position(self):
156
        return torch.tensor([self.seq_length - 1], device=self.args.device)
chenxl's avatar
chenxl committed
157

158
159
    def tokenize_prompt(self, prompt: str):
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
chenxl's avatar
chenxl committed
160
161
        return input_ids

162
    def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
chenxl's avatar
chenxl committed
163
        for m in messages:
164
165
166
            if m["role"] == "system":
                logger.warning(f'change {m["role"]} to user')
                m["role"] = "user"
anyanqilin's avatar
anyanqilin committed
167

chenxl's avatar
chenxl committed
168
        new_messages = [messages[0]]
169
170
171
172
        for m in messages[1:]:
            if m["role"] == "user" and new_messages[-1]["role"] == "user":
                logger.warning("merge two adjacent user messages")
                new_messages[-1]["content"] += m["content"]
chenxl's avatar
chenxl committed
173
            else:
174
                new_messages.append(m)
175
176
177
178
179
180
181
        # if (self.last_request_id is not None) and self.last_request_id == thread_id:
        #     input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device)
        # else:
        #     input_ids = self.tokenizer.apply_chat_template(
        #         new_messages, return_tensors="pt", add_generation_prompt=True
        #     ).to(self.args.device)
        input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
chenxl's avatar
chenxl committed
182
        if (self.last_request_id is not None) and self.last_request_id == thread_id:
183
184
185
186
187
188
189
190
191
            x = self.generated_ids[:,:self.seq_length]
            y = input_ids[:,:self.seq_length]
            # We can only hope that the input_ids are the same
            unequal_mask = torch.ne(x,y)
            unequal_positions = torch.nonzero(unequal_mask)
            num_unequal_elements = unequal_mask.sum().item()
            logger.warning(f'num_unequal_elements: {num_unequal_elements}') 

            input_ids = input_ids[:,self.seq_length:]
192
        logger.debug(f"get input ids of shape {input_ids.shape}")
chenxl's avatar
chenxl committed
193
        return input_ids
194
195
196
197

    def append_new_tokens(self, new_tokens: int) -> Optional[str]:
        self.generated_ids[0, self.seq_length] = new_tokens
        self.seq_length += 1
chenxl's avatar
chenxl committed
198
199
        return self.streamer.put(new_tokens)

200
201
    def logits_to_token(self, logits: torch.Tensor):
        logits = logits / self.args.temperature
chenxl's avatar
chenxl committed
202
203
204
205
206
207
208
209

        for token_idx in self.ever_generated_ids:
            if logits[token_idx] < 0:
                logits[token_idx] *= self.args.repetition_penalty
            else:
                logits[token_idx] /= self.args.repetition_penalty

        probs = torch.nn.functional.softmax(logits, dim=-1)
210

chenxl's avatar
chenxl committed
211
212
213
214
215
216
217
218
219
220
221
222
        sample = True
        if sample:
            last = torch.multinomial(probs, num_samples=1)
        else:
            _, last = torch.topk(probs, k=1, dim=-1)

        last = last.item()
        self.ever_generated_ids.add(last)
        return last

    def decode_one_tokens(self):
        if self.use_static_cache:
223
            mask = torch.ones((1, self.seq_length)).to(self.args.device)
chenxl's avatar
chenxl committed
224
225
226
227
228
229
            logits = self.model(
                self.current_ids,
                cache_position=self.active_cache_position,
                past_key_values=self.cache,
                attention_mask=mask,
                return_dict=False,
230
                use_cache=True,
chenxl's avatar
chenxl committed
231
232
            )[0]
        else:
233
234
            logits = self.model(self.current_ids, return_dict=False)[0]
        logits = logits[0, -1, :]
chenxl's avatar
chenxl committed
235
236
237
238

        return self.logits_to_token(logits)

    @torch.no_grad
239
    def prefill(self, input_ids: torch.Tensor, is_new: bool):
chenxl's avatar
chenxl committed
240
        input_ids_length = input_ids.shape[-1]
241
242
        self.profiler.set_counter("prefill", input_ids_length)
        logger.debug(f"input_ids: {input_ids.shape}")
chenxl's avatar
chenxl committed
243
244
245
246
247
248
249

        if is_new:
            self.cache.reset()
            self.ever_generated_ids.clear()
            former_seq_length = 0
            self.seq_length = input_ids_length
            self.generated_ids = torch.zeros(
250
251
252
253
254
                self.args.batch_size,
                self.seq_length + self.args.max_new_tokens + 1,
                dtype=torch.int,
                device=self.args.device,
            )
chenxl's avatar
chenxl committed
255
        else:
256
            logger.debug(f"generate_ids: {self.generated_ids.shape}")
chenxl's avatar
chenxl committed
257
258
            former_seq_length = self.seq_length
            self.seq_length += input_ids_length
259
            expected_length = self.seq_length + self.args.max_new_tokens + 1
chenxl's avatar
chenxl committed
260
            delta_length = expected_length - self.generated_ids.shape[-1]
261
            if delta_length > 0:
chenxl's avatar
chenxl committed
262
263
264
                new_generate_ids = torch.zeros(
                    self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
                )
265
266
267
268
                self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
        logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
        cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
        self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
chenxl's avatar
chenxl committed
269

270
        mask = torch.ones((1, self.seq_length)).to(self.args.device)
chenxl's avatar
chenxl committed
271
        device = input_ids.device
272
        if not (type(self) is TransformersInterface):
chenxl's avatar
chenxl committed
273
274
275
276
            input_ids = input_ids.to("cpu")
        inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
        if self.use_static_cache:
            logits = self.model(
277
278
279
280
281
282
                inputs_embeds=inputs_embeds,
                cache_position=cache_position,
                past_key_values=self.cache,
                return_dict=False,
                use_cache=True,
                attention_mask=mask,
chenxl's avatar
chenxl committed
283
284
            )[0]
        else:
285
            logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
chenxl's avatar
chenxl committed
286

287
        next_token = self.logits_to_token(logits[0, -1, :])
chenxl's avatar
chenxl committed
288
289
290
291
        yield self.append_new_tokens(next_token)

    @torch.no_grad
    def generate(self):
292
        self.profiler.set_counter("decode", 0)
chenxl's avatar
chenxl committed
293
294
295
        for _ in range(1, self.args.max_new_tokens):
            with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
                next_token = self.decode_one_tokens()
296
                self.profiler.inc("decode")
chenxl's avatar
chenxl committed
297
298
299
300
301
302
                if next_token == self.tokenizer.eos_token_id:
                    assert self.args.batch_size == 1
                    break
                yield self.append_new_tokens(next_token)
        yield self.streamer.end()

303
    def check_is_new(self, thread_id: str):
chenxl's avatar
chenxl committed
304
305
306
307
308
309
        if not self.use_static_cache:
            return True
        if self.last_request_id is None:
            self.last_request_id = thread_id
            return True
        else:
310
            if self.last_request_id == thread_id:
chenxl's avatar
chenxl committed
311
312
313
314
315
                return False
            else:
                self.last_request_id = thread_id
                return True

316
317
318
319
320
    async def inference(self, local_messages, thread_id: str):
        self.profiler.create_and_start_timer("tokenize")
        if isinstance(local_messages, List):
            input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
        elif isinstance(local_messages, str):
chenxl's avatar
chenxl committed
321
322
            input_ids = self.tokenize_prompt(local_messages)
        else:
323
            raise ValueError("local_messages should be List or str")
chenxl's avatar
chenxl committed
324

325
        self.profiler.pause_timer("tokenize")
chenxl's avatar
chenxl committed
326

327
328
        self.profiler.create_and_start_timer("prefill")
        for t in self.prefill(input_ids, self.check_is_new(thread_id)):
chenxl's avatar
chenxl committed
329
            if t is not None:
330
                print(t, end="")
chenxl's avatar
chenxl committed
331
                yield t
332
        self.profiler.pause_timer("prefill")
chenxl's avatar
chenxl committed
333

334
        self.profiler.create_and_start_timer("decode")
chenxl's avatar
chenxl committed
335
336
        for t in self.generate():
            if t is not None:
337
                print(t, end="")
chenxl's avatar
chenxl committed
338
                yield t
339
340
        print("")
        self.profiler.pause_timer("decode")
chenxl's avatar
chenxl committed
341
        self.report_last_time_performance()