utils.py 11.8 KB
Newer Older
chenxl's avatar
chenxl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#!/usr/bin/env python
# coding=utf-8
'''
Description  :  
Author       : Boxin Zhang, Azure-Tang
Version      : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. 
'''
import torch
from torch import nn
import itertools
import time
import enum
from ktransformers.util.custom_gguf import translate_name_to_gguf
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.operators import base_operator
from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer
20
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
chenxl's avatar
chenxl committed
21

Atream's avatar
Atream committed
22
23
warm_uped = False

24
25
26
27
28
29
30
31
32
33
34
35
def get_compute_capability(device:torch.device = None):
    if torch.cuda.is_available():
        if device is None:
            num_gpus = torch.cuda.device_count()
            min_compute_capability_major = 100
            for gpu_id in range(num_gpus):
                gpu_props = torch.cuda.get_device_properties(gpu_id)
                min_compute_capability_major = min(min_compute_capability_major, gpu_props.major)
            return min_compute_capability_major
        else:
            return torch.cuda.get_device_properties(device)

chenxl's avatar
chenxl committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def set_module(model, submodule_key, module):
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        if hasattr(cur_mod, s):
            cur_mod = getattr(cur_mod, s)
        else: # nn.ModuleList or nn.ModuleList
            cur_mod=cur_mod[int(s)]
    if hasattr(cur_mod, tokens[-1]):
        setattr(cur_mod, tokens[-1], module)
    else: # nn.ModuleList or nn.ModuleList
        cur_mod[int(tokens[-1])] = module

def set_param(module: nn.Module, name: str, weights: torch.Tensor):
    
    param=nn.parameter.Parameter(weights, requires_grad=False)
    if isinstance(module, nn.Linear) and len(weights.shape)==1:
        param.unsqueeze_(0)
    setattr(module, name, param)

chenxl's avatar
chenxl committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def get_device(gguf_module_key:str, device_map:dict):
    if gguf_module_key in device_map:
        return device_map[gguf_module_key]["generate_device"]
    else:
        return "cuda"

def get_all_used_cuda_device(device_map:dict):
    all_device_list = set()
    for key in device_map:
        all_device_list.add(device_map[key]["generate_device"]) if "generate_device" in device_map[key] else None
        all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
    if "cpu" in all_device_list:
        all_device_list.remove("cpu")
    all_device_list = list(all_device_list)
    return all_device_list

chenxl's avatar
chenxl committed
73
74
75
76
77
78
79
80
def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""):
    prefix = prefix.replace("orig_module.", "")
    persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
    local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
    local_state = {k: v for k, v in local_name_params if v is not None}
    for name, param in local_state.items():
        key = prefix + name
        translated_key = translate_name_to_gguf(key)
81
82
83
84
85
86
87
88
89
90
91
        
        # TODO: Merge all loader.
        # I know this is ugly but lets do it for now.
        if gguf_loader.safetensor_loader is not None:
            load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
            tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
        else:
            load_dequantized_tensor = gguf_loader.load_gguf_tensor
            tensor_file_map = gguf_loader.tensor_file_map
        
        if translated_key in tensor_file_map:
chenxl's avatar
chenxl committed
92
            target_dtype = torch.get_default_dtype()
chenxl's avatar
chenxl committed
93
94
            device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
            print(f"loading {translated_key} to {device}")
95
            torch.cuda.empty_cache()
96
            weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
chenxl's avatar
chenxl committed
97
98
99
100
            set_param(module, name, weights)
            del weights
        else:
            #print(load_config.tensor_file_map.keys())
chenxl's avatar
chenxl committed
101
            raise Exception(f"can't find {translated_key} in GGUF file!")
chenxl's avatar
chenxl committed
102
        
chenxl's avatar
chenxl committed
103
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
104
    #print(f"recursively loading weights {prefix}")
chenxl's avatar
chenxl committed
105
106
107
108
109
110
    if not isinstance(module, base_operator.BaseInjectedModule):
        load_cur_state_dict(module, gguf_loader, prefix)
        for name, child in module._modules.items():
            load_weights(child, gguf_loader, prefix+name+".")
    else:
        module.load()
chenxl's avatar
chenxl committed
111

chenxl's avatar
chenxl committed
112
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
113
114
                         mode = 'normal', force_think: bool = False, use_flashinfer_mla = False,
                         num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
chenxl's avatar
chenxl committed
115
116
117
118
    import os
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    torch._dynamo.config.suppress_errors = True
    batch_size, seq_length = inputs.shape
119
    device_map = model.gguf_loader.tensor_device_map
chenxl's avatar
chenxl committed
120
121
122
123
124
    torch_device = get_device('blk.0.self_attn', device_map)
    torch_device = "cuda:0" if torch_device == "cuda" else torch_device
    inputs = inputs.to(torch_device)
    all_cuda_device = get_all_used_cuda_device(device_map)

chenxl's avatar
chenxl committed
125
126
    tokens = []
    
chenxl's avatar
chenxl committed
127
    def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True):
Atream's avatar
Atream committed
128
129
        if cuda_graph_runner is None:
            use_cuda_graph = False
chenxl's avatar
chenxl committed
130
131
132
133
134
135
136
137
138
139
140
141
        if use_cuda_graph:
            logits = cuda_graph_runner(cur_token, position_ids, cache_position)
        else:
            # custom_stream = torch.cuda.Stream()
            torch.cuda.set_device(torch_device)
            inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
            # with torch.cuda.stream(custom_stream):
            logits=model(inputs_embeds=inputs_embeds,
                        position_ids=position_ids,
                        cache_position=cache_position,
                        past_key_values=past_key_values,
                        return_dict=False, use_cache=True)[0]
chenxl's avatar
chenxl committed
142
143
        if past_key_values != None:
            past_key_values.change_seq_length(1)
chenxl's avatar
chenxl committed
144
145
        for device in all_cuda_device:
            torch.cuda.synchronize(device)
chenxl's avatar
chenxl committed
146
147
148
149
150
151
152
153
        #print(logits)
        next_token_scores = logits_warper(inputs, logits[:, -1, :])
        if generation_config.do_sample:
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            next_token = torch.argmax(next_token_scores, dim=-1)
        return next_token
chenxl's avatar
chenxl committed
154
155
    
    torch.cuda.set_device(torch_device)
chenxl's avatar
chenxl committed
156
157
    with torch.no_grad():
        stream = TextStreamer(tokenizer)
chenxl's avatar
chenxl committed
158
159
160
161
162
163
        if mode != 'long_context':
            past_key_values = StaticCache(
                config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
            )
        else:
            past_key_values = None
164
        cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
chenxl's avatar
chenxl committed
165
166
167
168
        generated_ids = torch.zeros(
            batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
        )
        generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
chenxl's avatar
chenxl committed
169
170
        if past_key_values != None:
            past_key_values.cur_idx=cache_position
chenxl's avatar
chenxl committed
171
172
        start_time = time.time()

chenxl's avatar
chenxl committed
173
        inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
chenxl's avatar
chenxl committed
174
175
176
177
        if mode == "long_context":
            inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
        else:
            inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
178
        if use_flashinfer_mla:
179
            MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
180
181
            MLAWrapperSingleton.need_plan_all()
            
chenxl's avatar
chenxl committed
182
183
        logits = model(
            inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
chenxl's avatar
chenxl committed
184
        )[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
chenxl's avatar
chenxl committed
185
        generation_config, model_kwargs = model._prepare_generation_config(
186
187
188
            None, do_sample=True
            # change this to modify generate config
            #top_k=5, top_p=0.85, temperature=0.1
chenxl's avatar
chenxl committed
189
190
191
        )
        try: # transformers==4.43
            logits_warper = (
chenxl's avatar
chenxl committed
192
                model._get_logits_warper(generation_config,device=inputs.device)
chenxl's avatar
chenxl committed
193
194
195
            )
        except: 
            logits_warper = (
chenxl's avatar
chenxl committed
196
                model._get_logits_warper(generation_config)
chenxl's avatar
chenxl committed
197
198
199
200
201
202
203
204
            )
        next_token_scores = logits_warper(inputs, logits[:, -1, :])
        if generation_config.do_sample:
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            next_token = torch.argmax(next_token_scores, dim=-1)
        first_token_time = time.time() - start_time
205
206
207
        
        if use_flashinfer_mla:
            MLAWrapperSingleton.reset_buffer()
chenxl's avatar
chenxl committed
208
209
210

        prefill_count = seq_length
        prefill_time = first_token_time
liam's avatar
liam committed
211
212
        if force_think:
            print("<think>\n")
chenxl's avatar
chenxl committed
213
214
        print(stream.put(next_token.item()), end="", flush=True)
        generated_ids[:, seq_length] = next_token
215
        tokens.append(int(next_token))
chenxl's avatar
chenxl committed
216
        inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
217
        cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
chenxl's avatar
chenxl committed
218
219
        position_ids = cache_position.unsqueeze(0)
        seq_length += 1
chenxl's avatar
chenxl committed
220
        
Atream's avatar
Atream committed
221
        cuda_graph_runner = None
chenxl's avatar
chenxl committed
222
            
chenxl's avatar
chenxl committed
223
        start_time = time.time()
Atream's avatar
Atream committed
224
        for i in range(1, max_new_tokens):
225
226
227
228
            if use_flashinfer_mla:
                MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
                                             num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
                                             q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
Atream's avatar
Atream committed
229
230
231
232
233
            global warm_uped
            if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
                warm_uped = True
                cuda_graph_runner = CUDAGraphRunner()
                cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
chenxl's avatar
chenxl committed
234
            next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
chenxl's avatar
chenxl committed
235
236
            inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
            generated_ids[:, cache_position] = next_token.int()
237
            tokens.append(int(next_token))
chenxl's avatar
chenxl committed
238
239
            seq_length += 1
            
Atream's avatar
Atream committed
240
            if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
chenxl's avatar
chenxl committed
241
242
243
244
245
246
                print(stream.end(), end="", flush=True)
                break
            else:
                print(stream.put(next_token.item()), end="", flush=True)
            cache_position += 1
            position_ids = cache_position.unsqueeze(0)
chenxl's avatar
chenxl committed
247
        
chenxl's avatar
chenxl committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

    total_time = time.time() - start_time
    tokens_generated = len(tokens)
    tokens_per_second = tokens_generated / total_time

    print("")

    print(f"prompt eval count:    {prefill_count} token(s)")
    print(f"prompt eval duration: {prefill_time}s")
    print(f"prompt eval rate:     {prefill_count/prefill_time} tokens/s")
    print(f"eval count:           {tokens_generated} token(s)")
    print(f"eval duration:        {total_time}s")
    print(f"eval rate:            {tokens_per_second} tokens/s")

    return tokens

class InferenceState(enum.Enum):
    UNLOAD = 0
    PREFILL = 1
    GENERATE = 2
    RESTORE = 3