Unverified Commit 877aec85 authored by Yuhao Tsui's avatar Yuhao Tsui Committed by GitHub
Browse files

Merge branch 'kvcache-ai:main' into main

parents 84164f58 9037bf30
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import pickle
import time
from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
import torch
from torch.distributed import TCPStore
import server.envs as envs
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# NOTE: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def get_pp_indices(
num_hidden_layers: int, pp_rank: int, pp_size: int
) -> Tuple[int, int]:
"""Try to evenly distribute layers across partitions.
If the number of layers is not divisible by the number of partitions,
the last partition will have the remaining layers.
"""
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
if partition_list_str is not None:
try:
partitions = [int(layer) for layer in partition_list_str.split(",")]
except ValueError as err:
raise ValueError(
"Invalid partition string: {}".format(partition_list_str)
) from err
if len(partitions) != pp_size:
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
else:
layers_per_partition = num_hidden_layers // pp_size
start_layer = pp_rank * layers_per_partition
end_layer = start_layer + layers_per_partition
if pp_rank == pp_size - 1:
end_layer = num_hidden_layers
return (start_layer, end_layer)
@dataclasses.dataclass
class StatelessProcessGroup:
"""A dataclass to hold a metadata store, and the rank, world_size of the
group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
"""
rank: int
world_size: int
store: torch._C._distributed_c10d.Store
data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
# src rank -> counter
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
broadcast_send_counter: int = 0
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
# A deque to store the data entries, with key and timestamp.
entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque)
def __post_init__(self):
assert self.rank < self.world_size
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
def send_obj(self, obj: Any, dst: int):
"""Send an object to a destination rank."""
self.expire_data()
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
def expire_data(self):
"""Expire data that is older than `data_expiration_seconds` seconds."""
while self.entries:
# check the oldest entry
key, timestamp = self.entries[0]
if time.time() - timestamp > self.data_expiration_seconds:
self.store.delete_key(key)
self.entries.popleft()
else:
break
def recv_obj(self, src: int) -> Any:
"""Receive an object from a source rank."""
obj = pickle.loads(
self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
)
self.recv_src_counter[src] += 1
return obj
def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
"""Broadcast an object from a source rank to all other ranks.
It does not clean up after all ranks have received the object.
Use it for limited times, e.g., for initialization.
"""
if self.rank == src:
self.expire_data()
key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return obj
else:
key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return recv_obj
def all_gather_obj(self, obj: Any) -> list[Any]:
"""All gather an object from all ranks."""
gathered_objs = []
for i in range(self.world_size):
if i == self.rank:
gathered_objs.append(obj)
self.broadcast_obj(obj, src=self.rank)
else:
recv_obj = self.broadcast_obj(None, src=i)
gathered_objs.append(recv_obj)
return gathered_objs
def barrier(self):
"""A barrier to synchronize all ranks."""
for i in range(self.world_size):
if i == self.rank:
self.broadcast_obj(None, src=self.rank)
else:
self.broadcast_obj(None, src=i)
@staticmethod
def create(
host: str,
port: int,
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
If we have process A and process B called `torch.distributed.init_process_group`
to form a group, and then we want to form another group with process A, B, C,
D, it is not possible in PyTorch, because process A and process B have already
formed a group, and process C and process D cannot join that group. This
function is a workaround for this issue.
`torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `StatelessProcessGroup` object that can be
used for exchanging metadata. With this function, process A and process B
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
store = TCPStore(
host_name=host,
port=port,
world_size=world_size,
is_master=(rank == 0),
)
return StatelessProcessGroup(
rank=rank,
world_size=world_size,
store=store,
data_expiration_seconds=data_expiration_seconds,
)
'''
Date: 2024-11-12 14:15:16
LastEditors: Xie Weiyu ervinxie@qq.com
LastEditTime: 2024-11-26 08:12:49
'''
import torch
from ktransformers.server.balance_serve.settings import sched_ext
from ktransformers.server.balance_serve.inference.query_manager import QueryManager, QueryInfo
import time
from ktransformers.server.config.config import Config
class ForwardBatchInput:
class ForwardMiniBatch:
q_indptr: torch.Tensor
kv_indptr: torch.Tensor
kv_indices: torch.Tensor
kv_last_page_len: torch.Tensor
kv_len: torch.Tensor
position_ids: torch.Tensor
tokens: torch.Tensor
batch_indices: torch.Tensor
positions: torch.Tensor
chunk_size: int
decode_batch: int
is_last_prefill_chunk: bool
logits_start: list
temperatures: torch.Tensor
top_ps: torch.Tensor
def __init__(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256):
batch_decode = len(decode_querys_info)
batch_prefill = len(prefill_querys_info)
self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32)
self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32)
self.kv_indices = torch.tensor([], device=device, dtype=torch.int32)
self.kv_len = torch.tensor([], device=device, dtype=torch.int32)
self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32)
self.position_ids = torch.tensor([], device=device, dtype=torch.int32)
self.tokens = torch.tensor([], device=device, dtype=torch.int32)
self.temperatures = torch.tensor([], device=device, dtype=torch.float32)
self.top_ps = torch.tensor([], device=device, dtype=torch.float32)
self.logits_start = []
self.decode_batch = batch_decode
self.num_tokens = batch_decode + sum(prefill_l)
self.batch_size = batch_decode + batch_prefill
for i, prefill_query_info in enumerate(prefill_querys_info):
if prefill_query_info != None:
prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0
# print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}")
self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0)
self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0)
self.position_ids = torch.concat((self.position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)
self.tokens = torch.concat((self.tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)
self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1)
self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
for decode_query_info in decode_querys_info:
decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size
self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0)
self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0)
self.position_ids = torch.concat((self.position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0)
if decode_query_info.active_position > 0:
self.tokens = torch.concat((self.tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0)
else:
self.tokens = torch.concat((self.tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0)
self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1)
self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
self.q_indptr = self.q_indptr.contiguous()
self.kv_indptr = self.kv_indptr.contiguous()
self.kv_indices = self.kv_indices.contiguous()
self.kv_len = self.kv_len.contiguous()
self.kv_last_page_len = self.kv_last_page_len.contiguous()
self.position_ids = self.position_ids.contiguous()
self.tokens = self.tokens.contiguous()
self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)
def fill(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256):
batch_decode = len(decode_querys_info)
batch_prefill = len(prefill_querys_info)
self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32)
self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32)
self.kv_indices = torch.tensor([], device=device, dtype=torch.int32)
self.kv_len = torch.tensor([], device=device, dtype=torch.int32)
self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32)
new_position_ids = torch.tensor([], device=device, dtype=torch.int32)
new_tokens = torch.tensor([], device=device, dtype=torch.int32)
self.temperatures = torch.tensor([], device=device, dtype=torch.float32)
self.top_ps = torch.tensor([], device=device, dtype=torch.float32)
self.logits_start = []
self.decode_batch = batch_decode
self.num_tokens = batch_decode + sum(prefill_l)
self.batch_size = batch_decode + batch_prefill
for i, prefill_query_info in enumerate(prefill_querys_info):
prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0
# print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}")
self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0)
self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0)
new_position_ids = torch.concat((new_position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)
new_tokens = torch.concat((new_tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)
self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1)
self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
for decode_query_info in decode_querys_info:
decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size
self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)
self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0)
self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)
self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0)
new_position_ids = torch.concat((new_position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0)
if decode_query_info.active_position > 0:
new_tokens = torch.concat((new_tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0)
else:
new_tokens = torch.concat((new_tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0)
self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1)
self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)
self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)
self.q_indptr = self.q_indptr.contiguous()
self.kv_indptr = self.kv_indptr.contiguous()
self.kv_indices = self.kv_indices.contiguous()
self.kv_len = self.kv_len.contiguous()
self.kv_last_page_len = self.kv_last_page_len.contiguous()
self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)
# copy new_position_ids and new_tokens to self.position_ids and self.tokens
# print("new_position_ids: ", new_position_ids)
# self.print()
self.position_ids[:new_position_ids.size(0)].copy_(new_position_ids)
self.position_ids[new_position_ids.size(0):].zero_()
self.tokens[:new_tokens.size(0)].copy_(new_tokens)
forward_minibatchs: list[ForwardMiniBatch]
batch_size: int
minibatch: ForwardMiniBatch
def __init__(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, device=None, tokens: torch.Tensor = None):
if batch is None:
return
prefill_minibatches = batch.prefill_mini_batches
decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist]
prefill_querys_info = []
prefill_s = []
prefill_l = []
decode_querys_info = []
self.batch_size = 1
for (id, s, l) in prefill_minibatches:
prefill_querys_info.append(query_manager.query_map[id])
prefill_s.append(s)
prefill_l.append(l)
for decode_batch_idx in decode_mini_batches:
if query_manager.query_map[decode_batch_idx].decode_start_time is None:
query_manager.query_map[decode_batch_idx].decode_start_time =time.time()
decode_querys_info.append(query_manager.query_map[decode_batch_idx])
minibatch = ForwardBatchInput.ForwardMiniBatch(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size)
self.minibatch = minibatch
@classmethod
def gen_max_forward_batch(
cls,
device=None,
tokens: torch.Tensor = None,
num_mini_batches: int = 1,
max_seq_length: int = 1024, # TODO: add to yaml
prefill_query_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, # TODO: use config
prefill_active_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size,
gen_prefill: bool = True,
decode_batch_size: int = Config().max_decode_batch_size,
decode_active_position: torch.Tensor = None,
page_size = 256,
cuda_lens = 1
):
instance = cls()
instance.batch_size = num_mini_batches
page_size = page_size
prefill_query_info = []
offset = 0
if gen_prefill and prefill_query_length != 0:
for i in range(Config().max_prefill_batch_size):
prefill_query_info.append(QueryInfo(i, prefill_query_length, max_seq_length, page_size, device, offset=offset))
offset += max_seq_length // page_size
decode_querys_info = []
for i in range(min(decode_batch_size, cuda_lens)):
query_info = QueryInfo(i+Config().max_prefill_batch_size, prefill_query_length, max_seq_length, page_size, device, is_prefill=False, offset=offset)
offset += max_seq_length // page_size
if tokens is not None:
query_info.query_tokens[prefill_active_length:prefill_active_length + 1].copy_(tokens)
if decode_active_position is None:
query_info.active_position = prefill_active_length
else:
query_info.active_position = decode_active_position[i]
decode_querys_info.append(query_info)
if prefill_query_length*Config().max_prefill_batch_size + len(decode_querys_info) < cuda_lens:
decode_querys_info.append(query_info)
instance.minibatch = ForwardBatchInput.ForwardMiniBatch(prefill_query_info, decode_querys_info, [0, 0], [prefill_active_length for _ in range(Config().max_prefill_batch_size)], device, page_size)
return instance
def fill(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, page_size = 256):
if batch is None:
return
prefill_minibatches = batch.prefill_mini_batches
decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist]
prefill_querys_info = []
prefill_s = []
prefill_l = []
decode_querys_info = []
self.batch_size = 1
for (id, s, l) in prefill_minibatches:
prefill_querys_info.append(query_manager.query_map[id])
prefill_s.append(s)
prefill_l.append(l)
for decode_batch_idx in decode_mini_batches:
if query_manager.query_map[decode_batch_idx].decode_start_time is None:
query_manager.query_map[decode_batch_idx].decode_start_time =time.time()
decode_querys_info.append(query_manager.query_map[decode_batch_idx])
self.minibatch.fill(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device=query_manager.device, page_size=page_size)
class ForwardBatchOutput:
logits: list[torch.Tensor]
num_batchs: int
batch_sizes: list[int]
generated_tokens_num: list[int]
lm_start: list[int]
temperatures: list[torch.Tensor]
top_ps: list[torch.Tensor]
def __init__(self):
self.logits = []
self.batch_sizes = []
self.generated_tokens_num = []
self.top_ps = []
self.temperatures = []
pass
\ No newline at end of file
"""
Date: 2024-11-07 07:02:20
LastEditors: djw
LastEditTime: 2024-12-10 08:48:32
"""
import torch
from torch import nn
import queue
import signal
import queue
from typing import AsyncIterable
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field
import asyncio
import multiprocessing
import time
import torch.multiprocessing as mp
import random
import torch.distributed as dist
import zmq
import tempfile
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
from ktransformers.server.config.config import Config
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
from ktransformers.server.balance_serve.settings import sched_ext
def pad_num_tokens(num_tokens):
return (num_tokens + 63) // 64 * 64
def deduplicate_and_sort(lst):
return sorted(set(lst))
class ModelRunner:
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
model: KDeepseekV3ForCausalLM
input: ForwardBatchInput | list[ForwardBatchInput]
output: ForwardBatchOutput
def __init__(self, model = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256):
self.stream = torch.cuda.Stream(device=device)
# 先注释掉
self.model = model # Compile and move model to the specified device
self.device = device
self.input = None
self.features_buf = None
self.output = None
self.graph_memory_pool = None
self.cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size])
self.use_cuda_graph = use_cuda_graph
self.model_time = 0
self.page_size = page_size
# GPU timing for model execution
self.start_model_event = torch.cuda.Event(enable_timing=True)
self.end_model_event = torch.cuda.Event(enable_timing=True)
if isinstance(self.cuda_graphs, list):
self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))]
self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
else:
self.graphs = torch.cuda.CUDAGraph()
self.page_idx_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device)
self.page_offset_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device)
self.num_mini_batches = num_mini_batches
self.max_chunk_size = max_chunk_size
self.bsz_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)
self.num_tokens_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)
def warmup(self):
def capture_graphs(cuda_graph_idx=-1):
if cuda_graph_idx != -1:
with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream):
self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx)
self.graph_memory_pool = self.graphs[cuda_graph_idx].pool()
else:
with torch.cuda.graph(self.graphs, pool=self.graph_memory_pool, stream=self.stream):
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
self.graph_memory_pool = self.graphs.pool()
if isinstance(self.cuda_graphs, list):
self.input = []
self.features_buf = []
self.outputs_buf = []
self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
for i in range(len(self.cuda_graphs)):
prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0 #@TODO only supprot 2 prefill batch
self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens = self.cuda_graphs[i]))
self.features_buf.append(self.model.batch_embeddings(self.input[i]))
batch_size = self.input[i].minibatch.q_indptr.size(0)-1
num_tokens = self.features_buf[i][0].size(0)
print("capturing cuda graph", batch_size, num_tokens)
self.bsz_tensor_buf[0] = batch_size
self.num_tokens_tensor_buf[0] = num_tokens
self.model.flash_infer_attn_plan(self.input[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf)
self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens])
self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens])
self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1)
self.outputs_buf.append(None)
torch.cuda.synchronize()
for warm_up_iters in range(11):
with torch.cuda.stream(self.stream):
self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i])
torch.cuda.synchronize()
capture_graphs(i)
with torch.cuda.stream(self.stream):
self.graphs[i].replay()
self.sync(calc_time=False)
print(f"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.")
else:
self.input = ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches)
self.features_buf = self.model.batch_embeddings(self.input)
batch_size = self.input.minibatch.q_indptr.size(0)-1
num_tokens = self.features_buf[0].size(0)
self.bsz_tensor_buf = torch.tensor([batch_size], dtype=torch.int32, device=self.device)
self.num_tokens_tensor_buf = torch.tensor([num_tokens], dtype=torch.int32, device=self.device)
self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf)
self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens])
self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens])
self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
torch.cuda.synchronize()
for warm_up_iters in range(11):
with torch.cuda.stream(self.stream):
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
torch.cuda.synchronize()
def capture_graphs():
with torch.cuda.graph(self.graphs, stream=self.stream):
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
# self.graph_memory_pool = self.graphs.pool()
capture_graphs()
with torch.cuda.stream(self.stream):
self.graphs.replay()
self.sync(calc_time=False)
print("warmup finished.")
def run(self, batch: sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None):
with torch.cuda.stream(self.stream):
batch_size = len(batch.prefill_mini_batches) # TODO: calc this
num_tokens = 0
for i in range(len(batch.decode_mini_batches)):
batch_size += len(batch.decode_mini_batches[i])
num_tokens += len(batch.decode_mini_batches[i])
print(f'decode_batch_i: {len(batch.decode_mini_batches[i])},')
for i in range(len(batch.prefill_mini_batches)):
num_tokens += batch.prefill_mini_batches[i][2]
print(f'prefill_batch_i: {batch.prefill_mini_batches[i][2]},')
if isinstance(self.cuda_graphs, list):
# cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens
cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs))
if cuda_graph_idx == len(self.cuda_graphs):
assert False, "num_tokens is too large"
else:
cuda_graph_idx = -1
if self.use_cuda_graph:
if cuda_graph_idx != -1:
self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size)
else:
self.input.fill(batch, query_manager, self.page_size)
else:
self.input = ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device)
if cuda_graph_idx != -1 and self.use_cuda_graph:
self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device)
else:
self.features = self.model.batch_embeddings(self.input, device=self.device)
self.bsz_tensor_buf.copy_(batch_size)
self.num_tokens_tensor_buf.copy_(torch.tensor([num_tokens], dtype=torch.int32, device=self.device))
if self.use_cuda_graph:
if cuda_graph_idx != -1:
self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True)
else:
self.features_buf[0].copy_(self.features[0], non_blocking=True)
"""
if num_tokens_0 > 64:
padded_num_tokens_0 = pad_num_tokens(num_tokens_0)
self.features_buf[0][num_tokens_0:padded_num_tokens_0] = 0
"""
#self.input.forward_minibatchs[0].print()
# print([[hash(k[i].float().cpu().numpy().tobytes()) for i in self.input.forward_minibatchs[0].kv_indices] for k in self.model.cache.k_caches])
# print(f"overlap: {overlap}, is_compute_bound: {is_compute_bound}")
# self.model.flash_infer_attn_plan(self.input, self.bsz_tensors, self.num_tokens_tensors)
"""
if self.use_cuda_graph:
print("before replay features_buf", self.features_buf[0])
print("features_buf addr", self.features_buf[0].data_ptr())
else:
print("before run features", self.features[0])
"""
if cuda_graph_idx != -1 and self.use_cuda_graph:
self.model.flash_infer_attn_plan(self.input[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
self.start_model_event.record(self.stream)
page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf)
if self.use_cuda_graph:
self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens])
self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens])
self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
self.replay(cuda_graph_idx)
self.output = ForwardBatchOutput()
self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps)
self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures)
self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone())
else:
self.output = self.model(self.input[cuda_graph_idx], self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start]
self.end_model_event.record(self.stream)
else:
self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
self.start_model_event.record(self.stream)
page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf)
if self.use_cuda_graph:
self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens])
self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens])
self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
self.replay(cuda_graph_idx)
self.output = ForwardBatchOutput()
self.output.top_ps.append(self.input.minibatch.top_ps)
self.output.temperatures.append(self.input.minibatch.temperatures)
self.output.logits.append(self.outputs_buf.logits[0][self.input.minibatch.logits_start].clone())
else:
self.output = self.model(self.input, self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
self.output.logits[0] = self.output.logits[0][self.input.minibatch.logits_start]
self.output.top_ps.append(self.input.minibatch.top_ps)
self.output.temperatures.append(self.input.minibatch.temperatures)
self.end_model_event.record(self.stream)
if not self.use_cuda_graph:
self.output.num_batchs = self.input.batch_size
else:
self.output.num_batchs = self.input[cuda_graph_idx].batch_size
def replay(self, cuda_graph_idx=-1):
with torch.cuda.stream(self.stream):
if cuda_graph_idx != -1:
self.graphs[cuda_graph_idx].replay()
else:
self.graphs.replay()
def sync(self, calc_time = True):
self.stream.synchronize()
if calc_time:
self.model_time = self.start_model_event.elapsed_time(self.end_model_event) # In ms
\ No newline at end of file
'''
Date: 2024-11-14 12:23:45
LastEditors: djw
LastEditTime: 2024-11-20 04:06:23
'''
import torch
from ktransformers.server.balance_serve.settings import sched_ext
import random
import time
class QueryInfo:
id: int
active_position: int
query_length: int
is_prefill: int
block_index: torch.Tensor
query_tokens: torch.Tensor
stop_criteria: list[torch.Tensor]
temperature: float
top_p: float
max_length: int
def __init__(self, id, query_length: int, max_length: int, page_size: int, device: torch.device, is_prefill: bool = True, offset: int = 0, active_position: int = 0, temperature: float = 0.01, top_p: float = 1.0):
self.id = id
self.is_prefill = is_prefill
self.active_position = active_position
self.max_length = max_length - 1
self.query_tokens = torch.zeros((max_length,), dtype=torch.int, device = device)
self.stop_criteria = []
self.block_index = torch.arange(offset, offset + (max_length + active_position + page_size - 1) // page_size, dtype=torch.int, device = device)
self.query_length = query_length
self.enqueue_time = time.time()
self.decode_start_time = None
self.speculative_token = {} # {position: (accept, token)}
self.temperature = temperature
self.top_p = top_p
def check_stop(self):
if self.active_position >= self.max_length - 2:
return True
# 遍历每个停止条件
for stop_tensor in self.stop_criteria:
stop_len = len(stop_tensor)
# 如果停止条件比 query_tokens 长,跳过
if stop_len >= self.active_position:
continue
#print(f"stop_tensor: {stop_tensor}, stop_len: {stop_len}, active_position: {self.active_position}, query_token: {self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1]}")
if (torch.equal(self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1], stop_tensor) and self.active_position) or self.max_length <= self.active_position + 3:
self.life_time = time.time() - self.enqueue_time
self.decode_duration_time = time.time() - self.decode_start_time
self.decode_tps = (self.active_position - self.query_length) / self.decode_duration_time
print(f"prefill length: {self.query_length}, prefill time: {self.prefill_duration_time}, prefill tps {self.prefill_tps}, decode length: {self.active_position - self.query_length}, decode time: {self.decode_duration_time}, decode tps {self.decode_tps}")
return True # 找到匹配的停止条件
return False # 没有找到任何停止条件
def print(self):
print(f"active_position: {self.active_position}, query_length: {self.query_length}, is_prefill: {self.is_prefill}")
print(f"block_index_shape: {self.block_index.shape}, query_tokens_shape: {self.query_tokens.shape}")
class QueryManager:
max_length: int = 65536
page_size: int = 256
device: torch.device
query_map : dict[int, QueryInfo]
def __init__(self, max_length = 65536, page_size = 256, device = torch.device('cuda')):
self.max_length = max_length
self.page_size = page_size
self.device = device
self.query_map = {}
def add_query(self, batch: sched_ext.BatchQueryTodo):
for i in range(len(batch.query_ids)):
id = batch.query_ids[i]
if id not in self.query_map:
print(f"add query id: {id}, batch.query_lengths: {batch.query_lengths[i]}, batch_query_tokens: {batch.query_tokens[i].shape}, batch.block_indexes: {batch.block_indexes[i]}")
assert batch.query_tokens[i].size(0) < self.max_length, "query max length in batchquerytodo exceeds internal max_length"
query_info = QueryInfo(id=id, query_length=batch.query_lengths[i], max_length=batch.query_tokens[i].size(0) + 1, page_size=self.page_size, device=self.device, temperature=batch.sample_options[i].temperature, top_p=batch.sample_options[i].top_p)
query_info.query_tokens[:query_info.query_length].copy_(batch.query_tokens[i][:query_info.query_length].to(self.device))
for stop_token_list in batch.stop_criteria[i]:
query_info.stop_criteria.append(torch.tensor(stop_token_list, dtype=torch.int, device = self.device))
block_num = batch.block_indexes[i].size(0)
query_info.block_index[:block_num].copy_(batch.block_indexes[i].to(self.device))
self.query_map[id] = query_info
prefill_mini_batches = batch.prefill_mini_batches
for (prefill_id, s, l) in prefill_mini_batches:
if prefill_id == id:
self.query_map[prefill_id].active_position = s
def update(self, batch: sched_ext.BatchQueryTodo) -> list[sched_ext.QueryUpdate]:
query_updates = []
prefill_mini_batches = batch.prefill_mini_batches
for (id, s, l) in prefill_mini_batches:
if id not in self.query_map:
assert False, f"query id {id} not found in query_map"
# update query_info
query_info = self.query_map[id]
query_info.active_position += l
if query_info.active_position >= query_info.query_length and query_info.is_prefill:
query_info.is_prefill = False
query_info.prefill_duration_time = time.time() - query_info.enqueue_time
query_info.prefill_tps = query_info.query_length / query_info.prefill_duration_time
# generate schedule query_update
query_update = sched_ext.QueryUpdate()
query_update.id = id
query_update.ok = True
query_update.is_prefill = query_info.is_prefill
query_update.active_position = query_info.active_position
# if(not query_info.is_prefill):
query_updates.append(query_update)
decode_mini_batches = batch.decode_mini_batches
for ids in decode_mini_batches:
for id in ids:
if id not in self.query_map:
assert False, f"query id {id} not found in query_map"
query_info = self.query_map[id]
query_info.active_position += 1
query_update = sched_ext.QueryUpdate()
query_update.id = id
query_update.ok = True
query_update.is_prefill = query_info.is_prefill
query_update.decode_done = query_info.check_stop()
query_update.active_position = query_info.active_position
query_updates.append(query_update)
return query_updates
\ No newline at end of file
from .orchestrator import BatchedPenalizerOrchestrator
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
from .penalizers.presence_penalty import BatchedPresencePenalizer
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
__all__ = [
"BatchedFrequencyPenalizer",
"BatchedMinNewTokensPenalizer",
"BatchedPresencePenalizer",
"BatchedRepetitionPenalizer",
"BatchedPenalizerOrchestrator",
]
import abc
import dataclasses
import typing
import torch
@dataclasses.dataclass
class _ReqLike:
origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
@dataclasses.dataclass
class _BatchLike:
reqs: typing.List[_ReqLike]
def batch_size(self):
return len(self.reqs)
class BatchedPenalizerOrchestrator:
batch: _BatchLike
device: str
vocab_size: int
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
def __init__(
self,
vocab_size: int,
batch: _BatchLike,
device: str,
Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
):
self.vocab_size = vocab_size
self.batch = batch
self.device = device
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
is_required = False
for penalizer in self.penalizers.values():
pen_is_required = penalizer.prepare_if_required()
is_required |= pen_is_required
self.is_required = is_required
if self.is_required:
self.cumulate_input_tokens(
input_ids=[req.origin_input_ids for req in self.reqs()]
)
def reqs(self):
return self.batch.reqs
def batch_size(self):
return self.batch.batch_size()
def cumulate_input_tokens(
self,
input_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
):
"""
Feed the input tokens to the penalizers.
Args:
input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
"""
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
for penalizer in self.penalizers.values():
penalizer.cumulate_input_tokens(input_ids=token_ids)
def cumulate_output_tokens(
self,
output_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
):
"""
Feed the output tokens to the penalizers.
Args:
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
"""
if not self.is_required:
return
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
for penalizer in self.penalizers.values():
penalizer.cumulate_output_tokens(output_ids=token_ids)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
"""
Apply the penalizers to the logits.
Note that it may apply the penalizers in-place.
Args:
logits (torch.Tensor): The logits to apply the penalizers to.
Returns:
torch.Tensor: The logits after applying the penalizers.
"""
if not self.is_required:
return
for penalizer in self.penalizers.values():
logits = penalizer.apply(logits)
return logits
def filter(
self,
indices_to_keep: typing.List[int],
indices_tensor_to_keep: torch.Tensor = None,
):
"""
Filter the penalizers based on the indices to keep in the batch.
Args:
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
"""
if not self.is_required:
return
empty_indices = len(indices_to_keep) == 0
is_required = False
for penalizer in self.penalizers.values():
tmp_is_required = penalizer.is_required()
is_required = is_required or tmp_is_required
if not tmp_is_required or empty_indices:
penalizer.teardown()
else:
# create tensor index only when it's needed
if indices_tensor_to_keep is None:
indices_tensor_to_keep = torch.tensor(
indices_to_keep, dtype=torch.int32, device=self.device
)
penalizer.filter(
indices_to_keep=indices_to_keep,
indices_tensor_to_keep=indices_tensor_to_keep,
)
self.is_required = is_required
def merge(self, their: "BatchedPenalizerOrchestrator"):
"""
Merge the penalizers of another orchestrator into this one.
Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
This step requires the original batch.reqs, before it gets merged with other batch.reqs.
Args:
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
"""
if not self.is_required and not their.is_required:
return
self.is_required |= their.is_required
for Penalizer, their_penalizer in their.penalizers.items():
if Penalizer not in self.penalizers:
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
self.penalizers[Penalizer].merge(their_penalizer)
class _TokenIDs:
"""
A class that wraps token IDs to provide additional utility functions to penalizers.
Attributes:
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
cached_counts (torch.Tensor): The cached occurrence count tensor.
"""
orchestrator: BatchedPenalizerOrchestrator
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
cached_counts: torch.Tensor = None
def __init__(
self,
orchestrator: BatchedPenalizerOrchestrator,
token_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
):
self.orchestrator = orchestrator
if not isinstance(token_ids[0], torch.Tensor):
token_ids = [
torch.tensor(
data=ids, dtype=torch.int64, device=self.orchestrator.device
)
for ids in token_ids
]
self.token_ids = token_ids
def occurrence_count(self) -> torch.Tensor:
"""
Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
Returns:
torch.Tensor: The occurrence count tensor.
"""
if self.cached_counts is not None:
return self.cached_counts
token_ids = self.token_ids
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.unsqueeze(1)
# needs to be long to be used as index in scatter_add
if token_ids.dtype != torch.int64:
token_ids = token_ids.to(torch.int64)
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
sequences=token_ids,
batch_first=True,
padding_value=self.orchestrator.vocab_size,
)
self.cached_counts = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
dtype=torch.int64,
device=self.orchestrator.device,
).scatter_add_(
dim=1,
index=padded_token_ids,
src=torch.ones_like(padded_token_ids),
)[
:, : self.orchestrator.vocab_size
]
return self.cached_counts
class _BatchedPenalizer(abc.ABC):
"""
An abstract class for a batched penalizer.
"""
orchestrator: BatchedPenalizerOrchestrator
_is_prepared: bool = False
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self.orchestrator = orchestrator
def is_prepared(self) -> bool:
return self._is_prepared
def is_required(self) -> bool:
return self._is_required()
def prepare(self):
if not self.is_prepared():
self._prepare()
self._is_prepared = True
def prepare_if_required(self):
if self.is_required():
self.prepare()
return True
else:
return False
def teardown(self):
if self.is_prepared():
self._teardown()
self._is_prepared = False
def cumulate_input_tokens(self, input_ids: _TokenIDs):
if not self.is_prepared():
return
self._cumulate_input_tokens(input_ids=input_ids)
def cumulate_output_tokens(self, output_ids: _TokenIDs):
if not self.is_prepared():
return
self._cumulate_output_tokens(output_ids=output_ids)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.is_prepared():
return logits
return self._apply(logits=logits)
def filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
if not self.is_prepared():
return
self._filter(
indices_to_keep=indices_to_keep,
indices_tensor_to_keep=indices_tensor_to_keep,
)
def merge(self, their: "_BatchedPenalizer"):
if not self.is_prepared() and not their.is_prepared():
return
self.prepare()
their.prepare()
self._merge(their)
@abc.abstractmethod
def _is_required(self) -> bool:
"""
Check if the penalizer is required to be prepared.
"""
pass
@abc.abstractmethod
def _prepare(self):
"""
Prepare the penalizer.
Usually, this is where the penalizer initializes its tensors.
"""
pass
@abc.abstractmethod
def _teardown(self):
"""
Tear down the penalizer.
Usually, this is where the penalizer frees its tensors.
"""
pass
@abc.abstractmethod
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
"""
Cumulate the input tokens.
Orchestrator will call this function to feed the input tokens to the penalizer.
"""
pass
@abc.abstractmethod
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
"""
Cumulate the output tokens.
Orchestrator will call this function to feed the output tokens to the penalizer.
"""
pass
@abc.abstractmethod
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
"""
Apply the penalizer to the logits.
Penalizers can modify the logits in-place if needed.
"""
pass
@abc.abstractmethod
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
"""
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
"""
pass
@abc.abstractmethod
def _merge(self, their: "_BatchedPenalizer"):
"""
Merge the penalizer with another penalizer.
"""
pass
import typing
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedFrequencyPenalizer(_BatchedPenalizer):
"""
Frequency penalizer penalizes tokens based on their frequency in the output.
"""
frequency_penalties: torch.Tensor = None
cumulated_frequency_penalties: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.frequency_penalty != 0.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_frequency_penalties = (
torch.tensor(
data=[0.0 for _ in self.orchestrator.reqs()],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
)
self.frequency_penalties = (
torch.tensor(
data=[
req.sampling_params.frequency_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.expand_as(self.cumulated_frequency_penalties)
)
def _teardown(self):
del self.frequency_penalties
del self.cumulated_frequency_penalties
self.frequency_penalties = None
self.cumulated_frequency_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
pass
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
self.cumulated_frequency_penalties += (
self.frequency_penalties * output_ids.occurrence_count()
)
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
logits -= self.cumulated_frequency_penalties
return logits
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
indices_tensor_to_keep
]
def _merge(self, their: "BatchedFrequencyPenalizer"):
self.frequency_penalties = torch.cat(
[self.frequency_penalties, their.frequency_penalties], dim=0
)
self.cumulated_frequency_penalties = torch.cat(
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
dim=0,
)
import typing
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
"""
Min new tokens penalizer penalizes tokens based on the length of the output.
"""
min_new_tokens: torch.Tensor = None
stop_token_penalties: torch.Tensor = None
len_output_tokens: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()
)
def _prepare(self):
self.min_new_tokens = torch.tensor(
data=[
req.sampling_params.min_new_tokens for req in self.orchestrator.reqs()
],
dtype=torch.int32,
device=self.orchestrator.device,
).unsqueeze_(1)
padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
sequences=[
torch.tensor(
data=(
list(
(req.sampling_params.stop_token_ids or set())
| (req.tokenizer.additional_stop_token_ids or set())
| {req.tokenizer.eos_token_id}
)
),
dtype=torch.int64,
device=self.orchestrator.device,
)
for req in self.orchestrator.reqs()
],
batch_first=True,
padding_value=self.orchestrator.vocab_size,
)
self.stop_token_penalties = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
dtype=torch.float32,
device=self.orchestrator.device,
).scatter_add_(
dim=1,
index=padded_stop_token_ids,
src=torch.full_like(
input=padded_stop_token_ids,
dtype=torch.float32,
fill_value=float("-inf"),
device=self.orchestrator.device,
),
)[
:, : self.orchestrator.vocab_size
]
self.len_output_tokens = torch.zeros(
size=(self.orchestrator.batch_size(), 1),
dtype=torch.int32,
device=self.orchestrator.device,
)
def _teardown(self):
del self.min_new_tokens
del self.stop_token_penalties
del self.len_output_tokens
self.min_new_tokens = None
self.stop_token_penalties = None
self.len_output_tokens = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
pass
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
self.len_output_tokens += 1
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
logits[mask] += self.stop_token_penalties[mask]
return logits
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
def _merge(self, their: "BatchedMinNewTokensPenalizer"):
self.min_new_tokens = torch.cat(
[self.min_new_tokens, their.min_new_tokens], dim=0
)
self.stop_token_penalties = torch.cat(
[self.stop_token_penalties, their.stop_token_penalties], dim=0
)
self.len_output_tokens = torch.cat(
[self.len_output_tokens, their.len_output_tokens], dim=0
)
import typing
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedPresencePenalizer(_BatchedPenalizer):
"""
Presence penalizer penalizes tokens based on their presence in the output.
"""
presence_penalties: torch.Tensor = None
cumulated_presence_penalties: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.presence_penalty != 0.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_presence_penalties = (
torch.tensor(
data=[0.0 for _ in self.orchestrator.reqs()],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
)
self.presence_penalties = (
torch.tensor(
data=[
req.sampling_params.presence_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.expand_as(self.cumulated_presence_penalties)
)
def _teardown(self):
del self.presence_penalties
del self.cumulated_presence_penalties
self.presence_penalties = None
self.cumulated_presence_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
pass
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
mask = output_ids.occurrence_count() > 0
self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
logits -= self.cumulated_presence_penalties
return logits
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
indices_tensor_to_keep
]
def _merge(self, their: "BatchedPresencePenalizer"):
self.presence_penalties = torch.cat(
[self.presence_penalties, their.presence_penalties], dim=0
)
self.cumulated_presence_penalties = torch.cat(
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
dim=0,
)
import typing
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedRepetitionPenalizer(_BatchedPenalizer):
"""
Repetition penalizer penalizes tokens based on their repetition in the input and output.
"""
repetition_penalties: torch.Tensor = None
cumulated_repetition_penalties: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.repetition_penalty != 1.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_repetition_penalties = (
torch.tensor(
data=[1.0 for _ in self.orchestrator.reqs()],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
)
self.repetition_penalties = (
torch.tensor(
data=[
req.sampling_params.repetition_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.expand_as(self.cumulated_repetition_penalties)
)
def _teardown(self):
del self.repetition_penalties
del self.cumulated_repetition_penalties
self.repetition_penalties = None
self.cumulated_repetition_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
mask = input_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
mask = output_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
return torch.where(
logits > 0,
logits / self.cumulated_repetition_penalties,
logits * self.cumulated_repetition_penalties,
)
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
indices_tensor_to_keep
]
def _merge(self, their: "BatchedRepetitionPenalizer"):
self.repetition_penalties = torch.cat(
[self.repetition_penalties, their.repetition_penalties], dim=0
)
self.cumulated_repetition_penalties = torch.cat(
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
dim=0,
)
'''
Date: 2024-11-14 12:23:45
LastEditors: Xie Weiyu ervinxie@qq.com
LastEditTime: 2024-11-25 08:59:23
'''
import logging
import torch
from torch import nn
from transformers import GenerationConfig
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_probs,
top_k_top_p_sampling_from_logits,
top_p_renorm_probs,
)
logger = logging.getLogger(__name__)
class SamplingOptions():
# Batched sampling params
temperatures: torch.Tensor
top_ps: torch.Tensor
top_ks: torch.Tensor
min_ps: torch.Tensor
# All requests use greedy sampling
is_all_greedy: bool
# Dispatch in CUDA graph
need_min_p_sampling: bool
def __init__(self, bsz = 1, device = torch.device('cuda'), pretrained_config:GenerationConfig = None, temperatures: torch.Tensor = None, top_ps: torch.Tensor = None):
if pretrained_config is None and temperatures is None:
self.temperatures = torch.full((bsz, 1), 0, device=device, dtype=torch.float32)
self.top_ps = torch.ones((bsz, 1), device=device, dtype=torch.float32)
self.top_ks = torch.ones((bsz, 1), device=device, dtype=torch.float32)
self.need_min_p_sampling = False
self.is_all_greedy = True
else:
if temperatures is not None:
self.temperatures = temperatures.unsqueeze(-1)
else:
self.temperatures = torch.full((bsz, 1), pretrained_config.temperature, device=device, dtype=torch.float32)
if top_ps is not None:
self.top_ps = top_ps.unsqueeze(-1)
else:
self.top_ps = torch.full((bsz, 1), pretrained_config.top_p, device=device, dtype=torch.float32)
self.top_ks = torch.full((bsz, 1), pretrained_config.top_k, device=device, dtype=torch.float32)
self.need_min_p_sampling = False
self.is_all_greedy = False
class Sampler(nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
logits: torch.Tensor,
sampling_config: SamplingOptions = None,
):
if sampling_config == None:
sampling_config = SamplingOptions()
logits = logits.contiguous()
origin_logits = logits.clone()
if sampling_config.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
probs = logits
batch_next_token_ids = torch.argmax(logits, -1)
else:
# Post process logits
logits.div_(sampling_config.temperatures)
max_top_k_round, batch_size = 32, logits.shape[0]
if sampling_config.need_min_p_sampling:
probs = torch.softmax(logits, dim=-1)
logits = None
del logits
probs = top_k_renorm_probs(probs, sampling_config.top_ks)
probs = top_p_renorm_probs(probs, sampling_config.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, sampling_config.min_ps
)
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
else:
# TODO: use different kernel when don't need top_k or top_p
# @TODO get probs
probs = logits
batch_next_token_ids = top_k_top_p_sampling_from_logits(
logits,
sampling_config.top_ks,
sampling_config.top_ps,
filter_apply_order="joint",
)
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
return batch_next_token_ids.to(torch.int32), probs
\ No newline at end of file
from datetime import datetime
import os
from typing import Optional
import zmq
import pickle
import threading
import torch.multiprocessing as mp
import sys
current_file_path = os.path.abspath(__file__)
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
import pickle
import argparse
from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings
if mp.get_start_method(allow_none=True) is None:
print('set start method')
mp.set_start_method('spawn')
else:
print(f'start method already set to {mp.get_start_method(allow_none=True)}')
class SchedulerServer:
def __init__(self, settings, main_args):
# 创建 Scheduler 实例并初始化
self.sched = sched_ext.create_scheduler(settings)
# 初始化 ZeroMQ 上下文和套接字
self.context = zmq.Context()
self.frontend = self.context.socket(zmq.ROUTER)
print(f"sched zmq rpc server on port {main_args.sched_port}")
self.frontend.bind(f"tcp://*:{main_args.sched_port}")
# 创建内部的 DEALER 套接字,用于与工作线程通信
self.backend = self.context.socket(zmq.DEALER)
self.backend.bind("inproc://backend")
# 启动调度器
def run_scheduler(self):
self.sched.run()
# 停止调度器
def stop_scheduler(self):
self.sched.stop()
# 处理客户端请求
def start_proxy(self):
# 使用 ZMQ 的内置代理,将前端请求分发给后端工作线程
zmq.proxy(self.frontend, self.backend)
# 工作线程处理请求
def worker_routine(self):
worker = self.context.socket(zmq.REP)
worker.connect("inproc://backend")
while True:
try:
# 接收客户端请求
message = worker.recv()
data = pickle.loads(message)
method = data.get('method')
params = data.get('params', {})
# print(f"Received request: {method}")
if method == 'add_query':
query_add = params.get('query') # 直接是一个 QueryAdd 对象
# 添加查询
query_id = self.sched.add_query(query_add)
# 发送响应
response = {'status': 'ok', 'query_id': query_id}
worker.send(pickle.dumps(response))
elif method == 'cancel_query':
query_id = params.get('query_id')
# 假设您的 Scheduler 类实现了 cancel 方法
self.sched.cancel(query_id)
response = {'status': 'ok'}
worker.send(pickle.dumps(response))
elif method == 'update_last_batch':
updates = params.get('updates') # 直接是一个列表,包含 QueryUpdate 对象
# 更新最后一个批次
batch_todo = self.sched.update_last_batch(updates)
# 直接发送 batch_todo 对象
response = {'status': 'ok', 'batch_todo': batch_todo}
# print (batch_todo.query_lengths, batch_todo.query_ids)
worker.send(pickle.dumps(response))
elif method == 'get_inference_context':
inference_context = self.sched.get_inference_context()
data = {
"k_cache":inference_context.k_cache,
"v_cache":inference_context.v_cache
}
print(f"Serializing KVCache")
data["k_cache"] = [mp.reductions.reduce_tensor(t) for t in data['k_cache']]
data["v_cache"] = [mp.reductions.reduce_tensor(t) for t in data['v_cache']]
# print(data)
response = {'status': 'ok', 'inference_context': data}
worker.send(pickle.dumps(response))
# response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1
# print("k_cache update")
else:
# 未知方法
response = {'status': 'error', 'message': 'Unknown method'}
worker.send(pickle.dumps(response))
except Exception as e:
# 处理异常并发送错误响应
response = {'status': 'error', 'message': str(e)}
worker.send(pickle.dumps(response))
# 启动 RPC 服务
def start_rpc_service(self):
try:
print("Scheduler RPC service is running...")
# 在单独的线程中运行调度器
threading.Thread(target=self.run_scheduler, daemon=True).start()
# 启动工作线程
for _ in range(10): # 根据需要调整线程数
threading.Thread(target=self.worker_routine, daemon=True).start()
# 启动代理,开始监听请求
self.start_proxy()
except KeyboardInterrupt:
print("Shutting down scheduler RPC service...")
self.stop_rpc_service()
# 停止 RPC 服务
def stop_rpc_service(self):
self.stop_scheduler()
self.frontend.close()
self.backend.close()
self.context.term()
def start_server(settings, main_args):
server = SchedulerServer(settings, main_args)
server.start_rpc_service()
# Add async client for webserver
class SchedulerClient:
def __init__(self, sched_port):
address=f'tcp://localhost:{sched_port}'
self.address = address
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(self.address)
print(f"Connected to server at {self.address}")
def __del__(self):
self.socket.close()
self.context.term()
def send_request(self, method, params=None):
if params is None:
params = {}
request = {
'method': method,
'params': params
}
# print(f'send request {request}')
self.socket.send(pickle.dumps(request))
response = self.socket.recv()
# print(response)
response = pickle.loads(response)
if response.get('status') == 'ok':
return response
else:
raise Exception(f"Error from server: {response.get('message')}")
def add_query(self, query):
response = self.send_request('add_query', {'query': query})
return response.get('query_id')
def cancel_query(self, query_id):
self.send_request('cancel_query', {'query_id': query_id})
def update_last_batch(self, updates):
response = self.send_request('update_last_batch', {'updates': updates})
# print(f"update_last_batch response {response}")
return response.get('batch_todo')
def rebuild_inferece_context(self,response):
data = response.get('inference_context')
inference_context = sched_ext.InferenceContext()
print('Rebuilding kvcache')
inference_context.k_cache = [fn(*args) for fn,args in data['k_cache']]
inference_context.v_cache = [fn(*args) for fn,args in data['v_cache']]
return inference_context
def get_inference_context_raw(self):
response = self.send_request('get_inference_context')
return response
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
with open(args.config, "rb") as f:
main_args = pickle.load(f)
settings = create_sched_settings(main_args)
start_server(settings, main_args)
'''
Date: 2024-11-13 09:43:39
LastEditors: djw
LastEditTime: 2024-11-18 16:41:03
'''
import sys, os
import yaml, json
from time import sleep
import sched_ext
from transformers import AutoConfig
def create_sched_settings(args):
default_sample_options = sched_ext.SampleOptions()
model_name = os.path.basename(os.path.normpath(args.model_dir))
input_model_settings = sched_ext.ModelSettings()
input_model_settings.model_path = args.model_dir
input_model_settings.params_count = int(0)
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
input_model_settings.layer_count = model_config.num_hidden_layers
input_model_settings.num_k_heads = 1 # model_config["num_key_value_heads"]
input_model_settings.k_head_dim = 576
input_model_settings.bytes_per_params = 2
input_model_settings.bytes_per_kv_cache_element = 2
settings = sched_ext.Settings()
settings.model_name = model_name
settings.quant_type = "BF16"
settings.model_settings = input_model_settings
settings.page_size = args.page_size
settings.gpu_device_count = 1 # tp
settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]
# settings.gpu_memory_size = args.cache_lens*576*2
settings.gpu_memory_size = args.gpu_memory_size
settings.memory_utilization_percentage = args.utilization_percentage
max_batch_size = args.max_batch_size
chunk_size = args.chunk_size
max_decode_batch_size = max_batch_size - 2
settings.max_batch_size = max_batch_size
settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2
settings.sample_options = default_sample_options
settings.sched_metrics_port = args.sched_metrics_port
settings.gpu_only = args.memory_gpu_only
settings.use_self_defined_head_dim = True
settings.self_defined_head_dim = 576
settings.full_kv_cache_on_each_gpu = True
settings.k_cache_on = True
settings.v_cache_on = False
settings.kvc2_root_path = '/mnt/data/persist-kvc'
settings.kvc2_config_path = args.kvc2_config_dir
settings.memory_pool_size_GB = args.cpu_memory_size_GB
settings.evict_count = 40
settings.kvc2_metrics_port = args.kvc2_metrics_port
settings.load_from_disk = False
settings.save_to_disk = True
settings.strategy_name = args.sched_strategy
settings.auto_derive()
return settings
...@@ -11,6 +11,7 @@ LastEditTime : 2024-08-12 06:31:14 ...@@ -11,6 +11,7 @@ LastEditTime : 2024-08-12 06:31:14
import os import os
import shutil import shutil
import yaml import yaml
import psutil
from ktransformers.server.config.singleton import Singleton from ktransformers.server.config.singleton import Singleton
from typing import Optional from typing import Optional
...@@ -33,12 +34,15 @@ class Config(metaclass=Singleton): ...@@ -33,12 +34,15 @@ class Config(metaclass=Singleton):
user_path: str = os.path.expanduser("~") user_path: str = os.path.expanduser("~")
localstore_path: str = os.path.join(user_path, ".ktransformers") localstore_path: str = os.path.join(user_path, ".ktransformers")
kvc2_config_dir = os.path.join(localstore_path, "kvc2")
config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME) config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME)
if not os.path.exists(config_yaml): if not os.path.exists(config_yaml):
print(f"Can't find config file, {config_yaml}") print(f"Can't find config file, {config_yaml}")
exit(-1) exit(-1)
if not os.path.exists(localstore_path): if not os.path.exists(localstore_path):
os.mkdir(localstore_path) os.mkdir(localstore_path)
if not os.path.exists(kvc2_config_dir):
os.mkdir(kvc2_config_dir)
if not os.path.exists(config_path): if not os.path.exists(config_path):
shutil.copyfile(config_yaml, config_path) shutil.copyfile(config_yaml, config_path)
with open(config_path, "r", encoding="utf-8") as fp: with open(config_path, "r", encoding="utf-8") as fp:
...@@ -60,11 +64,14 @@ class Config(metaclass=Singleton): ...@@ -60,11 +64,14 @@ class Config(metaclass=Singleton):
self.user_path: str = os.path.expanduser("~") self.user_path: str = os.path.expanduser("~")
self.localstore_path: str = os.path.join(self.user_path, ".ktransformers") self.localstore_path: str = os.path.join(self.user_path, ".ktransformers")
# log configs # log configs
self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"])) self.log_dir = os.path.join(self.localstore_path, cfg["log"]["dir"])
if not os.path.exists(self.log_dir):
os.mkdir(self.log_dir)
self.log_file = cfg["log"]["file"] self.log_file = cfg["log"]["file"]
self.log_level = cfg["log"]["level"] self.log_level = cfg["log"]["level"]
self.backup_count = cfg["log"]["backup_count"] self.backup_count = cfg["log"]["backup_count"]
self.kvc2_config_dir = os.path.join(self.localstore_path, "kvc2")
# server configs # server configs
self.server: dict = cfg.get("server", {}) self.server: dict = cfg.get("server", {})
self.server_ip = self.server.get("ip", "0.0.0.0") self.server_ip = self.server.get("ip", "0.0.0.0")
...@@ -74,7 +81,7 @@ class Config(metaclass=Singleton): ...@@ -74,7 +81,7 @@ class Config(metaclass=Singleton):
# db configs # db configs
self.db_configs: dict = cfg.get("db", {}) self.db_configs: dict = cfg.get("db", {})
self.db_type = self.db_configs.get("type", "") self.db_type = self.db_configs.get("type", "")
self.db_host = os.path.join(self.base_path, self.db_configs.get("host", "")) self.db_host = self.localstore_path
self.db_port = self.db_configs.get("port", "") self.db_port = self.db_configs.get("port", "")
self.db_name = self.db_configs.get("database", "") self.db_name = self.db_configs.get("database", "")
self.db_pool_size = self.db_configs.get("pool_size") self.db_pool_size = self.db_configs.get("pool_size")
...@@ -101,11 +108,6 @@ class Config(metaclass=Singleton): ...@@ -101,11 +108,6 @@ class Config(metaclass=Singleton):
self.optimize_config_path: Optional[str] = self.model.get( self.optimize_config_path: Optional[str] = self.model.get(
"optimize_config_path", None "optimize_config_path", None
) )
self.paged = self.model.get("paged", True)
self.total_context = self.model.get("total_context", 2**18)
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
self.chunk_prefill_size = self.model.get("chunk_prefill_size", 8192)
self.max_new_tokens = self.model.get("max_new_tokens", 2000) self.max_new_tokens = self.model.get("max_new_tokens", 2000)
self.json_mode = self.model.get("json_mode", False) self.json_mode = self.model.get("json_mode", False)
...@@ -138,7 +140,6 @@ class Config(metaclass=Singleton): ...@@ -138,7 +140,6 @@ class Config(metaclass=Singleton):
self.repetition_penalty = self.model.get("repetition_penalty", 1.01) self.repetition_penalty = self.model.get("repetition_penalty", 1.01)
self.frequency_penalty = self.model.get("frequency_penalty", 0.0) self.frequency_penalty = self.model.get("frequency_penalty", 0.0)
self.presence_penalty = self.model.get("presence_penalty", 0.0) self.presence_penalty = self.model.get("presence_penalty", 0.0)
self.max_response_tokens = self.model.get("max_response_tokens", 300)
self.response_chunk = self.model.get("response_chunk", 250) self.response_chunk = self.model.get("response_chunk", 250)
self.no_code_formatting = self.model.get("no_code_formatting", False) self.no_code_formatting = self.model.get("no_code_formatting", False)
self.cache_8bit = self.model.get("cache_8bit", False) self.cache_8bit = self.model.get("cache_8bit", False)
...@@ -155,8 +156,9 @@ class Config(metaclass=Singleton): ...@@ -155,8 +156,9 @@ class Config(metaclass=Singleton):
self.web_cross_domain: bool = self.web.get("open_cross_domain", True) self.web_cross_domain: bool = self.web.get("open_cross_domain", True)
self.mount_web: bool = self.web.get("mount", False) self.mount_web: bool = self.web.get("mount", False)
# ext
self.ext: dict = cfg.get("ext", {}) self.ext: dict = cfg.get("ext", {})
self.cpu_infer = self.ext.get("cpu_infer", 10) self.cpu_infer = psutil.cpu_count(logical=False) - 3
# file config # file config
self.local_store_configs: dict = cfg.get("local_store", {}) self.local_store_configs: dict = cfg.get("local_store", {})
...@@ -169,7 +171,6 @@ class Config(metaclass=Singleton): ...@@ -169,7 +171,6 @@ class Config(metaclass=Singleton):
# long context config # long context config
self.long_context_config: dict = cfg.get("long_context", {}) self.long_context_config: dict = cfg.get("long_context", {})
self.chunk_size = self.long_context_config.get("chunk_size", 4096)
self.max_seq_len = self.long_context_config.get("max_seq_len", 32000) self.max_seq_len = self.long_context_config.get("max_seq_len", 32000)
self.block_size = self.long_context_config.get("block_size", 128) self.block_size = self.long_context_config.get("block_size", 128)
self.local_windows_len = self.long_context_config.get("local_windows_len", 4096) self.local_windows_len = self.long_context_config.get("local_windows_len", 4096)
...@@ -187,3 +188,21 @@ class Config(metaclass=Singleton): ...@@ -187,3 +188,21 @@ class Config(metaclass=Singleton):
# local chat # local chat
self.local_chat_config: dict = cfg.get("local_chat", {}) self.local_chat_config: dict = cfg.get("local_chat", {})
self.prompt_file = self.local_chat_config.get("prompt_file", None) self.prompt_file = self.local_chat_config.get("prompt_file", None)
# asyncserver
self.sched_strategy = cfg["async_server"]["sched_strategy"]
self.sched_port = cfg["async_server"]["sched_port"]
self.sched_metrics_port = cfg["async_server"]["sched_metrics_port"]
self.kvc2_metrics_port = cfg["async_server"]["kvc2_metrics_port"]
self.max_batch_size = cfg["async_server"]["max_batch_size"]
self.page_size = cfg["attn"]["page_size"]
self.chunk_size = cfg["attn"]["chunk_size"]
self.memory_gpu_only = cfg["kvc2"]["gpu_only"]
self.cache_lens = ((self.cache_lens + self.page_size - 1) // self.page_size) * self.page_size
self.gpu_memory_size = 2*576*61*self.cache_lens
self.utilization_percentage = 1.0 #cfg["kvc2"]["utilization_percentage"]
self.cpu_memory_size_GB = cfg["kvc2"]["cpu_memory_size_GB"]
# only support 2 prefill task
self.max_prefill_batch_size = 2
self.max_decode_batch_size = self.max_batch_size - self.max_prefill_batch_size
...@@ -5,24 +5,20 @@ from fastapi.staticfiles import StaticFiles ...@@ -5,24 +5,20 @@ from fastapi.staticfiles import StaticFiles
import uvicorn.logging import uvicorn.logging
import uvicorn import uvicorn
import sys import sys
import atexit
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, project_dir)
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.args import ArgumentParser from ktransformers.server.args import ArgumentParser
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import create_interface from ktransformers.server.utils.create_interface import create_interface, GlobalInterface
from ktransformers.server.backend.args import default_args
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.api import router, post_db_creation_operations from ktransformers.server.api import router, post_db_creation_operations
from ktransformers.server.utils.sql_utils import Base, SQLUtil from ktransformers.server.utils.sql_utils import Base, SQLUtil
from ktransformers.server.config.log import logger from ktransformers.server.config.log import logger
import subprocess
import tempfile
def mount_app_routes(mount_app: FastAPI): def mount_app_routes(mount_app: FastAPI):
sql_util = SQLUtil() sql_util = SQLUtil()
...@@ -34,7 +30,10 @@ def mount_app_routes(mount_app: FastAPI): ...@@ -34,7 +30,10 @@ def mount_app_routes(mount_app: FastAPI):
def create_app(): def create_app():
cfg = Config() cfg = Config()
app = FastAPI() if(hasattr(GlobalInterface.interface, "lifespan")):
app = FastAPI(lifespan=GlobalInterface.interface.lifespan)
else:
app = FastAPI()
if Config().web_cross_domain: if Config().web_cross_domain:
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
...@@ -108,11 +107,32 @@ def main(): ...@@ -108,11 +107,32 @@ def main():
arg_parser = ArgumentParser(cfg) arg_parser = ArgumentParser(cfg)
# 初始化消息
args = arg_parser.parse_args() args = arg_parser.parse_args()
if args.backend_type == "balance_serve":
import pickle
def cleanup():
if sched_process.poll() is None:
sched_process.terminate()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
pickle.dump(args, temp_file)
temp_file_path = temp_file.name
current_file = __file__
target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py")
target_file = os.path.normpath(target_file)
log_path = os.path.join(args.log_dir, "rpc.log")
log = open(log_path, "a")
sched_process = subprocess.Popen(
["python3", target_file, "--config", temp_file_path],
stdout=log,
stderr=log
)
print("sched_rpc started with PID:", sched_process.pid)
atexit.register(cleanup)
create_interface(config=cfg, default_args=cfg)
app = create_app() app = create_app()
custom_openapi(app) custom_openapi(app)
create_interface(config=cfg, default_args=cfg)
run_api( run_api(
app=app, app=app,
host=args.host, host=args.host,
...@@ -121,6 +141,5 @@ def main(): ...@@ -121,6 +141,5 @@ def main():
ssl_certfile=args.ssl_certfile, ssl_certfile=args.ssl_certfile,
) )
if __name__ == "__main__": if __name__ == "__main__":
main() main()
torch >= 2.3.0,<=2.3.1 torch >= 2.3.0
transformers == 4.43.2 transformers == 4.43.2
fastapi >= 0.111.0 fastapi >= 0.111.0
langchain >= 0.2.0 langchain >= 0.2.0
...@@ -11,4 +11,6 @@ build ...@@ -11,4 +11,6 @@ build
ninja ninja
wheel wheel
colorlog colorlog
fire fire
\ No newline at end of file zmq
psutil
\ No newline at end of file
...@@ -2,7 +2,7 @@ from typing import List, Optional ...@@ -2,7 +2,7 @@ from typing import List, Optional
from typing_extensions import Literal from typing_extensions import Literal
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel, Field
from ktransformers.server.schemas.base import Object from ktransformers.server.schemas.base import Object
...@@ -30,8 +30,8 @@ class ChatCompletionCreate(BaseModel): ...@@ -30,8 +30,8 @@ class ChatCompletionCreate(BaseModel):
messages: List[Message] messages: List[Message]
model : str model : str
stream : bool = False stream : bool = False
temperature: Optional[float] = None temperature: Optional[float] = Field(default=1.0)
top_p: Optional[float] = None top_p: Optional[float] = Field(default=1.0)
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]
......
...@@ -15,6 +15,7 @@ from ktransformers.server.backend.context_manager import ThreadContextManager ...@@ -15,6 +15,7 @@ from ktransformers.server.backend.context_manager import ThreadContextManager
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
def create_interface(config: Config, default_args: ConfigArgs): def create_interface(config: Config, default_args: ConfigArgs):
if config.backend_type=='transformers': if config.backend_type=='transformers':
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
...@@ -22,6 +23,8 @@ def create_interface(config: Config, default_args: ConfigArgs): ...@@ -22,6 +23,8 @@ def create_interface(config: Config, default_args: ConfigArgs):
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
elif config.backend_type == 'ktransformers': elif config.backend_type == 'ktransformers':
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
elif config.backend_type == 'balance_serve':
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface
else: else:
raise NotImplementedError(f'{config.backend_type} not implemented') raise NotImplementedError(f'{config.backend_type} not implemented')
GlobalInterface.interface = BackendInterface(default_args) GlobalInterface.interface = BackendInterface(default_args)
...@@ -30,9 +33,9 @@ def create_interface(config: Config, default_args: ConfigArgs): ...@@ -30,9 +33,9 @@ def create_interface(config: Config, default_args: ConfigArgs):
class GlobalContextManager: class GlobalContextManager:
context_manager: ThreadContextManager context_manager: ThreadContextManager
class GlobalInterface: class GlobalInterface:
interface: TransformersInterface | KTransformersInterface | ExllamaInterface interface: TransformersInterface | KTransformersInterface | ExllamaInterface
def get_thread_context_manager() -> ThreadContextManager: def get_thread_context_manager() -> GlobalContextManager:
return GlobalContextManager.context_manager return GlobalContextManager.context_manager
def get_interface() -> TransformersInterface | KTransformersInterface | ExllamaInterface: def get_interface() -> GlobalInterface:
return GlobalInterface.interface return GlobalInterface.interface
\ No newline at end of file
import argparse
import random
import time
import json
import requests
import pandas as pd
from datasets import load_dataset
import os
import concurrent.futures
import threading
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['https_proxy'] = ''
os.environ['http_proxy'] = ''
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
class DataEvaluator:
def __init__(self):
self.data = []
def load_data(self, file_path):
"""
从数据文件中加载数据,每条记录对应一个实例
"""
ds = load_dataset(file_path, "all")
df = pd.DataFrame(ds['test'])
for _, row in df.iterrows():
self.data.append(row.to_dict())
def get_prompt(self, record):
"""
结合提示信息和记录数据生成完整的题目
"""
options_str = "\n".join([f"{chr(65 + i)}. {opt}" for i, opt in enumerate(record['choices'])])
prompt = hint + "\nQuestion: " + record['question'] + "\n" + options_str + "\nAnswer: '"
return prompt
def post_processing(self, text):
"""
对生成的文本进行后处理,提取最终答案(只返回最后一个字符)
"""
text = text.lstrip('\n').split('\n')[-1]
return text[-1:]
def score(self, pred, answer):
"""
对比预测答案和正确答案,返回得分
"""
if pred == answer:
return 1
return 0
def generate_text(api_url, question, model_name, stream=False):
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': 'Bearer ' # 如有需要,请填入 API Key
}
data = {
"messages": [{"content": question, "role": "user"}],
"model": model_name,
"stream": stream,
}
print("POST data:", data)
response = requests.post(api_url, headers=headers, json=data, timeout=5000000)
if response.status_code == 200:
result = response.json()
return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip()
else:
print(f"API Request failed with status code {response.status_code}")
return None
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
start_total_time = time.time()
total_score = 0
results = []
file_lock = threading.Lock()
# 打乱数据顺序,并选择需要测试的实例数
random.seed(42)
random.shuffle(data_evaluator.data)
data_subset = data_evaluator.data[:min(concurrent_requests, len(data_evaluator.data))]
batch_size = 10 # 每批次最多 10 个实例
def worker(index, data_item):
nonlocal total_score
question = data_evaluator.get_prompt(data_item)
start_time = time.time()
try:
prediction = generate_text(api_url, question, model_name)
if prediction is None:
raise Exception(f"Failed to get prediction for question: {question}")
# 正确答案:将数字转换成字母(0->A, 1->B, 2->C, 3->D)
answer = chr(data_item['answer'] + 65)
processed_prediction = data_evaluator.post_processing(prediction)
score = data_evaluator.score(processed_prediction, answer)
elapsed_time = time.time() - start_time
result_data = {
"question_id": index,
"answer": answer,
"prediction": processed_prediction,
"real_prediction": prediction,
"score": score,
"time": elapsed_time
}
# 写入结果时加锁保证线程安全
with file_lock:
with open(result_file, 'a', encoding='utf-8') as f:
json.dump(result_data, f, ensure_ascii=False, indent=4)
f.write("\n")
return result_data
except Exception as e:
print(f"Error processing request {index}: {e}")
return None
# 按批次处理,每批最多 10 个任务
for batch_start in range(0, len(data_subset), batch_size):
batch = data_subset[batch_start: batch_start + batch_size]
with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
futures = [executor.submit(worker, batch_start + j, data_item) for j, data_item in enumerate(batch)]
for future in concurrent.futures.as_completed(futures):
res = future.result()
if res is not None:
results.append(res)
total_score += res['score']
total_time = time.time() - start_total_time
throughput = len(data_subset) / total_time if total_time > 0 else 0
with open(log_file, 'a', encoding='utf-8') as log_f:
log_f.write(f"Total Time: {total_time:.2f} seconds\n")
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
average_score = total_score / len(data_subset) if data_subset else 0
log_f.write(f"Average Score: {average_score}\n")
log_f.write('-' * 40 + '\n')
print(f"Results saved to {result_file}")
print(f"Log saved to {log_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="API Generate Tester")
parser.add_argument("--concurrent", type=int, default=1000, help="需要测试的实例总数")
parser.add_argument("--file", type=str, default="cais/mmlu", help="数据文件路径")
parser.add_argument("--result", type=str, default="./mmlu_result_silicon.json", help="结果文件保存路径")
parser.add_argument("--log", type=str, default="./mmlu_result_silicon.log", help="日志文件保存路径")
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="模型名称或路径")
parser.add_argument("--api_url", type=str, default="http://localhost:10006/v1/chat/completions", help="API URL")
args = parser.parse_args()
data_evaluator = DataEvaluator()
data_evaluator.load_data(args.file)
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
import asyncio
import json
import sys
import aiohttp
import random
import argparse
import yaml
import os
import time
from time import sleep
decodesz = 128
# Server URL (replace with your server URL)
SERVER_URL = "http://localhost:10002/v1/chat/completions"
bf_list = [1]
decodesz_list = [128]
prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke ']
async def fetch_event_stream(session, request_id):
try:
payload = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": prompt_list[request_id]}
],
"model": "DeepSeek-V3",
"temperature": 0.3,
"top_p": 1.0,
"stream": True # 开启流式输出
}
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
}
async with session.post(SERVER_URL, json=payload, headers=headers, timeout=50000) as response:
print(f"Request {request_id}: Connected, status {response.status}")
if response.status != 200:
print(f"Request {request_id}: Error, status {response.status}")
return
output_text = "" # 存储当前 response 的所有 token
total_tokens = 0 # 统计总 tokens 数
decode_start_time = None # 记录 decode 阶段开始时间
decode_end_time = None # 记录 decode 结束时间
async for line in response.content:
try:
decoded_line = line.decode("utf-8").strip()
# 过滤空行
if not decoded_line or not decoded_line.startswith("data: "):
continue
decoded_line = decoded_line[6:].strip() # 去掉 `data: `
# 确保 JSON 数据是合法的
if not decoded_line:
continue
response_data = json.loads(decoded_line) # 解析 JSON
# 确保 choices 存在
choices = response_data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
token = delta.get("content", "")
if token:
if decode_start_time is None:
decode_start_time = time.time() # 记录 decode 开始时间
output_text += token # 追加 token
sys.stdout.write(token) # 直接输出 token
sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端
total_tokens += 1 # 增加 token 计数
decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间
# 检查是否完成
finish_reason = choices[0].get("finish_reason", None)
if finish_reason:
# print(f"\nRequest {request_id}: Done")
break # 结束流式处理
except json.JSONDecodeError as e:
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
except IndexError:
print(f"\nRequest {request_id}: List Index Error - choices is empty")
except Exception as e:
print(f"\nRequest {request_id}: Error parsing stream - {e}")
# 计算 decode 速度
if decode_start_time and decode_end_time and total_tokens > 0:
decode_time = decode_end_time - decode_start_time
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
except Exception as e:
print(f"\nRequest {request_id}: Exception - {e}")
async def main(prompt_id):
async with aiohttp.ClientSession() as session:
tasks = [fetch_event_stream(session, prompt_id)]
await asyncio.gather(*tasks)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--question_id", type=int, default=0, required=False)
args = parser.parse_args()
output_file = "ktransformer_test_results.txt"
asyncio.run(main(args.question_id))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment