utils.py 9.82 KB
Newer Older
chenxl's avatar
chenxl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#!/usr/bin/env python
# coding=utf-8
'''
Description  :  
Author       : Boxin Zhang, Azure-Tang
Version      : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. 
'''
import torch
from torch import nn
import itertools
import time
import enum
from ktransformers.util.custom_gguf import translate_name_to_gguf
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.operators import base_operator
from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer

def set_module(model, submodule_key, module):
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        if hasattr(cur_mod, s):
            cur_mod = getattr(cur_mod, s)
        else: # nn.ModuleList or nn.ModuleList
            cur_mod=cur_mod[int(s)]
    if hasattr(cur_mod, tokens[-1]):
        setattr(cur_mod, tokens[-1], module)
    else: # nn.ModuleList or nn.ModuleList
        cur_mod[int(tokens[-1])] = module

def set_param(module: nn.Module, name: str, weights: torch.Tensor):
    
    param=nn.parameter.Parameter(weights, requires_grad=False)
    if isinstance(module, nn.Linear) and len(weights.shape)==1:
        param.unsqueeze_(0)
    setattr(module, name, param)

chenxl's avatar
chenxl committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def get_device(gguf_module_key:str, device_map:dict):
    if gguf_module_key in device_map:
        return device_map[gguf_module_key]["generate_device"]
    else:
        return "cuda"

def get_all_used_cuda_device(device_map:dict):
    all_device_list = set()
    for key in device_map:
        all_device_list.add(device_map[key]["generate_device"]) if "generate_device" in device_map[key] else None
        all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
    if "cpu" in all_device_list:
        all_device_list.remove("cpu")
    all_device_list = list(all_device_list)
    return all_device_list

chenxl's avatar
chenxl committed
58
59
60
61
62
63
64
65
66
67
def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""):
    prefix = prefix.replace("orig_module.", "")
    persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
    local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
    local_state = {k: v for k, v in local_name_params if v is not None}
    for name, param in local_state.items():
        key = prefix + name
        translated_key = translate_name_to_gguf(key)
        if translated_key in gguf_loader.tensor_file_map:
            target_dtype = torch.get_default_dtype()
chenxl's avatar
chenxl committed
68
69
70
            device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
            print(f"loading {translated_key} to {device}")
            # device = "cpu" if "embd" in translated_key else "cuda"
chenxl's avatar
chenxl committed
71
72
73
74
75
            weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
            set_param(module, name, weights)
            del weights
        else:
            #print(load_config.tensor_file_map.keys())
chenxl's avatar
chenxl committed
76
            raise Exception(f"can't find {translated_key} in GGUF file!")
chenxl's avatar
chenxl committed
77
        
chenxl's avatar
chenxl committed
78
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
chenxl's avatar
chenxl committed
79
80
81
82
83
84
85
    # print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}")
    if not isinstance(module, base_operator.BaseInjectedModule):
        load_cur_state_dict(module, gguf_loader, prefix)
        for name, child in module._modules.items():
            load_weights(child, gguf_loader, prefix+name+".")
    else:
        module.load()
chenxl's avatar
chenxl committed
86

chenxl's avatar
chenxl committed
87
88
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
                         mode = 'normal'):
chenxl's avatar
chenxl committed
89
90
91
92
    import os
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    torch._dynamo.config.suppress_errors = True
    batch_size, seq_length = inputs.shape
93
    device_map = model.gguf_loader.tensor_device_map
chenxl's avatar
chenxl committed
94
95
96
97
98
    torch_device = get_device('blk.0.self_attn', device_map)
    torch_device = "cuda:0" if torch_device == "cuda" else torch_device
    inputs = inputs.to(torch_device)
    all_cuda_device = get_all_used_cuda_device(device_map)

chenxl's avatar
chenxl committed
99
100
    tokens = []
    
chenxl's avatar
chenxl committed
101
102
103
104
105
106
107
108
109
110
111
112
113
    def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True):
        if use_cuda_graph:
            logits = cuda_graph_runner(cur_token, position_ids, cache_position)
        else:
            # custom_stream = torch.cuda.Stream()
            torch.cuda.set_device(torch_device)
            inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
            # with torch.cuda.stream(custom_stream):
            logits=model(inputs_embeds=inputs_embeds,
                        position_ids=position_ids,
                        cache_position=cache_position,
                        past_key_values=past_key_values,
                        return_dict=False, use_cache=True)[0]
