balance_serve.py 19.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from typing import Any, AsyncIterator, List, Optional, Set
from ktransformers.models.custom_cache import KDeepSeekV3Cache
from transformers import (
    AutoTokenizer,
    AutoConfig,
    GenerationConfig,
    StaticCache,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)

from ktransformers.server.config.config import Config
from ..base import ThreadContext, BackendInterfaceBase
import torch
from ktransformers.server.backend.interfaces.transformers import (
    ConfigArgs,
    default_args,
    TextStreamer,
)
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.config.log import logger
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner 
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
from ktransformers.server.balance_serve.settings import sched_ext
from torch.multiprocessing import Queue
import torch.multiprocessing as mp
33
from multiprocessing.synchronize import Event
34
35
36
37
38
39
40
41
42
43
44
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.utils.multi_timer import Profiler
import zmq
import time
import queue
import tempfile
import asyncio
import threading
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
import os
45
46
47
48
import pickle
import subprocess
import tempfile
import atexit
49
50
import signal

51
52
53
54
55
56
57
58
59

ktransformer_rules_dir = (
    os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/") 
)
default_optimize_rules = {
    "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
    "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml",
}

60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):
    streamer = TextStreamer(tokenizer)
    while True:
        token = await queue.get()
        #print(f"Got token: {token}")
        if token is None:
            # str = f'{token}\n\n'
            # str = model.tokenizer.decode(token)
            s = streamer.end()
            if s is not None:
                yield s
            break

        # str = model.tokenizer.decode(token)
        yield streamer.put(token)
        


def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_tokens: torch.Tensor, query_manager: QueryManager = None):
    #print(len(query_updates), generated_tokens.size(0), generated_tokens)
    for i in range(generated_tokens.size(0)):
        print(generated_tokens[i].item())
        query_updates[i].generated_token = generated_tokens[i].item()
        if not query_manager.query_map[query_updates[i].id].is_prefill:
            pos = query_updates[i].active_position
86
87
            if pos < query_manager.query_map[query_updates[i].id].max_length:
                query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

def report_last_time_performance(profiler: Profiler):
        try:
            tokenize_time = profiler.get_timer_sec('tokenize')
            prefill_time = profiler.get_timer_sec('prefill')
            decode_time = profiler.get_timer_sec('decode')
            prefill_count = profiler.get_counter('prefill')
            decode_count = profiler.get_counter('decode')

            logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')
        except:
            logger.info(f'Performance statistics not recorded')

class Engine:
    sched_client : SchedulerClient
    updates : list[sched_ext.QueryUpdate]
    batch : sched_ext.BatchQueryTodo
    model_runner: ModelRunner
    sampler: Sampler
    query_manager: QueryManager
    cache: KDeepSeekV3Cache
109
    def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None):
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        self.args = args

        # 子进程和父进程无法共享 config 变量
        for key, value in vars(args).items():
            if value is not None and hasattr(Config(), key):
                setattr(Config(), key, value)

        self.device = self.args.device
        self.sched_client = SchedulerClient(args.sched_port)
        self.updates = []
        config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) 
        self.cache = KDeepSeekV3Cache(config, self.args.page_size)
            
        self.gen_queue = generated_token_queue
            
        with torch.device("meta"):
            if config.architectures[0] == "DeepseekV3ForCausalLM":
                self.model = KDeepseekV3ForCausalLM(config, self.cache)
            elif config.architectures[0] == "DeepseekV2ForCausalLM":
                self.model = KDeepseekV2ForCausalLM(config, self.cache)
        # print(self.block_num)

        context = zmq.Context()

            
        self.pub_socket = context.socket(zmq.PUB)
        self.pub_socket.bind(f"ipc://{broadcast_endpoint}") 
        # time.sleep(1) # make sure all subscribers are ready


        try:
            generation_config = GenerationConfig.from_pretrained(args.model_dir)
        except:
            generation_config = GenerationConfig(
                max_length=args.max_new_tokens,
                temperature=args.temperature,
                top_p=args.top_p,
                do_sample=True
            )
            
        if args.optimize_config_path is None:
            optimize_config_path = default_optimize_rules[config.architectures[0]]
               
        else:
            optimize_config_path = args.optimize_config_path
        gguf_path = args.gguf_path
        if gguf_path is None:
            gguf_path = input(
                "please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
                " belong to current model):"
            )
        optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
        self.model.generation_config = generation_config
        if self.model.generation_config.pad_token_id is None:
            self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id

        self.model.eval()
