ktransformers.py 7.51 KB
Newer Older
chenxl's avatar
chenxl committed
1
import torch
hrz6976's avatar
hrz6976 committed
2
import asyncio
chenxl's avatar
chenxl committed
3
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
4
5
6
7
8
9
10
from ktransformers.server.backend.interfaces.transformers import (
    TransformersInterface,
    ConfigArgs,
    TransformersThreadContext,
    default_args,
    TextStreamer,
)
chenxl's avatar
chenxl committed
11
12
13
14
15
from ktransformers.server.config.log import logger
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules
16
from ktransformers.util.utils import get_device
chenxl's avatar
chenxl committed
17
18
19
20
21
22
23


class KTransformersThreadContext(TransformersThreadContext):
    pass


class KTransformersInterface(TransformersInterface):
24
    def __init__(self, args: ConfigArgs = default_args):
chenxl's avatar
chenxl committed
25
26
27
        self.args = args
        torch.set_default_dtype(torch.bfloat16)
        torch.set_grad_enabled(False)
Azure's avatar
Azure committed
28
29
        self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
        config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
chenxl's avatar
chenxl committed
30
        if config.architectures[0] == "Qwen2MoeForCausalLM":
31
            config._attn_implementation = "flash_attention_2"
chenxl's avatar
chenxl committed
32
33

        with torch.device("meta"):
34
            self.model = custom_models[config.architectures[0]](config)
chenxl's avatar
chenxl committed
35
        if default_args.optimize_config_path is None:
36
37
38
            optimize_rule_path = default_optimize_rules[config.architectures[0]]
        else:
            optimize_rule_path = args.optimize_config_path
39

chenxl's avatar
chenxl committed
40
41
42
43
44
        # print(optimize_config)

        gguf_path = args.gguf_path
        if gguf_path is None:
            gguf_path = input(
45
46
                "please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
                " belong to current model):"
chenxl's avatar
chenxl committed
47
48
49
            )
        optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)

50
51
        self.device_map = self.model.gguf_loader.tensor_device_map
        # logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
52
53
54
55
        self.cache = StaticCache(
            config=self.model.config,
            max_batch_size=args.batch_size,
            max_cache_len=args.cache_lens,
56
            device=self.device_map,
57
58
            dtype=self.model.dtype,
        )
59
60
61
62
63
64
65
66
67
68
69
        # logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
        try:
            self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
        except:
            gen_config = GenerationConfig(
                max_length=128,
                temperature=0.7,
                top_p=0.9,
                do_sample=True
            )
            self.model.generation_config = gen_config
chenxl's avatar
chenxl committed
70
71
72
        if self.model.generation_config.pad_token_id is None:
            self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
        self.streamer = TextStreamer(self.tokenizer)
73

hrz6976's avatar
hrz6976 committed
74
75
        self._infer_lock = asyncio.Lock()

chenxl's avatar
chenxl committed
76
    def decode_one_tokens(self):
Azure's avatar
Azure committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        device_map = self.model.gguf_loader.tensor_device_map
        torch_device = get_device("blk.0.self_attn", device_map)
        torch_device = "cuda:0" if torch_device == "cuda" else torch_device
        if self.args.use_cuda_graph:
            if not hasattr(self, "cuda_graph_runner"):
                self.cuda_graph_runner = CUDAGraphRunner()
                self.cuda_graph_runner.capture(
                    self.model,
                    self.current_ids,
                    self.active_cache_position.unsqueeze(0),
                    self.active_cache_position,
                    self.cache,
                    main_device=torch_device,
                    return_dict=False,
                    use_cache=True,
                )
93

Azure's avatar
Azure committed
94
95
96
97
98
99
100
101
            if hasattr(self, "cuda_graph_runner"):
                logits = self.cuda_graph_runner(
                    self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
                )
                self.cache.change_seq_length(1)
                torch.cuda.synchronize()
                logits = logits[0, -1, :]
                return self.logits_to_token(logits)
102

chenxl's avatar
chenxl committed
103
        if self.use_static_cache:
104
            mask = torch.ones((1, self.seq_length)).to(torch_device)
chenxl's avatar
chenxl committed
105
            logits = self.model(
Azure's avatar
Azure committed
106
                self.current_ids.to(torch_device),
chenxl's avatar
chenxl committed
107
108
109
110
                cache_position=self.active_cache_position,
                past_key_values=self.cache,
                attention_mask=mask,
                return_dict=False,
111
                use_cache=True,
chenxl's avatar
chenxl committed
112
113
            )[0]
        else:
114
115
            logits = self.model(self.current_ids, return_dict=False)[0]
        logits = logits[0, -1, :]
chenxl's avatar
chenxl committed
116
117

        return self.logits_to_token(logits)
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176



    @torch.no_grad
    def prefill(self, input_ids: torch.Tensor, is_new: bool):
        input_ids_length = input_ids.shape[-1]
        self.profiler.set_counter("prefill", input_ids_length)
        logger.debug(f"input_ids: {input_ids.shape}")

        device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")

        if is_new:
            self.cache.reset()
            self.ever_generated_ids.clear()
            former_seq_length = 0
            self.seq_length = input_ids_length
            self.generated_ids = torch.zeros(
                self.args.batch_size,
                self.seq_length + self.args.max_new_tokens + 1,
                dtype=torch.int,
                device=self.args.device,
            )
        else:
            logger.debug(f"generate_ids: {self.generated_ids.shape}")
            former_seq_length = self.seq_length
            self.seq_length += input_ids_length
            expected_length = self.seq_length + self.args.max_new_tokens + 1
            delta_length = expected_length - self.generated_ids.shape[-1]
            if delta_length > 0:
                new_generate_ids = torch.zeros(
                    self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
                )
                self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
        logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
        cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
        self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)

        mask = torch.ones((1, self.seq_length)).to(device)
        if not (type(self) is TransformersInterface):
            input_ids = input_ids.to("cpu")
        inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
        if self.use_static_cache:
            logits = self.model(
                inputs_embeds=inputs_embeds,
                cache_position=cache_position,
                past_key_values=self.cache,
                return_dict=False,
                use_cache=True,
                attention_mask=mask,
            )[0]
        else:
            logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]

        next_token = self.logits_to_token(logits[0, -1, :])
        yield self.append_new_tokens(next_token)

    @property
    def active_cache_position(self):
        device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
hrz6976's avatar
hrz6976 committed
177
178
179
180
181
182
        return torch.tensor([self.seq_length - 1], device=device)
    
    async def inference(self, local_messages, thread_id: str):
        async with self._infer_lock:
            async for v in super().inference(local_messages, thread_id):
                yield v