Commit 25cee581 authored by Atream's avatar Atream
Browse files

add balance-serve, support concurrence

parent 8d0292aa
import typing
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedRepetitionPenalizer(_BatchedPenalizer):
"""
Repetition penalizer penalizes tokens based on their repetition in the input and output.
"""
repetition_penalties: torch.Tensor = None
cumulated_repetition_penalties: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.repetition_penalty != 1.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_repetition_penalties = (
torch.tensor(
data=[1.0 for _ in self.orchestrator.reqs()],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
)
self.repetition_penalties = (
torch.tensor(
data=[
req.sampling_params.repetition_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.expand_as(self.cumulated_repetition_penalties)
)
def _teardown(self):
del self.repetition_penalties
del self.cumulated_repetition_penalties
self.repetition_penalties = None
self.cumulated_repetition_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
mask = input_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
mask = output_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
return torch.where(
logits > 0,
logits / self.cumulated_repetition_penalties,
logits * self.cumulated_repetition_penalties,
)
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
indices_tensor_to_keep
]
def _merge(self, their: "BatchedRepetitionPenalizer"):
self.repetition_penalties = torch.cat(
[self.repetition_penalties, their.repetition_penalties], dim=0
)
self.cumulated_repetition_penalties = torch.cat(
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
dim=0,
)
'''
Date: 2024-11-14 12:23:45
LastEditors: Xie Weiyu ervinxie@qq.com
LastEditTime: 2024-11-25 08:59:23
'''
import logging
import torch
from torch import nn
from transformers import GenerationConfig
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_probs,
top_k_top_p_sampling_from_logits,
top_p_renorm_probs,
)
logger = logging.getLogger(__name__)
class SamplingOptions():
# Batched sampling params
temperatures: torch.Tensor
top_ps: torch.Tensor
top_ks: torch.Tensor
min_ps: torch.Tensor
# All requests use greedy sampling
is_all_greedy: bool
# Dispatch in CUDA graph
need_min_p_sampling: bool
def __init__(self, bsz = 1, device = torch.device('cuda'), pretrained_config:GenerationConfig = None, temperatures: torch.Tensor = None, top_ps: torch.Tensor = None):
if pretrained_config is None and temperatures is None:
self.temperatures = torch.full((bsz, 1), 0, device=device, dtype=torch.float32)
self.top_ps = torch.ones((bsz, 1), device=device, dtype=torch.float32)
self.top_ks = torch.ones((bsz, 1), device=device, dtype=torch.float32)
self.need_min_p_sampling = False
self.is_all_greedy = True
else:
if temperatures is not None:
self.temperatures = temperatures.unsqueeze(-1)
else:
self.temperatures = torch.full((bsz, 1), pretrained_config.temperature, device=device, dtype=torch.float32)
if top_ps is not None:
self.top_ps = top_ps.unsqueeze(-1)
else:
self.top_ps = torch.full((bsz, 1), pretrained_config.top_p, device=device, dtype=torch.float32)
self.top_ks = torch.full((bsz, 1), pretrained_config.top_k, device=device, dtype=torch.float32)
self.need_min_p_sampling = False
self.is_all_greedy = False
class Sampler(nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
logits: torch.Tensor,
sampling_config: SamplingOptions = None,
):
if sampling_config == None:
sampling_config = SamplingOptions()
logits = logits.contiguous()
origin_logits = logits.clone()
if sampling_config.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
probs = logits
batch_next_token_ids = torch.argmax(logits, -1)
else:
# Post process logits
logits.div_(sampling_config.temperatures)
max_top_k_round, batch_size = 32, logits.shape[0]
if sampling_config.need_min_p_sampling:
probs = torch.softmax(logits, dim=-1)
logits = None
del logits
probs = top_k_renorm_probs(probs, sampling_config.top_ks)
probs = top_p_renorm_probs(probs, sampling_config.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, sampling_config.min_ps
)
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
else:
# TODO: use different kernel when don't need top_k or top_p
# @TODO get probs
probs = logits
batch_next_token_ids = top_k_top_p_sampling_from_logits(
logits,
sampling_config.top_ks,
sampling_config.top_ps,
filter_apply_order="joint",
)
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
return batch_next_token_ids.to(torch.int32), probs
\ No newline at end of file
from datetime import datetime
import os
from typing import Optional
import zmq
import pickle
import threading
import torch.multiprocessing as mp
import sys
current_file_path = os.path.abspath(__file__)
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
import pickle
import argparse
from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings
if mp.get_start_method(allow_none=True) is None:
print('set start method')
mp.set_start_method('spawn')
else:
print(f'start method already set to {mp.get_start_method(allow_none=True)}')
class SchedulerServer:
def __init__(self, settings, main_args):
# 创建 Scheduler 实例并初始化
self.sched = sched_ext.create_scheduler(settings)
# 初始化 ZeroMQ 上下文和套接字
self.context = zmq.Context()
self.frontend = self.context.socket(zmq.ROUTER)
print(f"sched zmq rpc server on port {main_args.sched_port}")
self.frontend.bind(f"tcp://*:{main_args.sched_port}")
# 创建内部的 DEALER 套接字,用于与工作线程通信
self.backend = self.context.socket(zmq.DEALER)
self.backend.bind("inproc://backend")
# 启动调度器
def run_scheduler(self):
self.sched.run()
# 停止调度器
def stop_scheduler(self):
self.sched.stop()
# 处理客户端请求
def start_proxy(self):
# 使用 ZMQ 的内置代理,将前端请求分发给后端工作线程
zmq.proxy(self.frontend, self.backend)
# 工作线程处理请求
def worker_routine(self):
worker = self.context.socket(zmq.REP)
worker.connect("inproc://backend")
while True:
try:
# 接收客户端请求
message = worker.recv()
data = pickle.loads(message)
method = data.get('method')
params = data.get('params', {})
# print(f"Received request: {method}")
if method == 'add_query':
query_add = params.get('query') # 直接是一个 QueryAdd 对象
# 添加查询
query_id = self.sched.add_query(query_add)
# 发送响应
response = {'status': 'ok', 'query_id': query_id}
worker.send(pickle.dumps(response))
elif method == 'cancel_query':
query_id = params.get('query_id')
# 假设您的 Scheduler 类实现了 cancel 方法
self.sched.cancel(query_id)
response = {'status': 'ok'}
worker.send(pickle.dumps(response))
elif method == 'update_last_batch':
updates = params.get('updates') # 直接是一个列表,包含 QueryUpdate 对象
# 更新最后一个批次
batch_todo = self.sched.update_last_batch(updates)
# 直接发送 batch_todo 对象
response = {'status': 'ok', 'batch_todo': batch_todo}
# print (batch_todo.query_lengths, batch_todo.query_ids)
worker.send(pickle.dumps(response))
elif method == 'get_inference_context':
inference_context = self.sched.get_inference_context()
data = {
"k_cache":inference_context.k_cache,
"v_cache":inference_context.v_cache
}
print(f"Serializing KVCache")
data["k_cache"] = [mp.reductions.reduce_tensor(t) for t in data['k_cache']]
data["v_cache"] = [mp.reductions.reduce_tensor(t) for t in data['v_cache']]
# print(data)
response = {'status': 'ok', 'inference_context': data}
worker.send(pickle.dumps(response))
# response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1
# print("k_cache update")
else:
# 未知方法
response = {'status': 'error', 'message': 'Unknown method'}
worker.send(pickle.dumps(response))
except Exception as e:
# 处理异常并发送错误响应
response = {'status': 'error', 'message': str(e)}
worker.send(pickle.dumps(response))
# 启动 RPC 服务
def start_rpc_service(self):
try:
print("Scheduler RPC service is running...")
# 在单独的线程中运行调度器
threading.Thread(target=self.run_scheduler, daemon=True).start()
# 启动工作线程
for _ in range(10): # 根据需要调整线程数
threading.Thread(target=self.worker_routine, daemon=True).start()
# 启动代理,开始监听请求
self.start_proxy()
except KeyboardInterrupt:
print("Shutting down scheduler RPC service...")
self.stop_rpc_service()
# 停止 RPC 服务
def stop_rpc_service(self):
self.stop_scheduler()
self.frontend.close()
self.backend.close()
self.context.term()
def start_server(settings, main_args):
server = SchedulerServer(settings, main_args)
server.start_rpc_service()
# Add async client for webserver
class SchedulerClient:
def __init__(self, sched_port):
address=f'tcp://localhost:{sched_port}'
self.address = address
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(self.address)
print(f"Connected to server at {self.address}")
def __del__(self):
self.socket.close()
self.context.term()
def send_request(self, method, params=None):
if params is None:
params = {}
request = {
'method': method,
'params': params
}
# print(f'send request {request}')
self.socket.send(pickle.dumps(request))
response = self.socket.recv()
# print(response)
response = pickle.loads(response)
if response.get('status') == 'ok':
return response
else:
raise Exception(f"Error from server: {response.get('message')}")
def add_query(self, query):
response = self.send_request('add_query', {'query': query})
return response.get('query_id')
def cancel_query(self, query_id):
self.send_request('cancel_query', {'query_id': query_id})
def update_last_batch(self, updates):
response = self.send_request('update_last_batch', {'updates': updates})
# print(f"update_last_batch response {response}")
return response.get('batch_todo')
def rebuild_inferece_context(self,response):
data = response.get('inference_context')
inference_context = sched_ext.InferenceContext()
print('Rebuilding kvcache')
inference_context.k_cache = [fn(*args) for fn,args in data['k_cache']]
inference_context.v_cache = [fn(*args) for fn,args in data['v_cache']]
return inference_context
def get_inference_context_raw(self):
response = self.send_request('get_inference_context')
return response
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
with open(args.config, "rb") as f:
main_args = pickle.load(f)
settings = create_sched_settings(main_args)
start_server(settings, main_args)
'''
Date: 2024-11-13 09:43:39
LastEditors: djw
LastEditTime: 2024-11-18 16:41:03
'''
import sys, os
import yaml, json
from time import sleep
current_dir = os.path.dirname(__file__)
# sched_path = os.path.abspath(os.path.join(current_dir, '../../../build/balance_serve/sched'))
# sys.path.insert(0, sched_path)
import sched_ext
from transformers import AutoConfig
def create_sched_settings(args):
default_sample_options = sched_ext.SampleOptions()
model_name = os.path.basename(os.path.normpath(args.model_dir))
input_model_settings = sched_ext.ModelSettings()
input_model_settings.model_path = args.model_dir
input_model_settings.params_count = int(0)
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
input_model_settings.layer_count = model_config.num_hidden_layers
input_model_settings.num_k_heads = 1 # model_config["num_key_value_heads"]
input_model_settings.k_head_dim = 576
input_model_settings.bytes_per_params = 2
input_model_settings.bytes_per_kv_cache_element = 2
settings = sched_ext.Settings()
settings.model_name = model_name
settings.quant_type = "BF16"
settings.model_settings = input_model_settings
settings.page_size = args.page_size
settings.gpu_device_count = 1 # tp
settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]
# settings.gpu_memory_size = args.cache_lens*576*2
settings.gpu_memory_size = args.gpu_memory_size
settings.memory_utilization_percentage = args.utilization_percentage
max_batch_size = args.max_batch_size
chunk_size = args.chunk_size
max_decode_batch_size = max_batch_size - 2
settings.max_batch_size = max_batch_size
settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2
settings.sample_options = default_sample_options
settings.sched_metrics_port = args.sched_metrics_port
settings.gpu_only = args.memory_gpu_only
settings.use_self_defined_head_dim = True
settings.self_defined_head_dim = 576
settings.full_kv_cache_on_each_gpu = True
settings.k_cache_on = True
settings.v_cache_on = False
settings.kvc2_root_path = '/mnt/data/persist-kvc'
settings.kvc2_config_path = os.path.join(current_dir, "..", "..", "configs")
print(os.path.join(current_dir, "..", "..", "configs"))
settings.memory_pool_size_GB = args.cpu_memory_size_GB
settings.evict_count = 40
settings.kvc2_metrics_port = args.kvc2_metrics_port
settings.load_from_disk = False
settings.save_to_disk = True
settings.strategy_name = args.sched_strategy
settings.auto_derive()
return settings
...@@ -11,6 +11,7 @@ LastEditTime : 2024-08-12 06:31:14 ...@@ -11,6 +11,7 @@ LastEditTime : 2024-08-12 06:31:14
import os import os
import shutil import shutil
import yaml import yaml
import psutil
from ktransformers.server.config.singleton import Singleton from ktransformers.server.config.singleton import Singleton
from typing import Optional from typing import Optional
...@@ -60,7 +61,7 @@ class Config(metaclass=Singleton): ...@@ -60,7 +61,7 @@ class Config(metaclass=Singleton):
self.user_path: str = os.path.expanduser("~") self.user_path: str = os.path.expanduser("~")
self.localstore_path: str = os.path.join(self.user_path, ".ktransformers") self.localstore_path: str = os.path.join(self.user_path, ".ktransformers")
# log configs # log configs
self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"])) self.log_dir = os.path.join(self.localstore_path, cfg["log"]["dir"])
self.log_file = cfg["log"]["file"] self.log_file = cfg["log"]["file"]
self.log_level = cfg["log"]["level"] self.log_level = cfg["log"]["level"]
self.backup_count = cfg["log"]["backup_count"] self.backup_count = cfg["log"]["backup_count"]
...@@ -74,7 +75,7 @@ class Config(metaclass=Singleton): ...@@ -74,7 +75,7 @@ class Config(metaclass=Singleton):
# db configs # db configs
self.db_configs: dict = cfg.get("db", {}) self.db_configs: dict = cfg.get("db", {})
self.db_type = self.db_configs.get("type", "") self.db_type = self.db_configs.get("type", "")
self.db_host = os.path.join(self.base_path, self.db_configs.get("host", "")) self.db_host = Config.to_path(self.db_configs.get("host", ""))
self.db_port = self.db_configs.get("port", "") self.db_port = self.db_configs.get("port", "")
self.db_name = self.db_configs.get("database", "") self.db_name = self.db_configs.get("database", "")
self.db_pool_size = self.db_configs.get("pool_size") self.db_pool_size = self.db_configs.get("pool_size")
...@@ -101,11 +102,6 @@ class Config(metaclass=Singleton): ...@@ -101,11 +102,6 @@ class Config(metaclass=Singleton):
self.optimize_config_path: Optional[str] = self.model.get( self.optimize_config_path: Optional[str] = self.model.get(
"optimize_config_path", None "optimize_config_path", None
) )
self.paged = self.model.get("paged", True)
self.total_context = self.model.get("total_context", 2**18)
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
self.chunk_prefill_size = self.model.get("chunk_prefill_size", 8192)
self.max_new_tokens = self.model.get("max_new_tokens", 2000) self.max_new_tokens = self.model.get("max_new_tokens", 2000)
self.json_mode = self.model.get("json_mode", False) self.json_mode = self.model.get("json_mode", False)
...@@ -138,7 +134,6 @@ class Config(metaclass=Singleton): ...@@ -138,7 +134,6 @@ class Config(metaclass=Singleton):
self.repetition_penalty = self.model.get("repetition_penalty", 1.01) self.repetition_penalty = self.model.get("repetition_penalty", 1.01)
self.frequency_penalty = self.model.get("frequency_penalty", 0.0) self.frequency_penalty = self.model.get("frequency_penalty", 0.0)
self.presence_penalty = self.model.get("presence_penalty", 0.0) self.presence_penalty = self.model.get("presence_penalty", 0.0)
self.max_response_tokens = self.model.get("max_response_tokens", 300)
self.response_chunk = self.model.get("response_chunk", 250) self.response_chunk = self.model.get("response_chunk", 250)
self.no_code_formatting = self.model.get("no_code_formatting", False) self.no_code_formatting = self.model.get("no_code_formatting", False)
self.cache_8bit = self.model.get("cache_8bit", False) self.cache_8bit = self.model.get("cache_8bit", False)
...@@ -155,8 +150,9 @@ class Config(metaclass=Singleton): ...@@ -155,8 +150,9 @@ class Config(metaclass=Singleton):
self.web_cross_domain: bool = self.web.get("open_cross_domain", True) self.web_cross_domain: bool = self.web.get("open_cross_domain", True)
self.mount_web: bool = self.web.get("mount", False) self.mount_web: bool = self.web.get("mount", False)
# ext
self.ext: dict = cfg.get("ext", {}) self.ext: dict = cfg.get("ext", {})
self.cpu_infer = self.ext.get("cpu_infer", 10) self.cpu_infer = psutil.cpu_count(logical=False) - 3
# file config # file config
self.local_store_configs: dict = cfg.get("local_store", {}) self.local_store_configs: dict = cfg.get("local_store", {})
...@@ -169,7 +165,6 @@ class Config(metaclass=Singleton): ...@@ -169,7 +165,6 @@ class Config(metaclass=Singleton):
# long context config # long context config
self.long_context_config: dict = cfg.get("long_context", {}) self.long_context_config: dict = cfg.get("long_context", {})
self.chunk_size = self.long_context_config.get("chunk_size", 4096)
self.max_seq_len = self.long_context_config.get("max_seq_len", 32000) self.max_seq_len = self.long_context_config.get("max_seq_len", 32000)
self.block_size = self.long_context_config.get("block_size", 128) self.block_size = self.long_context_config.get("block_size", 128)
self.local_windows_len = self.long_context_config.get("local_windows_len", 4096) self.local_windows_len = self.long_context_config.get("local_windows_len", 4096)
...@@ -187,3 +182,21 @@ class Config(metaclass=Singleton): ...@@ -187,3 +182,21 @@ class Config(metaclass=Singleton):
# local chat # local chat
self.local_chat_config: dict = cfg.get("local_chat", {}) self.local_chat_config: dict = cfg.get("local_chat", {})
self.prompt_file = self.local_chat_config.get("prompt_file", None) self.prompt_file = self.local_chat_config.get("prompt_file", None)
# asyncserver
self.sched_strategy = cfg['async_server']['sched_strategy']
self.sched_port = cfg['async_server']['sched_port']
self.sched_metrics_port = cfg['async_server']['sched_metrics_port']
self.kvc2_metrics_port = cfg['async_server']['kvc2_metrics_port']
self.max_batch_size = cfg['async_server']['max_batch_size']
self.page_size = cfg['attn']['page_size']
self.chunk_size = cfg['attn']['chunk_size']
self.memory_gpu_only = cfg['kvc2']['gpu_only']
self.cache_lens = ((self.cache_lens + self.page_size - 1) // self.page_size) * self.page_size
self.gpu_memory_size = 2*576*61*self.cache_lens
self.utilization_percentage = 1.0 #cfg['kvc2']['utilization_percentage']
self.cpu_memory_size_GB = cfg['kvc2']['cpu_memory_size_GB']
# only support 2 prefill task
self.max_prefill_batch_size = 2
self.max_decode_batch_size = self.max_batch_size - self.max_prefill_batch_size
...@@ -5,24 +5,20 @@ from fastapi.staticfiles import StaticFiles ...@@ -5,24 +5,20 @@ from fastapi.staticfiles import StaticFiles
import uvicorn.logging import uvicorn.logging
import uvicorn import uvicorn
import sys import sys
import atexit
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, project_dir)
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.args import ArgumentParser from ktransformers.server.args import ArgumentParser
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import create_interface from ktransformers.server.utils.create_interface import create_interface, GlobalInterface
from ktransformers.server.backend.args import default_args
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.api import router, post_db_creation_operations from ktransformers.server.api import router, post_db_creation_operations
from ktransformers.server.utils.sql_utils import Base, SQLUtil from ktransformers.server.utils.sql_utils import Base, SQLUtil
from ktransformers.server.config.log import logger from ktransformers.server.config.log import logger
import subprocess
import tempfile
def mount_app_routes(mount_app: FastAPI): def mount_app_routes(mount_app: FastAPI):
sql_util = SQLUtil() sql_util = SQLUtil()
...@@ -34,7 +30,10 @@ def mount_app_routes(mount_app: FastAPI): ...@@ -34,7 +30,10 @@ def mount_app_routes(mount_app: FastAPI):
def create_app(): def create_app():
cfg = Config() cfg = Config()
app = FastAPI() if(hasattr(GlobalInterface.interface, "lifespan")):
app = FastAPI(lifespan=GlobalInterface.interface.lifespan)
else:
app = FastAPI()
if Config().web_cross_domain: if Config().web_cross_domain:
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
...@@ -108,11 +107,32 @@ def main(): ...@@ -108,11 +107,32 @@ def main():
arg_parser = ArgumentParser(cfg) arg_parser = ArgumentParser(cfg)
# 初始化消息
args = arg_parser.parse_args() args = arg_parser.parse_args()
if args.backend_type == "balance_serve":
import pickle
def cleanup():
if sched_process.poll() is None:
sched_process.terminate()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
pickle.dump(args, temp_file)
temp_file_path = temp_file.name
current_file = __file__
target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py")
target_file = os.path.normpath(target_file)
log_path = os.path.join(args.log_dir, "rpc.log")
log = open(log_path, "a")
sched_process = subprocess.Popen(
["python3", target_file, "--config", temp_file_path],
stdout=log,
stderr=log
)
print("sched_rpc started with PID:", sched_process.pid)
atexit.register(cleanup)
create_interface(config=cfg, default_args=cfg)
app = create_app() app = create_app()
custom_openapi(app) custom_openapi(app)
create_interface(config=cfg, default_args=cfg)
run_api( run_api(
app=app, app=app,
host=args.host, host=args.host,
...@@ -121,6 +141,5 @@ def main(): ...@@ -121,6 +141,5 @@ def main():
ssl_certfile=args.ssl_certfile, ssl_certfile=args.ssl_certfile,
) )
if __name__ == "__main__": if __name__ == "__main__":
main() main()
torch >= 2.3.0,<=2.3.1 torch >= 2.3.0
transformers == 4.43.2 transformers == 4.43.2
fastapi >= 0.111.0 fastapi >= 0.111.0
langchain >= 0.2.0 langchain >= 0.2.0
...@@ -11,4 +11,6 @@ build ...@@ -11,4 +11,6 @@ build
ninja ninja
wheel wheel
colorlog colorlog
fire fire
\ No newline at end of file zmq
psutil
\ No newline at end of file
...@@ -2,7 +2,7 @@ from typing import List, Optional ...@@ -2,7 +2,7 @@ from typing import List, Optional
from typing_extensions import Literal from typing_extensions import Literal
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel, Field
from ktransformers.server.schemas.base import Object from ktransformers.server.schemas.base import Object
...@@ -30,8 +30,8 @@ class ChatCompletionCreate(BaseModel): ...@@ -30,8 +30,8 @@ class ChatCompletionCreate(BaseModel):
messages: List[Message] messages: List[Message]
model : str model : str
stream : bool = False stream : bool = False
temperature: Optional[float] = None temperature: Optional[float] = Field(default=1.0)
top_p: Optional[float] = None top_p: Optional[float] = Field(default=1.0)
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]
......
...@@ -15,6 +15,7 @@ from ktransformers.server.backend.context_manager import ThreadContextManager ...@@ -15,6 +15,7 @@ from ktransformers.server.backend.context_manager import ThreadContextManager
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
def create_interface(config: Config, default_args: ConfigArgs): def create_interface(config: Config, default_args: ConfigArgs):
if config.backend_type=='transformers': if config.backend_type=='transformers':
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
...@@ -22,6 +23,8 @@ def create_interface(config: Config, default_args: ConfigArgs): ...@@ -22,6 +23,8 @@ def create_interface(config: Config, default_args: ConfigArgs):
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
elif config.backend_type == 'ktransformers': elif config.backend_type == 'ktransformers':
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
elif config.backend_type == 'balance_serve':
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface
else: else:
raise NotImplementedError(f'{config.backend_type} not implemented') raise NotImplementedError(f'{config.backend_type} not implemented')
GlobalInterface.interface = BackendInterface(default_args) GlobalInterface.interface = BackendInterface(default_args)
...@@ -30,9 +33,9 @@ def create_interface(config: Config, default_args: ConfigArgs): ...@@ -30,9 +33,9 @@ def create_interface(config: Config, default_args: ConfigArgs):
class GlobalContextManager: class GlobalContextManager:
context_manager: ThreadContextManager context_manager: ThreadContextManager
class GlobalInterface: class GlobalInterface:
interface: TransformersInterface | KTransformersInterface | ExllamaInterface interface: TransformersInterface | KTransformersInterface | ExllamaInterface
def get_thread_context_manager() -> ThreadContextManager: def get_thread_context_manager() -> GlobalContextManager:
return GlobalContextManager.context_manager return GlobalContextManager.context_manager
def get_interface() -> TransformersInterface | KTransformersInterface | ExllamaInterface: def get_interface() -> GlobalInterface:
return GlobalInterface.interface return GlobalInterface.interface
\ No newline at end of file
import argparse
import random
import time
import json
import requests
import pandas as pd
from datasets import load_dataset
import os
import concurrent.futures
import threading
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['https_proxy'] = ''
os.environ['http_proxy'] = ''
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
class DataEvaluator:
def __init__(self):
self.data = []
def load_data(self, file_path):
"""
从数据文件中加载数据,每条记录对应一个实例
"""
ds = load_dataset(file_path, "all")
df = pd.DataFrame(ds['test'])
for _, row in df.iterrows():
self.data.append(row.to_dict())
def get_prompt(self, record):
"""
结合提示信息和记录数据生成完整的题目
"""
options_str = "\n".join([f"{chr(65 + i)}. {opt}" for i, opt in enumerate(record['choices'])])
prompt = hint + "\nQuestion: " + record['question'] + "\n" + options_str + "\nAnswer: '"
return prompt
def post_processing(self, text):
"""
对生成的文本进行后处理,提取最终答案(只返回最后一个字符)
"""
text = text.lstrip('\n').split('\n')[-1]
return text[-1:]
def score(self, pred, answer):
"""
对比预测答案和正确答案,返回得分
"""
if pred == answer:
return 1
return 0
def generate_text(api_url, question, model_name, stream=False):
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': 'Bearer ' # 如有需要,请填入 API Key
}
data = {
"messages": [{"content": question, "role": "user"}],
"model": model_name,
"stream": stream,
}
print("POST data:", data)
response = requests.post(api_url, headers=headers, json=data, timeout=5000000)
if response.status_code == 200:
result = response.json()
return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip()
else:
print(f"API Request failed with status code {response.status_code}")
return None
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
start_total_time = time.time()
total_score = 0
results = []
file_lock = threading.Lock()
# 打乱数据顺序,并选择需要测试的实例数
random.seed(42)
random.shuffle(data_evaluator.data)
data_subset = data_evaluator.data[:min(concurrent_requests, len(data_evaluator.data))]
batch_size = 10 # 每批次最多 10 个实例
def worker(index, data_item):
nonlocal total_score
question = data_evaluator.get_prompt(data_item)
start_time = time.time()
try:
prediction = generate_text(api_url, question, model_name)
if prediction is None:
raise Exception(f"Failed to get prediction for question: {question}")
# 正确答案:将数字转换成字母(0->A, 1->B, 2->C, 3->D)
answer = chr(data_item['answer'] + 65)
processed_prediction = data_evaluator.post_processing(prediction)
score = data_evaluator.score(processed_prediction, answer)
elapsed_time = time.time() - start_time
result_data = {
"question_id": index,
"answer": answer,
"prediction": processed_prediction,
"real_prediction": prediction,
"score": score,
"time": elapsed_time
}
# 写入结果时加锁保证线程安全
with file_lock:
with open(result_file, 'a', encoding='utf-8') as f:
json.dump(result_data, f, ensure_ascii=False, indent=4)
f.write("\n")
return result_data
except Exception as e:
print(f"Error processing request {index}: {e}")
return None
# 按批次处理,每批最多 10 个任务
for batch_start in range(0, len(data_subset), batch_size):
batch = data_subset[batch_start: batch_start + batch_size]
with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
futures = [executor.submit(worker, batch_start + j, data_item) for j, data_item in enumerate(batch)]
for future in concurrent.futures.as_completed(futures):
res = future.result()
if res is not None:
results.append(res)
total_score += res['score']
total_time = time.time() - start_total_time
throughput = len(data_subset) / total_time if total_time > 0 else 0
with open(log_file, 'a', encoding='utf-8') as log_f:
log_f.write(f"Total Time: {total_time:.2f} seconds\n")
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
average_score = total_score / len(data_subset) if data_subset else 0
log_f.write(f"Average Score: {average_score}\n")
log_f.write('-' * 40 + '\n')
print(f"Results saved to {result_file}")
print(f"Log saved to {log_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="API Generate Tester")
parser.add_argument("--concurrent", type=int, default=1000, help="需要测试的实例总数")
parser.add_argument("--file", type=str, default="cais/mmlu", help="数据文件路径")
parser.add_argument("--result", type=str, default="./mmlu_result_silicon.json", help="结果文件保存路径")
parser.add_argument("--log", type=str, default="./mmlu_result_silicon.log", help="日志文件保存路径")
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="模型名称或路径")
parser.add_argument("--api_url", type=str, default="http://localhost:10006/v1/chat/completions", help="API URL")
args = parser.parse_args()
data_evaluator = DataEvaluator()
data_evaluator.load_data(args.file)
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
import asyncio
import json
import sys
import aiohttp
import random
import argparse
import yaml
import os
import time
from time import sleep
decodesz = 128
# Server URL (replace with your server URL)
SERVER_URL = "http://localhost:10002/v1/chat/completions"
bf_list = [1]
decodesz_list = [128]
prompt_list = ['请你介绍下秦始皇', '3.9 和 3.11 哪个大', '抗衰老有何妙招', '给我讲个故事']
async def fetch_event_stream(session, request_id):
try:
payload = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": prompt_list[request_id]}
],
"model": "DeepSeek-V3",
"temperature": 0.3,
"top_p": 1.0,
"stream": True # 开启流式输出
}
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
}
async with session.post(SERVER_URL, json=payload, headers=headers, timeout=50000) as response:
print(f"Request {request_id}: Connected, status {response.status}")
if response.status != 200:
print(f"Request {request_id}: Error, status {response.status}")
return
output_text = "" # 存储当前 response 的所有 token
total_tokens = 0 # 统计总 tokens 数
decode_start_time = None # 记录 decode 阶段开始时间
decode_end_time = None # 记录 decode 结束时间
async for line in response.content:
try:
decoded_line = line.decode("utf-8").strip()
# 过滤空行
if not decoded_line or not decoded_line.startswith("data: "):
continue
decoded_line = decoded_line[6:].strip() # 去掉 `data: `
# 确保 JSON 数据是合法的
if not decoded_line:
continue
response_data = json.loads(decoded_line) # 解析 JSON
# 确保 choices 存在
choices = response_data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
token = delta.get("content", "")
if token:
if decode_start_time is None:
decode_start_time = time.time() # 记录 decode 开始时间
output_text += token # 追加 token
sys.stdout.write(token) # 直接输出 token
sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端
total_tokens += 1 # 增加 token 计数
decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间
# 检查是否完成
finish_reason = choices[0].get("finish_reason", None)
if finish_reason:
# print(f"\nRequest {request_id}: Done")
break # 结束流式处理
except json.JSONDecodeError as e:
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
except IndexError:
print(f"\nRequest {request_id}: List Index Error - choices is empty")
except Exception as e:
print(f"\nRequest {request_id}: Error parsing stream - {e}")
# 计算 decode 速度
if decode_start_time and decode_end_time and total_tokens > 0:
decode_time = decode_end_time - decode_start_time
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
except Exception as e:
print(f"\nRequest {request_id}: Exception - {e}")
async def main(prompt_id):
async with aiohttp.ClientSession() as session:
tasks = [fetch_event_stream(session, prompt_id)]
await asyncio.gather(*tasks)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--question_id", type=int, default=0, required=False)
args = parser.parse_args()
output_file = "ktransformer_test_results.txt"
asyncio.run(main(args.question_id))
import asyncio
import json
import sys
import aiohttp
import random
import argparse
import yaml
import os
import time
from time import sleep
decodesz = 128
# Server URL (replace with your server URL)
decodesz_list = [128]
ktansformer_prompt1024="""在遥远的翡翠森林里,住着各种各样的神奇生物。其中,有一只名叫露露的小狐狸,她与其他狐狸不同,天生长着一双晶莹剔透的翅膀。然而,这双翅膀却从未带她飞翔过。
一天,森林里传来一个惊人的消息:藏在森林深处的魔法泉水干涸了,所有生物赖以生存的泉水即将枯竭。他们说,只有传说中的“天空之羽”才能唤醒泉水,让它重新流淌。然而,“天空之羽”藏在一座高耸入云的山峰上,没有任何动物能抵达那里。
露露听到这个消息后,决定亲自去寻找“天空之羽”,即便她的翅膀无法飞翔,她也要尝试。最终,露露来到了传说中的高峰脚下,根本无法攀爬。她望着天空,心里充满了不甘:“如果我能飞起来,就不会被这座山挡住了……”
正当她感到迷茫时,一只年迈的白鹰出现在她面前。
“孩子,你为什么到这里来?”白鹰用苍老但慈祥的声音问道。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。
在遥远的翡翠森林里,住着各种各样的神奇生物。其中,有一只名叫露露的小狐狸,她与其他狐狸不同,天生长着一双晶莹剔透的翅膀。然而,这双翅膀却从未带她飞翔过。
一天,森林里传来一个惊人的消息:藏在森林深处的魔法泉水干涸了,所有生物赖以生存的泉水即将枯竭。他们说,只有传说中的“天空之羽”才能唤醒泉水,让它重新流淌。然而,“天空之羽”藏在一座高耸入云的山峰上,没有任何动物能抵达那里。
露露听到这个消息后,决定亲自去寻找“天空之羽”,即便她的翅膀无法飞翔,她也要尝试。最终,露露来到了传说中的高峰脚下,根本无法攀爬。她望着天空,心里充满了不甘:“如果我能飞起来,就不会被这座山挡住了……”
正当她感到迷茫时,一只年迈的白鹰出现在她面前。
“孩子,你为什么到这里来?”白鹰用苍老但慈祥的声音问道。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。想。
请简述这个故事的内涵 故事的内涵这个故事的内涵写10000个字"""
async def fetch_event_stream(session, request_id , prompt):
try:
payload = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": prompt}
],
"model": "DeepSeek-V3",
"temperature": 0.3,
"top_p": 1.0,
"stream": True # 开启流式输出
}
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
}
async with session.post(SERVER_URL, json=payload, headers=headers, timeout=500000) as response:
print(f"Request {request_id}: Connected, status {response.status}")
if response.status != 200:
print(f"Request {request_id}: Error, status {response.status}")
return
output_text = "" # 存储当前 response 的所有 token
total_tokens = 0 # 统计总 tokens 数
decode_start_time = None # 记录 decode 阶段开始时间
decode_end_time = None # 记录 decode 结束时间
async for line in response.content:
try:
decoded_line = line.decode("utf-8").strip()
# 过滤空行
if not decoded_line or not decoded_line.startswith("data: "):
continue
decoded_line = decoded_line[6:].strip() # 去掉 `data: `
# 确保 JSON 数据是合法的
if not decoded_line:
continue
response_data = json.loads(decoded_line) # 解析 JSON
# 确保 choices 存在
choices = response_data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
token = delta.get("content", "")
if token:
if decode_start_time is None:
decode_start_time = time.time() # 记录 decode 开始时间
output_text += token # 追加 token
sys.stdout.write(str(request_id))
sys.stdout.write(token) # 直接输出 token
sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端
total_tokens += 1 # 增加 token 计数
decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间
# 检查是否完成
finish_reason = choices[0].get("finish_reason", None)
if finish_reason:
# print(f"\nRequest {request_id}: Done")
break # 结束流式处理
except json.JSONDecodeError as e:
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
except IndexError:
print(f"\nRequest {request_id}: List Index Error - choices is empty")
except Exception as e:
print(f"\nRequest {request_id}: Error parsing stream - {e}")
# 计算 decode 速度
if decode_start_time and decode_end_time and total_tokens > 0:
decode_time = decode_end_time - decode_start_time
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
except Exception as e:
print(f"\nRequest {request_id}: Exception - {e}")
async def main(concurrent_requests , prompt ):
async with aiohttp.ClientSession() as session:
tasks = [fetch_event_stream(session, i , prompt) for i in range(concurrent_requests)]
await asyncio.gather(*tasks)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
args = parser.parse_args()
SERVER_URL = args.api_url
if args.prompt_lens == 1024:
prompt = ktansformer_prompt1024
elif args.prompt_lens == 2048:
prompt = ktansformer_prompt1024 * 2
asyncio.run(main(args.concurrent, prompt))
...@@ -18,9 +18,26 @@ from ktransformers.models.custom_cache import StaticCache ...@@ -18,9 +18,26 @@ from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer from ktransformers.util.textstream import TextStreamer
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
import socket
warm_uped = False warm_uped = False
def get_free_ports(n: int, continue_prot: list):
sockets = []
ports = []
for _ in range(n):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
if port in continue_prot:
s.close()
continue
ports.append(port)
sockets.append(s)
for s in sockets:
s.close()
return ports
def get_compute_capability(device:torch.device = None): def get_compute_capability(device:torch.device = None):
if torch.cuda.is_available(): if torch.cuda.is_available():
if device is None: if device is None:
...@@ -110,7 +127,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): ...@@ -110,7 +127,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module.load() module.load()
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True, def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal', force_think: bool = False, chunk_prefill_size = 16384, use_flashinfer_mla = False, mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None): num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
import os import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
...@@ -202,11 +219,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud ...@@ -202,11 +219,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
chunk_start = 0 chunk_start = 0
while chunk_start < seq_length: while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_prefill_size, seq_length) chunk_end = min(chunk_start + chunk_size, seq_length)
if past_key_values != None: if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end] past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values) logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_prefill_size chunk_start += chunk_size
next_token_scores = logits_warper(inputs, logits[:, -1, :]) next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample: if generation_config.do_sample:
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import os import os
# insert the path of the project # insert the path of the project
import sys import sys
sys.path.insert(0, "/home/azure/ktransformers") # sys.path.insert(0, "/home/azure/ktransformers")
import argparse import argparse
import torch import torch
from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
......
...@@ -6,4 +6,4 @@ packaging ...@@ -6,4 +6,4 @@ packaging
cpufeature cpufeature
protobuf protobuf
tiktoken tiktoken
blobfile blobfile
\ No newline at end of file
...@@ -35,6 +35,8 @@ try: ...@@ -35,6 +35,8 @@ try:
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
except ImportError: except ImportError:
MUSA_HOME=None MUSA_HOME=None
with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1"
class CpuInstructInfo: class CpuInstructInfo:
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE") CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
...@@ -212,7 +214,7 @@ class VersionInfo: ...@@ -212,7 +214,7 @@ class VersionInfo:
cpu_instruct = self.get_cpu_instruct() cpu_instruct = self.get_cpu_instruct()
backend_version = "" backend_version = ""
if CUDA_HOME is not None: if CUDA_HOME is not None:
backend_version = f"" backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}"
elif MUSA_HOME is not None: elif MUSA_HOME is not None:
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}" backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
elif ROCM_HOME is not None: elif ROCM_HOME is not None:
...@@ -274,11 +276,10 @@ PLAT_TO_CMAKE = { ...@@ -274,11 +276,10 @@ PLAT_TO_CMAKE = {
class CMakeExtension(Extension): class CMakeExtension(Extension):
def __init__(self, name: str, sourcedir: str = "") -> None: def __init__(self, name: str, sourcedir: str) -> None:
super().__init__(name, sources=[]) super().__init__(name, sources=[])
self.sourcedir = os.fspath( print(name, sourcedir)
Path(sourcedir).resolve() / "ktransformers" / "ktransformers_ext") self.sourcedir = sourcedir
class CMakeBuild(BuildExtension): class CMakeBuild(BuildExtension):
...@@ -342,16 +343,17 @@ class CMakeBuild(BuildExtension): ...@@ -342,16 +343,17 @@ class CMakeBuild(BuildExtension):
f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"] f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"]
if self.compiler.compiler_type != "msvc": if self.compiler.compiler_type != "msvc":
if not cmake_generator or cmake_generator == "Ninja": if not cmake_generator or cmake_generator == "Ninja":
try: pass
import ninja # try:
# import ninja
ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
cmake_args += [ # ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
"-GNinja", # cmake_args += [
f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", # "-GNinja",
] # f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
except ImportError: # ]
pass # except ImportError:
# pass
else: else:
# Single config generators are handled "normally" # Single config generators are handled "normally"
...@@ -387,10 +389,12 @@ class CMakeBuild(BuildExtension): ...@@ -387,10 +389,12 @@ class CMakeBuild(BuildExtension):
build_args += [f"--parallel={cpu_count}"] build_args += [f"--parallel={cpu_count}"]
print("CMake args:", cmake_args) print("CMake args:", cmake_args)
build_temp = Path(ext.sourcedir) / "build" build_temp = Path(ext.sourcedir) / "build"
print("build_temp:", build_temp)
if not build_temp.exists(): if not build_temp.exists():
build_temp.mkdir(parents=True) build_temp.mkdir(parents=True)
result = subprocess.run( result = subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True, text=True
) )
print("Standard output:", result.stdout) print("Standard output:", result.stdout)
print("Standard error:", result.stderr) print("Standard error:", result.stderr)
...@@ -400,9 +404,9 @@ class CMakeBuild(BuildExtension): ...@@ -400,9 +404,9 @@ class CMakeBuild(BuildExtension):
if CUDA_HOME is not None or ROCM_HOME is not None: if CUDA_HOME is not None or ROCM_HOME is not None:
ops_module = CUDAExtension('KTransformersOps', [ ops_module = CUDAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu', 'csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp', 'csrc/ktransformers_ext/cuda/binding.cpp',
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu' 'csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
], ],
extra_compile_args={ extra_compile_args={
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'], 'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
...@@ -415,7 +419,7 @@ if CUDA_HOME is not None or ROCM_HOME is not None: ...@@ -415,7 +419,7 @@ if CUDA_HOME is not None or ROCM_HOME is not None:
} }
) )
elif MUSA_HOME is not None: elif MUSA_HOME is not None:
SimplePorting(cuda_dir_path="ktransformers/ktransformers_ext/cuda", mapping_rule={ SimplePorting(cuda_dir_path="csrc/ktransformers_ext/cuda", mapping_rule={
# Common rules # Common rules
"at::cuda": "at::musa", "at::cuda": "at::musa",
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"", "#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
...@@ -423,10 +427,10 @@ elif MUSA_HOME is not None: ...@@ -423,10 +427,10 @@ elif MUSA_HOME is not None:
"nv_bfloat16": "mt_bfloat16", "nv_bfloat16": "mt_bfloat16",
}).run() }).run()
ops_module = MUSAExtension('KTransformersOps', [ ops_module = MUSAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu', 'csrc/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
'ktransformers/ktransformers_ext/cuda_musa/binding.cpp', 'csrc/ktransformers_ext/cuda_musa/binding.cpp',
# TODO: Add Marlin support for MUSA. # TODO: Add Marlin support for MUSA.
# 'ktransformers/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu' # 'csrc/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
], ],
extra_compile_args={ extra_compile_args={
'cxx': ['force_mcc'], 'cxx': ['force_mcc'],
...@@ -440,12 +444,30 @@ elif MUSA_HOME is not None: ...@@ -440,12 +444,30 @@ elif MUSA_HOME is not None:
else: else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
ext_modules = [
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
ops_module,
CUDAExtension(
'vLLMMarlin', [
'csrc/custom_marlin/binding.cpp',
'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu',
'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu',
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '-Xcompiler', '-fPIC'],
},
)
]
if with_balance:
print("using balance_serve")
ext_modules.append(
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
)
setup( setup(
name=VersionInfo.PACKAGE_NAME, name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(), version=VersionInfo().get_package_version(),
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=[ ext_modules=ext_modules
CMakeExtension("cpuinfer_ext"),
ops_module,
]
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment