ktransformers.py 10.4 KB
Newer Older
chenxl's avatar
chenxl committed
1
import torch
Creeper-MZ's avatar
Creeper-MZ committed
2
from typing import Optional, List
hrz6976's avatar
hrz6976 committed
3
import asyncio
chenxl's avatar
chenxl committed
4
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
5
6
7
8
9
10
11
from ktransformers.server.backend.interfaces.transformers import (
    TransformersInterface,
    ConfigArgs,
    TransformersThreadContext,
    default_args,
    TextStreamer,
)
chenxl's avatar
chenxl committed
12
13
14
15
16
from ktransformers.server.config.log import logger
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules
17
from ktransformers.util.utils import get_device
18
from typing import Optional
19
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
20
from ktransformers.server.schemas.endpoints.chat import RawUsage
chenxl's avatar
chenxl committed
21

ceerrep's avatar
ceerrep committed
22
23
warm_uped = False

chenxl's avatar
chenxl committed
24
25
26
27
28
class KTransformersThreadContext(TransformersThreadContext):
    pass


class KTransformersInterface(TransformersInterface):
29
    def __init__(self, args: ConfigArgs = default_args):
chenxl's avatar
chenxl committed
30
31
        self.args = args
        torch.set_grad_enabled(False)
Azure's avatar
Azure committed
32
33
        self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
        config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
qiyuxinlin's avatar
qiyuxinlin committed
34
35
36
37
38
39
        try:
            generation_config = GenerationConfig.from_pretrained(args.model_dir)
        except:
            generation_config = GenerationConfig(
                max_length=args.max_new_tokens,
                temperature=args.temperature,
Wix Woo's avatar
Wix Woo committed
40
                top_p=args.top_p,
qiyuxinlin's avatar
qiyuxinlin committed
41
42
43
                do_sample=True
            )
        
Atream's avatar
Atream committed
44
        torch.set_default_dtype(config.torch_dtype)
chenxl's avatar
chenxl committed
45
        if config.architectures[0] == "Qwen2MoeForCausalLM":
46
            config._attn_implementation = "flash_attention_2"
chenxl's avatar
chenxl committed
47
48

        with torch.device("meta"):
49
            self.model = custom_models[config.architectures[0]](config)
chenxl's avatar
chenxl committed
50
        if default_args.optimize_config_path is None:
Azure's avatar
Azure committed
51
            optimize_config_path = default_optimize_rules[config.architectures[0]]
52
        else:
Azure's avatar
Azure committed
53
            optimize_config_path = args.optimize_config_path
54

chenxl's avatar
chenxl committed
55
56
57
58
59
        # print(optimize_config)

        gguf_path = args.gguf_path
        if gguf_path is None:
            gguf_path = input(
60
61
                "please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
                " belong to current model):"
chenxl's avatar
chenxl committed
62
            )
Azure's avatar
Azure committed
63
        optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
qiyuxinlin's avatar
qiyuxinlin committed
64
        self.model.generation_config = generation_config
65
66
        self.device_map = self.model.gguf_loader.tensor_device_map
        # logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
67
68
69
70
        self.cache = StaticCache(
            config=self.model.config,
            max_batch_size=args.batch_size,
            max_cache_len=args.cache_lens,
71
            device=self.device_map,
72
73
            dtype=self.model.dtype,
        )
74
        # logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
qiyuxinlin's avatar
qiyuxinlin committed
75

chenxl's avatar
chenxl committed
76
77
78
        if self.model.generation_config.pad_token_id is None:
            self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
        self.streamer = TextStreamer(self.tokenizer)
79

hrz6976's avatar
hrz6976 committed
80
81
        self._infer_lock = asyncio.Lock()

chenxl's avatar
chenxl committed
82
    def decode_one_tokens(self):
ceerrep's avatar
ceerrep committed
83
84
        global warm_uped

Azure's avatar
Azure committed
85
86
87
        device_map = self.model.gguf_loader.tensor_device_map
        torch_device = get_device("blk.0.self_attn", device_map)
        torch_device = "cuda:0" if torch_device == "cuda" else torch_device
ceerrep's avatar
ceerrep committed
88
89
        torch.cuda.set_device(torch_device)
        if warm_uped and self.args.use_cuda_graph:
Azure's avatar
Azure committed
90
91
92
93
94
95
96
97
98
99
100
101
            if not hasattr(self, "cuda_graph_runner"):
                self.cuda_graph_runner = CUDAGraphRunner()
                self.cuda_graph_runner.capture(
                    self.model,
                    self.current_ids,
                    self.active_cache_position.unsqueeze(0),
                    self.active_cache_position,
                    self.cache,
                    main_device=torch_device,
                    return_dict=False,
                    use_cache=True,
                )
102

Azure's avatar
Azure committed
103
104
105
106
107
108
109
110
            if hasattr(self, "cuda_graph_runner"):
                logits = self.cuda_graph_runner(
                    self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
                )
                self.cache.change_seq_length(1)
                torch.cuda.synchronize()
                logits = logits[0, -1, :]
                return self.logits_to_token(logits)
Xie Weiyu's avatar
Xie Weiyu committed
111
112
113
114
        
        if self.args.use_cuda_graph:
            warm_uped = True
            