chenxl's avatar
chenxl committed
114
115
        if past_key_values != None:
            past_key_values.change_seq_length(1)
chenxl's avatar
chenxl committed
116
117
        for device in all_cuda_device:
            torch.cuda.synchronize(device)
chenxl's avatar
chenxl committed
118
119
120
121
122
123
124
125
        #print(logits)
        next_token_scores = logits_warper(inputs, logits[:, -1, :])
        if generation_config.do_sample:
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            next_token = torch.argmax(next_token_scores, dim=-1)
        return next_token
chenxl's avatar
chenxl committed
126
127
    
    torch.cuda.set_device(torch_device)
chenxl's avatar
chenxl committed
128
129
    with torch.no_grad():
        stream = TextStreamer(tokenizer)
chenxl's avatar
chenxl committed
130
131
132
133
134
135
        if mode != 'long_context':
            past_key_values = StaticCache(
                config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
            )
        else:
            past_key_values = None
chenxl's avatar
chenxl committed
136
137
138
139
140
        cache_position = torch.arange(seq_length, device=torch_device)
        generated_ids = torch.zeros(
            batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
        )
        generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
chenxl's avatar
chenxl committed
141
142
        if past_key_values != None:
            past_key_values.cur_idx=cache_position
chenxl's avatar
chenxl committed
143
144
        start_time = time.time()

chenxl's avatar
chenxl committed
145
        inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
chenxl's avatar
chenxl committed
146
147
148
149
        if mode == "long_context":
            inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
        else:
            inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
chenxl's avatar
chenxl committed
150
151
        logits = model(
            inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
chenxl's avatar
chenxl committed
152
        )[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
chenxl's avatar
chenxl committed
153
154
155
156
157
158
        generation_config, model_kwargs = model._prepare_generation_config(
            None, max_length=max_new_tokens,
            do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config
        )
        try: # transformers==4.43
            logits_warper = (
chenxl's avatar
chenxl committed
159
                model._get_logits_warper(generation_config,device=inputs.device)
chenxl's avatar
chenxl committed
160
161
162
            )
        except: 
            logits_warper = (
chenxl's avatar
chenxl committed
163
                model._get_logits_warper(generation_config)
chenxl's avatar
chenxl committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
            )
        next_token_scores = logits_warper(inputs, logits[:, -1, :])
        if generation_config.do_sample:
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            next_token = torch.argmax(next_token_scores, dim=-1)
        first_token_time = time.time() - start_time

        prefill_count = seq_length
        prefill_time = first_token_time
        print(stream.put(next_token.item()), end="", flush=True)
        generated_ids[:, seq_length] = next_token
        tokens.append(next_token)
        inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
        cache_position = torch.tensor([seq_length], device=torch_device)
        position_ids = cache_position.unsqueeze(0)
        seq_length += 1
chenxl's avatar
chenxl committed
182
183
184
185
186
187
188
        
        if use_cuda_graph:
            cuda_graph_runner = CUDAGraphRunner()
            cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
        else:
            cuda_graph_runner = None
            
chenxl's avatar
chenxl committed
189
190
        start_time = time.time()
        for _ in range(1, max_new_tokens):
chenxl's avatar
chenxl committed
191
            next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
chenxl's avatar
chenxl committed
192
193
194
195
196
            inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
            generated_ids[:, cache_position] = next_token.int()
            tokens.append(next_token.int())
            seq_length += 1
            
chenxl's avatar
chenxl committed
197
            if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>':
chenxl's avatar
chenxl committed
198
199
200
201
202
203
                print(stream.end(), end="", flush=True)
                break
            else:
                print(stream.put(next_token.item()), end="", flush=True)
            cache_position += 1
            position_ids = cache_position.unsqueeze(0)
chenxl's avatar
chenxl committed
204
        
chenxl's avatar
chenxl committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

    total_time = time.time() - start_time
    tokens_generated = len(tokens)
    tokens_per_second = tokens_generated / total_time

    print("")

    print(f"prompt eval count:    {prefill_count} token(s)")
    print(f"prompt eval duration: {prefill_time}s")
    print(f"prompt eval rate:     {prefill_count/prefill_time} tokens/s")
    print(f"eval count:           {tokens_generated} token(s)")
    print(f"eval duration:        {total_time}s")
    print(f"eval rate:            {tokens_per_second} tokens/s")

    return tokens

class InferenceState(enum.Enum):
    UNLOAD = 0
    PREFILL = 1
    GENERATE = 2
    RESTORE = 3