167
168
169
170
171
172
173
174
175
176
177
        kvcache_event.set()
        # load kvcache
        print(f"Getting inference context from sched_client.")
        inference_context = self.sched_client.get_inference_context_raw()
        print(f"Got inference context, sending it to subscribers.")
        inference_context = self.sched_client.rebuild_inferece_context(inference_context)
        self.cache.load(inference_context)
        print(f"kv_cache loaded successfully.")
        

        self.block_num = inference_context.k_cache[0].size(1)
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        #@TODO add config
        self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)

        self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size)
        self.sampler = Sampler()
        self.query_manager = QueryManager(device = self.device, page_size = args.page_size)

            
    def sampling(self, forward_output: ForwardBatchOutput):
        generated_tokens = torch.empty(0, device=self.device, dtype=torch.int32)
        for i in range(forward_output.num_batchs):
            logit = forward_output.logits[i]
            if hasattr(forward_output, "temperatures"):
                temperatures = forward_output.temperatures[i]
            else:
                temperatures = None
            
            if hasattr(forward_output, "top_ps"):
                top_ps = forward_output.top_ps[i]
            else:
                top_ps = None

            sample_options = SamplingOptions(logit.size(0), self.device, pretrained_config=self.model.generation_config, temperatures=temperatures, top_ps=top_ps)
            generated_tokens, probs=self.sampler(logit, sample_options)
        return generated_tokens, probs
    
    def loop(self):

        next_batch = None   

        while True:
            self.batch = next_batch
            if self.batch is not None:
                self.model_runner.run(self.batch, self.query_manager)

            if len(self.updates) > 0:
                for q in self.updates:
                    if q.is_prefill == True:
                        continue
                    # print(f"Putting token {q.generated_token} into queue for query id: {q.id}")
                    try:
                        self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)
                    except queue.Full:
                        pass#print("Queue is full after timeout; unable to put more items.")
                
            next_batch = self.sched_client.update_last_batch(self.updates)
            if next_batch.query_ids == []:
                next_batch = None
            self.pub_socket.send_pyobj(next_batch)  

            if next_batch is not None:
                self.query_manager.add_query(next_batch)
            
            
            if self.batch is not None:
                self.model_runner.sync()
                print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms")
                # if self.rank == 0:
                
                generated_tokens, probs = self.sampling( self.model_runner.output)
                
                self.updates = self.query_manager.update(self.batch)
                fill_generated_tokens(self.updates, generated_tokens, self.query_manager)
            else:
                self.updates = []

class BalanceServeThreadContext(ThreadContext):
    def get_local_messages(self):
        local_messages = []
        for m in self.messages:
            local_messages.append({"role": m.role.value, "content": m.get_text_content()})

        return local_messages
    

253
254
def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event):
    engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    if args.use_cuda_graph:
        engine.model_runner.warmup()
        
    event.set()
    engine.loop()


class BalanceServeInterface(BackendInterfaceBase):
    use_static_cache: bool = True

    model: Any
    tokenizer: AutoTokenizer

    cache: StaticCache
    generated_ids: torch.Tensor
    seq_length: int

    streamer: TextStreamer

    # thread_related
    last_request_id: Optional[str] = None
    ever_generated_ids: Set[int] = set()
277

278
279
280
281
282
283
284
285
286
287
288
289
290
    def __init__(self, args: ConfigArgs = default_args):
        self.args = args
        self.queue_map:dict[int,asyncio.Queue] = {}
        self.thread_map: dict[int, int] = {}
        processes = []
        self.broadcast_endpoint = tempfile.NamedTemporaryFile(delete=False).name # @TODO add to config
        ctx = mp.get_context("spawn")
        self.token_queue = ctx.Queue(maxsize=1000) 
        self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
        self.sched_client = SchedulerClient(args.sched_port)
        self.streamer = TextStreamer(self.tokenizer)

        start_event = ctx.Event()
291
        kvcache_event = ctx.Event()
292

293
        p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event, kvcache_event))
294
295
        p.start()
        processes.append(p)
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        kvcache_event.wait()


        with tempfile.NamedTemporaryFile(delete=False) as temp_file:
            pickle.dump(args, temp_file)
            temp_file_path = temp_file.name
        current_file = __file__
        target_file = os.path.join(os.path.dirname(current_file), "..", "..", "balance_serve", "sched_rpc.py")
        target_file = os.path.normpath(target_file)
        log_path = os.path.join(args.log_dir, "rpc.log")
        log = open(log_path, "a") 
        sched_process = subprocess.Popen(
            ["python3", target_file, "--config", temp_file_path], 
            stdout=log, 
            stderr=log
        )
        print("sched_rpc started with PID:", sched_process.pid)
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

        def signal_handler(signum, frame):
            print(f"Received signal {signum}, shutting down...")
            cleanup()
            os._exit(0) 

        def cleanup():
            print("Cleaning up...")

            for p in processes:
                if p.is_alive():
                    print(f"Terminating subprocess {p.pid}")
                    p.terminate()
                    p.join()

            if sched_process and sched_process.poll() is None:
                print(f"Terminating sched_process {sched_process.pid}")
                sched_process.terminate()
                sched_process.wait()
        signal.signal(signal.SIGINT, signal_handler)   
        signal.signal(signal.SIGTERM, signal_handler)
334

335
        start_event.wait()
336
    
337
338
    def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None, 
                   max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]:
339
        """Get sampling parameters and handle default values and edge cases"""
340
341
342
343
344
345
        if max_tokens is not None:
            max_completion_tokens = max_tokens
        if max_completion_tokens is None:
            max_completion_tokens = self.args.max_new_tokens
        else:
            max_completion_tokens = min(self.args.max_new_tokens, max_completion_tokens)
346
        if temperature is None:
347
            temperature = self.args.temperature
348
        if top_p is None:
349
            top_p = self.args.top_p
350
351
352
353
354
355
            
        if temperature == 0:
            temperature = 0.0001
        if top_p == 0:
            top_p = 0.0001
            
356
        return temperature, top_p, max_completion_tokens
357

358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
    def run_queue_proxy(self):
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        loop.run_until_complete(self.queue_proxy())

    @asynccontextmanager
    async def lifespan(self, app: FastAPI):
        asyncio.create_task(self.queue_proxy())
        yield

    async def queue_proxy(self):
        print("Queue Proxy Started")
        while True:
            try:
                query_id, token = self.token_queue.get_nowait()
                try:
                    # query id might not be allocated yet
                    self.queue_map[query_id].put_nowait(token)
                    #print(f"Proxy Put token: {token} to queue for query id: {query_id}")
                except asyncio.QueueFull:
                    #print(f"Queue for query id: {query_id} is full, waiting to put: {token}")
                    await self.queue_map[query_id].put(token)

            except queue.Empty:
                # print("no new token")
                # await asyncio.sleep(1)
                await asyncio.sleep(0)
    def tokenize_prompt(self, prompt: str):
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
        return input_ids

    def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
        for m in messages:
            if m["role"] == "system":
                logger.warning(f'change {m["role"]} to user')
                m["role"] = "user"

        new_messages = [messages[0]]
        for m in messages[1:]:
            if m["role"] == "user" and new_messages[-1]["role"] == "user":
                logger.warning("merge two adjacent user messages")
                new_messages[-1]["content"] += '\n' + m["content"]
            else:
                new_messages.append(m)
        input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
        # drop <think> token in chat template
        if input_str.endswith('<think>\n'):
            input_str = input_str[:-len('<think>\n')]
        input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
        logger.debug(f"get input ids of shape {input_ids.shape}")
        return input_ids
    
410
411
    async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, 
                        max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
        profiler = Profiler()
        profiler.create_and_start_timer("tokenize")
        
        if isinstance(local_messages, List):
            input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
        elif isinstance(local_messages, str):
            input_ids = self.tokenize_prompt(local_messages)
        else:
            raise ValueError("local_messages should be List or str")
        if Config().user_force_think:
            token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
            input_ids = torch.cat(
                [input_ids, token_thinks], dim=1
            )

        profiler.pause_timer("tokenize")

        profiler.create_and_start_timer("prefill")
        
        query_add = sched_ext.QueryAdd()
        query_add.query_token =  input_ids[0].tolist()
        query_length = input_ids[0].shape[0]
        query_add.query_length = query_length
        profiler.set_counter("prefill", query_length)
        #@TODO add server
        stop_criteria =  [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
        query_add.stop_criteria = stop_criteria
439
        
440
        temperature, top_p, max_new_tokens = self.get_params(temperature, top_p, max_tokens, max_completion_tokens)
441
            
442
443
        query_add.sample_options.temperature = temperature
        query_add.sample_options.top_p = top_p
444
        query_add.estimated_length = min(self.args.cache_lens, query_length+max_new_tokens)
445
446
447
448

        if query_add.estimated_length < query_add.query_length:
            raise Exception(f'query too long: estimated_length={query_add.estimated_length} < query_length={query_add.query_length}')

449
        query_id = self.sched_client.add_query(query_add)
450
        queue = asyncio.Queue(maxsize=max_new_tokens)
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
        self.queue_map[query_id] = queue
        self.thread_map[thread_id] = query_id
        is_first_token = True
        async for token in chat_stream(self.queue_map[query_id], self.tokenizer):
            if is_first_token:
                is_first_token=False
                profiler.pause_timer("prefill")
                profiler.create_and_start_timer("decode")
                profiler.set_counter("decode", 0)
                if Config().user_force_think:
                    think = '<think>\n'
                    print(think, end="",flush=True)
                    yield think, None
            else:
                profiler.inc("decode")
            yield token, None
        profiler.pause_timer("decode")
        report_last_time_performance(profiler)
        yield self.streamer.end(), None
470
        if profiler.get_counter('decode') >= max_new_tokens - 1:
471
472
473
474
475
476
477
478
479
480
481
482
            yield "", "length"
        else:
            yield "", "stop"
        
        
        yield RawUsage(
                tokenize_time = profiler.get_timer_sec('tokenize'),
                prefill_time = profiler.get_timer_sec('prefill'),
                decode_time = profiler.get_timer_sec('decode'),
                prefill_count = profiler.get_counter('prefill'),
                decode_count = profiler.get_counter('decode'),
            )