chenxl's avatar
chenxl committed
115
116
        if self.use_static_cache:
            logits = self.model(
Azure's avatar
Azure committed
117
                self.current_ids.to(torch_device),
chenxl's avatar
chenxl committed
118
119
120
                cache_position=self.active_cache_position,
                past_key_values=self.cache,
                return_dict=False,
121
                use_cache=True,
chenxl's avatar
chenxl committed
122
123
            )[0]
        else:
124
125
            logits = self.model(self.current_ids, return_dict=False)[0]
        logits = logits[0, -1, :]
chenxl's avatar
chenxl committed
126
127

        return self.logits_to_token(logits)
128
129
130
131



    @torch.no_grad
lazymio's avatar
lazymio committed
132
    def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
133
        input_ids_length = input_ids.shape[-1]
liam's avatar
liam committed
134
135
136
137
        if(input_ids_length >= self.args.cache_lens):
            logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
            self.seq_length = input_ids_length
            return
138
139
        logger.debug(f"input_ids: {input_ids.shape}")
        device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
140
        device = "cuda:0" if device == "cuda" else device
141
142
143

        if is_new:
            self.ever_generated_ids.clear()
ceerrep's avatar
ceerrep committed
144
145
146
147
148
149
150
151
152
            same_prefix = 0
            flat_input_ids = input_ids.flatten()

            if getattr(self, 'generated_ids', None) is None:
                self.generated_ids = torch.zeros(
                    self.args.batch_size,
                    input_ids.shape[-1] + self.args.max_new_tokens + 1,
                    dtype=torch.int,
                    device=self.args.device,
153
                )
ceerrep's avatar
ceerrep committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
                self.seq_length = 1            
            
            flat_prev_ids = self.generated_ids.flatten()
            for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
                if flat_input_ids[i] == flat_prev_ids[i]:
                    same_prefix += 1
                else:
                    break
            
            logger.debug(f"same prefix len: {same_prefix}")
            self.cache.remove_suffix(same_prefix)
            self.seq_length = same_prefix
            self.generated_ids = self.generated_ids[..., :same_prefix]
            input_ids = input_ids[..., same_prefix:]
            input_ids_length = input_ids.shape[-1]

        self.ever_generated_ids.clear()
        self.profiler.set_counter("prefill", input_ids_length)
        logger.debug(f"input_ids: {input_ids.shape}")
        logger.debug(f"generate_ids: {self.generated_ids.shape}")
174
        
ceerrep's avatar
ceerrep committed
175
176
        former_seq_length = self.seq_length
        self.seq_length += input_ids_length
177
        expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens)
ceerrep's avatar
ceerrep committed
178
179
180
181
182
183
        delta_length = expected_length - self.generated_ids.shape[-1]
        if delta_length > 0:
            new_generate_ids = torch.zeros(
                self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
            )
            self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
184
185
186
        else:
            logger.warning(f"seq_length bigger than cache_lens, killed")
            exit(0)
ceerrep's avatar
ceerrep committed
187
        
188
189
190
191
192
193
        logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
        cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
        self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)

        if not (type(self) is TransformersInterface):
            input_ids = input_ids.to("cpu")
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        
        def chunk_prefill(input_ids, cache_position):
            inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
            torch.cuda.set_device(device)
            if flashinfer_enabled:
                MLAWrapperSingleton.need_plan_all()
            if self.use_static_cache:
                logits = self.model(
                    inputs_embeds=inputs_embeds,
                    cache_position=cache_position,
                    past_key_values=self.cache,
                    return_dict=False,
                    use_cache=True,
                )[0]
            else:
                logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]

            return logits

        chunk_start = 0
        while chunk_start < input_ids_length:
215
            chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length)
216
217
218
            if self.cache != None:
                self.cache.cur_idx=cache_position[chunk_start:chunk_end]
            logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
219
            chunk_start += self.args.chunk_size
220
            
221
222
        if flashinfer_enabled:
            MLAWrapperSingleton.reset_buffer()
lazymio's avatar
lazymio committed
223
        self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
224
225
226
227
228
229
        next_token = self.logits_to_token(logits[0, -1, :])
        yield self.append_new_tokens(next_token)

    @property
    def active_cache_position(self):
        device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
hrz6976's avatar
hrz6976 committed
230
231
        return torch.tensor([self.seq_length - 1], device=device)
    
Creeper-MZ's avatar
Creeper-MZ committed
232
    async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
hrz6976's avatar
hrz6976 committed
233
        async with self._infer_lock:
Creeper-MZ's avatar
Creeper-MZ committed
234
            async for v in super().inference(local_messages, thread_id, temperature, top_p, tools):
ceerrep's avatar
ceerrep committed
235
                yield v
236
237
238
239
240
241
242
243
244
            
            # return this inference raw usage
            yield RawUsage(
                tokenize_time = self.profiler.get_timer_sec('tokenize'),
                prefill_time = self.profiler.get_timer_sec('prefill'),
                decode_time = self.profiler.get_timer_sec('decode'),
                prefill_count = self.profiler.get_counter('prefill'),
                decode_count = self.profiler.get_counter('decode'),
            )