Commit 25cee581 authored by Atream's avatar Atream
Browse files

add balance-serve, support concurrence

parent 8d0292aa
import typing
import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedRepetitionPenalizer(_BatchedPenalizer):
"""
Repetition penalizer penalizes tokens based on their repetition in the input and output.
"""
repetition_penalties: torch.Tensor = None
cumulated_repetition_penalties: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.repetition_penalty != 1.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_repetition_penalties = (
torch.tensor(
data=[1.0 for _ in self.orchestrator.reqs()],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
)
self.repetition_penalties = (
torch.tensor(
data=[
req.sampling_params.repetition_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.expand_as(self.cumulated_repetition_penalties)
)
def _teardown(self):
del self.repetition_penalties
del self.cumulated_repetition_penalties
self.repetition_penalties = None
self.cumulated_repetition_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
mask = input_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
mask = output_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
return torch.where(
logits > 0,
logits / self.cumulated_repetition_penalties,
logits * self.cumulated_repetition_penalties,
)
def _filter(
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
indices_tensor_to_keep
]
def _merge(self, their: "BatchedRepetitionPenalizer"):
self.repetition_penalties = torch.cat(
[self.repetition_penalties, their.repetition_penalties], dim=0
)
self.cumulated_repetition_penalties = torch.cat(
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
dim=0,
)
'''
Date: 2024-11-14 12:23:45
LastEditors: Xie Weiyu ervinxie@qq.com
LastEditTime: 2024-11-25 08:59:23
'''
import logging
import torch
from torch import nn
from transformers import GenerationConfig
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_probs,
top_k_top_p_sampling_from_logits,
top_p_renorm_probs,
)
logger = logging.getLogger(__name__)
class SamplingOptions():
# Batched sampling params
temperatures: torch.Tensor
top_ps: torch.Tensor
top_ks: torch.Tensor
min_ps: torch.Tensor
# All requests use greedy sampling
is_all_greedy: bool
# Dispatch in CUDA graph
need_min_p_sampling: bool
def __init__(self, bsz = 1, device = torch.device('cuda'), pretrained_config:GenerationConfig = None, temperatures: torch.Tensor = None, top_ps: torch.Tensor = None):
if pretrained_config is None and temperatures is None:
self.temperatures = torch.full((bsz, 1), 0, device=device, dtype=torch.float32)
self.top_ps = torch.ones((bsz, 1), device=device, dtype=torch.float32)
self.top_ks = torch.ones((bsz, 1), device=device, dtype=torch.float32)
self.need_min_p_sampling = False
self.is_all_greedy = True
else:
if temperatures is not None:
self.temperatures = temperatures.unsqueeze(-1)
else:
self.temperatures = torch.full((bsz, 1), pretrained_config.temperature, device=device, dtype=torch.float32)
if top_ps is not None:
self.top_ps = top_ps.unsqueeze(-1)
else:
self.top_ps = torch.full((bsz, 1), pretrained_config.top_p, device=device, dtype=torch.float32)
self.top_ks = torch.full((bsz, 1), pretrained_config.top_k, device=device, dtype=torch.float32)
self.need_min_p_sampling = False
self.is_all_greedy = False
class Sampler(nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
logits: torch.Tensor,
sampling_config: SamplingOptions = None,
):
if sampling_config == None:
sampling_config = SamplingOptions()
logits = logits.contiguous()
origin_logits = logits.clone()
if sampling_config.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
probs = logits
batch_next_token_ids = torch.argmax(logits, -1)
else:
# Post process logits
logits.div_(sampling_config.temperatures)
max_top_k_round, batch_size = 32, logits.shape[0]
if sampling_config.need_min_p_sampling:
probs = torch.softmax(logits, dim=-1)
logits = None
del logits
probs = top_k_renorm_probs(probs, sampling_config.top_ks)
probs = top_p_renorm_probs(probs, sampling_config.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, sampling_config.min_ps
)
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
else:
# TODO: use different kernel when don't need top_k or top_p
# @TODO get probs
probs = logits
batch_next_token_ids = top_k_top_p_sampling_from_logits(
logits,
sampling_config.top_ks,
sampling_config.top_ps,
filter_apply_order="joint",
)
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
return batch_next_token_ids.to(torch.int32), probs
\ No newline at end of file
from datetime import datetime
import os
from typing import Optional
import zmq
import pickle
import threading
import torch.multiprocessing as mp
import sys
current_file_path = os.path.abspath(__file__)
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
import pickle
import argparse
from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings
if mp.get_start_method(allow_none=True) is None:
print('set start method')
mp.set_start_method('spawn')
else:
print(f'start method already set to {mp.get_start_method(allow_none=True)}')
class SchedulerServer:
def __init__(self, settings, main_args):
# 创建 Scheduler 实例并初始化
self.sched = sched_ext.create_scheduler(settings)
# 初始化 ZeroMQ 上下文和套接字
self.context = zmq.Context()
self.frontend = self.context.socket(zmq.ROUTER)
print(f"sched zmq rpc server on port {main_args.sched_port}")
self.frontend.bind(f"tcp://*:{main_args.sched_port}")
# 创建内部的 DEALER 套接字,用于与工作线程通信
self.backend = self.context.socket(zmq.DEALER)
self.backend.bind("inproc://backend")
# 启动调度器
def run_scheduler(self):
self.sched.run()
# 停止调度器
def stop_scheduler(self):
self.sched.stop()
# 处理客户端请求
def start_proxy(self):
# 使用 ZMQ 的内置代理,将前端请求分发给后端工作线程
zmq.proxy(self.frontend, self.backend)
# 工作线程处理请求
def worker_routine(self):
worker = self.context.socket(zmq.REP)
worker.connect("inproc://backend")
while True:
try:
# 接收客户端请求
message = worker.recv()
data = pickle.loads(message)
method = data.get('method')
params = data.get('params', {})
# print(f"Received request: {method}")
if method == 'add_query':
query_add = params.get('query') # 直接是一个 QueryAdd 对象
# 添加查询
query_id = self.sched.add_query(query_add)
# 发送响应
response = {'status': 'ok', 'query_id': query_id}
worker.send(pickle.dumps(response))
elif method == 'cancel_query':
query_id = params.get('query_id')
# 假设您的 Scheduler 类实现了 cancel 方法
self.sched.cancel(query_id)
response = {'status': 'ok'}
worker.send(pickle.dumps(response))
elif method == 'update_last_batch':
updates = params.get('updates') # 直接是一个列表,包含 QueryUpdate 对象
# 更新最后一个批次
batch_todo = self.sched.update_last_batch(updates)
# 直接发送 batch_todo 对象
response = {'status': 'ok', 'batch_todo': batch_todo}
# print (batch_todo.query_lengths, batch_todo.query_ids)
worker.send(pickle.dumps(response))
elif method == 'get_inference_context':
inference_context = self.sched.get_inference_context()
data = {
"k_cache":inference_context.k_cache,
"v_cache":inference_context.v_cache
}
print(f"Serializing KVCache")
data["k_cache"] = [mp.reductions.reduce_tensor(t) for t in data['k_cache']]
data["v_cache"] = [mp.reductions.reduce_tensor(t) for t in data['v_cache']]
# print(data)
response = {'status': 'ok', 'inference_context': data}
worker.send(pickle.dumps(response))
# response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1
# print("k_cache update")
else:
# 未知方法
response = {'status': 'error', 'message': 'Unknown method'}
worker.send(pickle.dumps(response))
except Exception as e:
# 处理异常并发送错误响应
response = {'status': 'error', 'message': str(e)}
worker.send(pickle.dumps(response))
# 启动 RPC 服务
def start_rpc_service(self):
try:
print("Scheduler RPC service is running...")
# 在单独的线程中运行调度器
threading.Thread(target=self.run_scheduler, daemon=True).start()
# 启动工作线程
for _ in range(10): # 根据需要调整线程数
threading.Thread(target=self.worker_routine, daemon=True).start()
# 启动代理,开始监听请求
self.start_proxy()
except KeyboardInterrupt:
print("Shutting down scheduler RPC service...")
self.stop_rpc_service()
# 停止 RPC 服务
def stop_rpc_service(self):
self.stop_scheduler()
self.frontend.close()
self.backend.close()
self.context.term()
def start_server(settings, main_args):
server = SchedulerServer(settings, main_args)
server.start_rpc_service()
# Add async client for webserver
class SchedulerClient:
def __init__(self, sched_port):
address=f'tcp://localhost:{sched_port}'
self.address = address
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(self.address)
print(f"Connected to server at {self.address}")
def __del__(self):
self.socket.close()
self.context.term()
def send_request(self, method, params=None):
if params is None:
params = {}
request = {
'method': method,
'params': params
}
# print(f'send request {request}')
self.socket.send(pickle.dumps(request))
response = self.socket.recv()
# print(response)
response = pickle.loads(response)
if response.get('status') == 'ok':
return response
else:
raise Exception(f"Error from server: {response.get('message')}")
def add_query(self, query):
response = self.send_request('add_query', {'query': query})
return response.get('query_id')
def cancel_query(self, query_id):
self.send_request('cancel_query', {'query_id': query_id})
def update_last_batch(self, updates):
response = self.send_request('update_last_batch', {'updates': updates})
# print(f"update_last_batch response {response}")
return response.get('batch_todo')
def rebuild_inferece_context(self,response):
data = response.get('inference_context')
inference_context = sched_ext.InferenceContext()
print('Rebuilding kvcache')
inference_context.k_cache = [fn(*args) for fn,args in data['k_cache']]
inference_context.v_cache = [fn(*args) for fn,args in data['v_cache']]
return inference_context
def get_inference_context_raw(self):
response = self.send_request('get_inference_context')
return response
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
with open(args.config, "rb") as f:
main_args = pickle.load(f)
settings = create_sched_settings(main_args)
start_server(settings, main_args)
'''
Date: 2024-11-13 09:43:39
LastEditors: djw
LastEditTime: 2024-11-18 16:41:03
'''
import sys, os
import yaml, json
from time import sleep
current_dir = os.path.dirname(__file__)
# sched_path = os.path.abspath(os.path.join(current_dir, '../../../build/balance_serve/sched'))
# sys.path.insert(0, sched_path)
import sched_ext
from transformers import AutoConfig
def create_sched_settings(args):
default_sample_options = sched_ext.SampleOptions()
model_name = os.path.basename(os.path.normpath(args.model_dir))
input_model_settings = sched_ext.ModelSettings()
input_model_settings.model_path = args.model_dir
input_model_settings.params_count = int(0)
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
input_model_settings.layer_count = model_config.num_hidden_layers
input_model_settings.num_k_heads = 1 # model_config["num_key_value_heads"]
input_model_settings.k_head_dim = 576
input_model_settings.bytes_per_params = 2
input_model_settings.bytes_per_kv_cache_element = 2
settings = sched_ext.Settings()
settings.model_name = model_name
settings.quant_type = "BF16"
settings.model_settings = input_model_settings
settings.page_size = args.page_size
settings.gpu_device_count = 1 # tp
settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]
# settings.gpu_memory_size = args.cache_lens*576*2
settings.gpu_memory_size = args.gpu_memory_size
settings.memory_utilization_percentage = args.utilization_percentage
max_batch_size = args.max_batch_size
chunk_size = args.chunk_size
max_decode_batch_size = max_batch_size - 2
settings.max_batch_size = max_batch_size
settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2
settings.sample_options = default_sample_options
settings.sched_metrics_port = args.sched_metrics_port
settings.gpu_only = args.memory_gpu_only
settings.use_self_defined_head_dim = True
settings.self_defined_head_dim = 576
settings.full_kv_cache_on_each_gpu = True
settings.k_cache_on = True
settings.v_cache_on = False
settings.kvc2_root_path = '/mnt/data/persist-kvc'
settings.kvc2_config_path = os.path.join(current_dir, "..", "..", "configs")
print(os.path.join(current_dir, "..", "..", "configs"))
settings.memory_pool_size_GB = args.cpu_memory_size_GB
settings.evict_count = 40
settings.kvc2_metrics_port = args.kvc2_metrics_port
settings.load_from_disk = False
settings.save_to_disk = True
settings.strategy_name = args.sched_strategy
settings.auto_derive()
return settings
......@@ -11,6 +11,7 @@ LastEditTime : 2024-08-12 06:31:14
import os
import shutil
import yaml
import psutil
from ktransformers.server.config.singleton import Singleton
from typing import Optional
......@@ -60,7 +61,7 @@ class Config(metaclass=Singleton):
self.user_path: str = os.path.expanduser("~")
self.localstore_path: str = os.path.join(self.user_path, ".ktransformers")
# log configs
self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"]))
self.log_dir = os.path.join(self.localstore_path, cfg["log"]["dir"])
self.log_file = cfg["log"]["file"]
self.log_level = cfg["log"]["level"]
self.backup_count = cfg["log"]["backup_count"]
......@@ -74,7 +75,7 @@ class Config(metaclass=Singleton):
# db configs
self.db_configs: dict = cfg.get("db", {})
self.db_type = self.db_configs.get("type", "")
self.db_host = os.path.join(self.base_path, self.db_configs.get("host", ""))
self.db_host = Config.to_path(self.db_configs.get("host", ""))
self.db_port = self.db_configs.get("port", "")
self.db_name = self.db_configs.get("database", "")
self.db_pool_size = self.db_configs.get("pool_size")
......@@ -101,11 +102,6 @@ class Config(metaclass=Singleton):
self.optimize_config_path: Optional[str] = self.model.get(
"optimize_config_path", None
)
self.paged = self.model.get("paged", True)
self.total_context = self.model.get("total_context", 2**18)
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
self.chunk_prefill_size = self.model.get("chunk_prefill_size", 8192)
self.max_new_tokens = self.model.get("max_new_tokens", 2000)
self.json_mode = self.model.get("json_mode", False)
......@@ -138,7 +134,6 @@ class Config(metaclass=Singleton):
self.repetition_penalty = self.model.get("repetition_penalty", 1.01)
self.frequency_penalty = self.model.get("frequency_penalty", 0.0)
self.presence_penalty = self.model.get("presence_penalty", 0.0)
self.max_response_tokens = self.model.get("max_response_tokens", 300)
self.response_chunk = self.model.get("response_chunk", 250)
self.no_code_formatting = self.model.get("no_code_formatting", False)
self.cache_8bit = self.model.get("cache_8bit", False)
......@@ -155,8 +150,9 @@ class Config(metaclass=Singleton):
self.web_cross_domain: bool = self.web.get("open_cross_domain", True)
self.mount_web: bool = self.web.get("mount", False)
# ext
self.ext: dict = cfg.get("ext", {})
self.cpu_infer = self.ext.get("cpu_infer", 10)
self.cpu_infer = psutil.cpu_count(logical=False) - 3
# file config
self.local_store_configs: dict = cfg.get("local_store", {})
......@@ -169,7 +165,6 @@ class Config(metaclass=Singleton):
# long context config
self.long_context_config: dict = cfg.get("long_context", {})
self.chunk_size = self.long_context_config.get("chunk_size", 4096)
self.max_seq_len = self.long_context_config.get("max_seq_len", 32000)
self.block_size = self.long_context_config.get("block_size", 128)
self.local_windows_len = self.long_context_config.get("local_windows_len", 4096)
......@@ -187,3 +182,21 @@ class Config(metaclass=Singleton):
# local chat
self.local_chat_config: dict = cfg.get("local_chat", {})
self.prompt_file = self.local_chat_config.get("prompt_file", None)
# asyncserver
self.sched_strategy = cfg['async_server']['sched_strategy']
self.sched_port = cfg['async_server']['sched_port']
self.sched_metrics_port = cfg['async_server']['sched_metrics_port']
self.kvc2_metrics_port = cfg['async_server']['kvc2_metrics_port']
self.max_batch_size = cfg['async_server']['max_batch_size']
self.page_size = cfg['attn']['page_size']
self.chunk_size = cfg['attn']['chunk_size']
self.memory_gpu_only = cfg['kvc2']['gpu_only']
self.cache_lens = ((self.cache_lens + self.page_size - 1) // self.page_size) * self.page_size
self.gpu_memory_size = 2*576*61*self.cache_lens
self.utilization_percentage = 1.0 #cfg['kvc2']['utilization_percentage']
self.cpu_memory_size_GB = cfg['kvc2']['cpu_memory_size_GB']
# only support 2 prefill task
self.max_prefill_batch_size = 2
self.max_decode_batch_size = self.max_batch_size - self.max_prefill_batch_size
......@@ -5,24 +5,20 @@ from fastapi.staticfiles import StaticFiles
import uvicorn.logging
import uvicorn
import sys
import atexit
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, project_dir)
from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.args import ArgumentParser
from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import create_interface
from ktransformers.server.backend.args import default_args
from ktransformers.server.utils.create_interface import create_interface, GlobalInterface
from fastapi.openapi.utils import get_openapi
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.api import router, post_db_creation_operations
from ktransformers.server.utils.sql_utils import Base, SQLUtil
from ktransformers.server.config.log import logger
import subprocess
import tempfile
def mount_app_routes(mount_app: FastAPI):
sql_util = SQLUtil()
......@@ -34,7 +30,10 @@ def mount_app_routes(mount_app: FastAPI):
def create_app():
cfg = Config()
app = FastAPI()
if(hasattr(GlobalInterface.interface, "lifespan")):
app = FastAPI(lifespan=GlobalInterface.interface.lifespan)
else:
app = FastAPI()
if Config().web_cross_domain:
app.add_middleware(
CORSMiddleware,
......@@ -108,11 +107,32 @@ def main():
arg_parser = ArgumentParser(cfg)
# 初始化消息
args = arg_parser.parse_args()
if args.backend_type == "balance_serve":
import pickle
def cleanup():
if sched_process.poll() is None:
sched_process.terminate()
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
pickle.dump(args, temp_file)
temp_file_path = temp_file.name
current_file = __file__
target_file = os.path.join(os.path.dirname(current_file), "balance_serve", "sched_rpc.py")
target_file = os.path.normpath(target_file)
log_path = os.path.join(args.log_dir, "rpc.log")
log = open(log_path, "a")
sched_process = subprocess.Popen(
["python3", target_file, "--config", temp_file_path],
stdout=log,
stderr=log
)
print("sched_rpc started with PID:", sched_process.pid)
atexit.register(cleanup)
create_interface(config=cfg, default_args=cfg)
app = create_app()
custom_openapi(app)
create_interface(config=cfg, default_args=cfg)
run_api(
app=app,
host=args.host,
......@@ -121,6 +141,5 @@ def main():
ssl_certfile=args.ssl_certfile,
)
if __name__ == "__main__":
main()
torch >= 2.3.0,<=2.3.1
torch >= 2.3.0
transformers == 4.43.2
fastapi >= 0.111.0
langchain >= 0.2.0
......@@ -11,4 +11,6 @@ build
ninja
wheel
colorlog
fire
\ No newline at end of file
fire
zmq
psutil
\ No newline at end of file
......@@ -2,7 +2,7 @@ from typing import List, Optional
from typing_extensions import Literal
from enum import Enum
from pydantic import BaseModel
from pydantic import BaseModel, Field
from ktransformers.server.schemas.base import Object
......@@ -30,8 +30,8 @@ class ChatCompletionCreate(BaseModel):
messages: List[Message]
model : str
stream : bool = False
temperature: Optional[float] = None
top_p: Optional[float] = None
temperature: Optional[float] = Field(default=1.0)
top_p: Optional[float] = Field(default=1.0)
def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages]
......
......@@ -15,6 +15,7 @@ from ktransformers.server.backend.context_manager import ThreadContextManager
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
def create_interface(config: Config, default_args: ConfigArgs):
if config.backend_type=='transformers':
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
......@@ -22,6 +23,8 @@ def create_interface(config: Config, default_args: ConfigArgs):
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
elif config.backend_type == 'ktransformers':
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
elif config.backend_type == 'balance_serve':
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface
else:
raise NotImplementedError(f'{config.backend_type} not implemented')
GlobalInterface.interface = BackendInterface(default_args)
......@@ -30,9 +33,9 @@ def create_interface(config: Config, default_args: ConfigArgs):
class GlobalContextManager:
context_manager: ThreadContextManager
class GlobalInterface:
interface: TransformersInterface | KTransformersInterface | ExllamaInterface
interface: TransformersInterface | KTransformersInterface | ExllamaInterface
def get_thread_context_manager() -> ThreadContextManager:
def get_thread_context_manager() -> GlobalContextManager:
return GlobalContextManager.context_manager
def get_interface() -> TransformersInterface | KTransformersInterface | ExllamaInterface:
def get_interface() -> GlobalInterface:
return GlobalInterface.interface
\ No newline at end of file
import argparse
import random
import time
import json
import requests
import pandas as pd
from datasets import load_dataset
import os
import concurrent.futures
import threading
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['https_proxy'] = ''
os.environ['http_proxy'] = ''
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
class DataEvaluator:
def __init__(self):
self.data = []
def load_data(self, file_path):
"""
从数据文件中加载数据,每条记录对应一个实例
"""
ds = load_dataset(file_path, "all")
df = pd.DataFrame(ds['test'])
for _, row in df.iterrows():
self.data.append(row.to_dict())
def get_prompt(self, record):
"""
结合提示信息和记录数据生成完整的题目
"""
options_str = "\n".join([f"{chr(65 + i)}. {opt}" for i, opt in enumerate(record['choices'])])
prompt = hint + "\nQuestion: " + record['question'] + "\n" + options_str + "\nAnswer: '"
return prompt
def post_processing(self, text):
"""
对生成的文本进行后处理,提取最终答案(只返回最后一个字符)
"""
text = text.lstrip('\n').split('\n')[-1]
return text[-1:]
def score(self, pred, answer):
"""
对比预测答案和正确答案,返回得分
"""
if pred == answer:
return 1
return 0
def generate_text(api_url, question, model_name, stream=False):
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': 'Bearer ' # 如有需要,请填入 API Key
}
data = {
"messages": [{"content": question, "role": "user"}],
"model": model_name,
"stream": stream,
}
print("POST data:", data)
response = requests.post(api_url, headers=headers, json=data, timeout=5000000)
if response.status_code == 200:
result = response.json()
return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip()
else:
print(f"API Request failed with status code {response.status_code}")
return None
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
start_total_time = time.time()
total_score = 0
results = []
file_lock = threading.Lock()
# 打乱数据顺序,并选择需要测试的实例数
random.seed(42)
random.shuffle(data_evaluator.data)
data_subset = data_evaluator.data[:min(concurrent_requests, len(data_evaluator.data))]
batch_size = 10 # 每批次最多 10 个实例
def worker(index, data_item):
nonlocal total_score
question = data_evaluator.get_prompt(data_item)
start_time = time.time()
try:
prediction = generate_text(api_url, question, model_name)
if prediction is None:
raise Exception(f"Failed to get prediction for question: {question}")
# 正确答案:将数字转换成字母(0->A, 1->B, 2->C, 3->D)
answer = chr(data_item['answer'] + 65)
processed_prediction = data_evaluator.post_processing(prediction)
score = data_evaluator.score(processed_prediction, answer)
elapsed_time = time.time() - start_time
result_data = {
"question_id": index,
"answer": answer,
"prediction": processed_prediction,
"real_prediction": prediction,
"score": score,
"time": elapsed_time
}
# 写入结果时加锁保证线程安全
with file_lock:
with open(result_file, 'a', encoding='utf-8') as f:
json.dump(result_data, f, ensure_ascii=False, indent=4)
f.write("\n")
return result_data
except Exception as e:
print(f"Error processing request {index}: {e}")
return None
# 按批次处理,每批最多 10 个任务
for batch_start in range(0, len(data_subset), batch_size):
batch = data_subset[batch_start: batch_start + batch_size]
with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
futures = [executor.submit(worker, batch_start + j, data_item) for j, data_item in enumerate(batch)]
for future in concurrent.futures.as_completed(futures):
res = future.result()
if res is not None:
results.append(res)
total_score += res['score']
total_time = time.time() - start_total_time
throughput = len(data_subset) / total_time if total_time > 0 else 0
with open(log_file, 'a', encoding='utf-8') as log_f:
log_f.write(f"Total Time: {total_time:.2f} seconds\n")
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
average_score = total_score / len(data_subset) if data_subset else 0
log_f.write(f"Average Score: {average_score}\n")
log_f.write('-' * 40 + '\n')
print(f"Results saved to {result_file}")
print(f"Log saved to {log_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="API Generate Tester")
parser.add_argument("--concurrent", type=int, default=1000, help="需要测试的实例总数")
parser.add_argument("--file", type=str, default="cais/mmlu", help="数据文件路径")
parser.add_argument("--result", type=str, default="./mmlu_result_silicon.json", help="结果文件保存路径")
parser.add_argument("--log", type=str, default="./mmlu_result_silicon.log", help="日志文件保存路径")
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="模型名称或路径")
parser.add_argument("--api_url", type=str, default="http://localhost:10006/v1/chat/completions", help="API URL")
args = parser.parse_args()
data_evaluator = DataEvaluator()
data_evaluator.load_data(args.file)
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
import asyncio
import json
import sys
import aiohttp
import random
import argparse
import yaml
import os
import time
from time import sleep
decodesz = 128
# Server URL (replace with your server URL)
SERVER_URL = "http://localhost:10002/v1/chat/completions"
bf_list = [1]
decodesz_list = [128]
prompt_list = ['请你介绍下秦始皇', '3.9 和 3.11 哪个大', '抗衰老有何妙招', '给我讲个故事']
async def fetch_event_stream(session, request_id):
try:
payload = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": prompt_list[request_id]}
],
"model": "DeepSeek-V3",
"temperature": 0.3,
"top_p": 1.0,
"stream": True # 开启流式输出
}
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
}
async with session.post(SERVER_URL, json=payload, headers=headers, timeout=50000) as response:
print(f"Request {request_id}: Connected, status {response.status}")
if response.status != 200:
print(f"Request {request_id}: Error, status {response.status}")
return
output_text = "" # 存储当前 response 的所有 token
total_tokens = 0 # 统计总 tokens 数
decode_start_time = None # 记录 decode 阶段开始时间
decode_end_time = None # 记录 decode 结束时间
async for line in response.content:
try:
decoded_line = line.decode("utf-8").strip()
# 过滤空行
if not decoded_line or not decoded_line.startswith("data: "):
continue
decoded_line = decoded_line[6:].strip() # 去掉 `data: `
# 确保 JSON 数据是合法的
if not decoded_line:
continue
response_data = json.loads(decoded_line) # 解析 JSON
# 确保 choices 存在
choices = response_data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
token = delta.get("content", "")
if token:
if decode_start_time is None:
decode_start_time = time.time() # 记录 decode 开始时间
output_text += token # 追加 token
sys.stdout.write(token) # 直接输出 token
sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端
total_tokens += 1 # 增加 token 计数
decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间
# 检查是否完成
finish_reason = choices[0].get("finish_reason", None)
if finish_reason:
# print(f"\nRequest {request_id}: Done")
break # 结束流式处理
except json.JSONDecodeError as e:
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
except IndexError:
print(f"\nRequest {request_id}: List Index Error - choices is empty")
except Exception as e:
print(f"\nRequest {request_id}: Error parsing stream - {e}")
# 计算 decode 速度
if decode_start_time and decode_end_time and total_tokens > 0:
decode_time = decode_end_time - decode_start_time
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
except Exception as e:
print(f"\nRequest {request_id}: Exception - {e}")
async def main(prompt_id):
async with aiohttp.ClientSession() as session:
tasks = [fetch_event_stream(session, prompt_id)]
await asyncio.gather(*tasks)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--question_id", type=int, default=0, required=False)
args = parser.parse_args()
output_file = "ktransformer_test_results.txt"
asyncio.run(main(args.question_id))
import asyncio
import json
import sys
import aiohttp
import random
import argparse
import yaml
import os
import time
from time import sleep
decodesz = 128
# Server URL (replace with your server URL)
decodesz_list = [128]
ktansformer_prompt1024="""在遥远的翡翠森林里,住着各种各样的神奇生物。其中,有一只名叫露露的小狐狸,她与其他狐狸不同,天生长着一双晶莹剔透的翅膀。然而,这双翅膀却从未带她飞翔过。
一天,森林里传来一个惊人的消息:藏在森林深处的魔法泉水干涸了,所有生物赖以生存的泉水即将枯竭。他们说,只有传说中的“天空之羽”才能唤醒泉水,让它重新流淌。然而,“天空之羽”藏在一座高耸入云的山峰上,没有任何动物能抵达那里。
露露听到这个消息后,决定亲自去寻找“天空之羽”,即便她的翅膀无法飞翔,她也要尝试。最终,露露来到了传说中的高峰脚下,根本无法攀爬。她望着天空,心里充满了不甘:“如果我能飞起来,就不会被这座山挡住了……”
正当她感到迷茫时,一只年迈的白鹰出现在她面前。
“孩子,你为什么到这里来?”白鹰用苍老但慈祥的声音问道。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。
在遥远的翡翠森林里,住着各种各样的神奇生物。其中,有一只名叫露露的小狐狸,她与其他狐狸不同,天生长着一双晶莹剔透的翅膀。然而,这双翅膀却从未带她飞翔过。
一天,森林里传来一个惊人的消息:藏在森林深处的魔法泉水干涸了,所有生物赖以生存的泉水即将枯竭。他们说,只有传说中的“天空之羽”才能唤醒泉水,让它重新流淌。然而,“天空之羽”藏在一座高耸入云的山峰上,没有任何动物能抵达那里。
露露听到这个消息后,决定亲自去寻找“天空之羽”,即便她的翅膀无法飞翔,她也要尝试。最终,露露来到了传说中的高峰脚下,根本无法攀爬。她望着天空,心里充满了不甘:“如果我能飞起来,就不会被这座山挡住了……”
正当她感到迷茫时,一只年迈的白鹰出现在她面前。
“孩子,你为什么到这里来?”白鹰用苍老但慈祥的声音问道。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。想。
请简述这个故事的内涵 故事的内涵这个故事的内涵写10000个字"""
async def fetch_event_stream(session, request_id , prompt):
try:
payload = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": prompt}
],
"model": "DeepSeek-V3",
"temperature": 0.3,
"top_p": 1.0,
"stream": True # 开启流式输出
}
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
}
async with session.post(SERVER_URL, json=payload, headers=headers, timeout=500000) as response:
print(f"Request {request_id}: Connected, status {response.status}")
if response.status != 200:
print(f"Request {request_id}: Error, status {response.status}")
return
output_text = "" # 存储当前 response 的所有 token
total_tokens = 0 # 统计总 tokens 数
decode_start_time = None # 记录 decode 阶段开始时间
decode_end_time = None # 记录 decode 结束时间
async for line in response.content:
try:
decoded_line = line.decode("utf-8").strip()
# 过滤空行
if not decoded_line or not decoded_line.startswith("data: "):
continue
decoded_line = decoded_line[6:].strip() # 去掉 `data: `
# 确保 JSON 数据是合法的
if not decoded_line:
continue
response_data = json.loads(decoded_line) # 解析 JSON
# 确保 choices 存在
choices = response_data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
token = delta.get("content", "")
if token:
if decode_start_time is None:
decode_start_time = time.time() # 记录 decode 开始时间
output_text += token # 追加 token
sys.stdout.write(str(request_id))
sys.stdout.write(token) # 直接输出 token
sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端
total_tokens += 1 # 增加 token 计数
decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间
# 检查是否完成
finish_reason = choices[0].get("finish_reason", None)
if finish_reason:
# print(f"\nRequest {request_id}: Done")
break # 结束流式处理
except json.JSONDecodeError as e:
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
except IndexError:
print(f"\nRequest {request_id}: List Index Error - choices is empty")
except Exception as e:
print(f"\nRequest {request_id}: Error parsing stream - {e}")
# 计算 decode 速度
if decode_start_time and decode_end_time and total_tokens > 0:
decode_time = decode_end_time - decode_start_time
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
except Exception as e:
print(f"\nRequest {request_id}: Exception - {e}")
async def main(concurrent_requests , prompt ):
async with aiohttp.ClientSession() as session:
tasks = [fetch_event_stream(session, i , prompt) for i in range(concurrent_requests)]
await asyncio.gather(*tasks)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
args = parser.parse_args()
SERVER_URL = args.api_url
if args.prompt_lens == 1024:
prompt = ktansformer_prompt1024
elif args.prompt_lens == 2048:
prompt = ktansformer_prompt1024 * 2
asyncio.run(main(args.concurrent, prompt))
......@@ -18,9 +18,26 @@ from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
import socket
warm_uped = False
def get_free_ports(n: int, continue_prot: list):
sockets = []
ports = []
for _ in range(n):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
if port in continue_prot:
s.close()
continue
ports.append(port)
sockets.append(s)
for s in sockets:
s.close()
return ports
def get_compute_capability(device:torch.device = None):
if torch.cuda.is_available():
if device is None:
......@@ -110,7 +127,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module.load()
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal', force_think: bool = False, chunk_prefill_size = 16384, use_flashinfer_mla = False,
mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
......@@ -202,11 +219,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
chunk_start = 0
while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_prefill_size, seq_length)
chunk_end = min(chunk_start + chunk_size, seq_length)
if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_prefill_size
chunk_start += chunk_size
next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample:
......
......@@ -3,7 +3,7 @@
import os
# insert the path of the project
import sys
sys.path.insert(0, "/home/azure/ktransformers")
# sys.path.insert(0, "/home/azure/ktransformers")
import argparse
import torch
from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
......
......@@ -6,4 +6,4 @@ packaging
cpufeature
protobuf
tiktoken
blobfile
\ No newline at end of file
blobfile
......@@ -35,6 +35,8 @@ try:
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
except ImportError:
MUSA_HOME=None
with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1"
class CpuInstructInfo:
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
......@@ -212,7 +214,7 @@ class VersionInfo:
cpu_instruct = self.get_cpu_instruct()
backend_version = ""
if CUDA_HOME is not None:
backend_version = f""
backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}"
elif MUSA_HOME is not None:
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
elif ROCM_HOME is not None:
......@@ -274,11 +276,10 @@ PLAT_TO_CMAKE = {
class CMakeExtension(Extension):
def __init__(self, name: str, sourcedir: str = "") -> None:
def __init__(self, name: str, sourcedir: str) -> None:
super().__init__(name, sources=[])
self.sourcedir = os.fspath(
Path(sourcedir).resolve() / "ktransformers" / "ktransformers_ext")
print(name, sourcedir)
self.sourcedir = sourcedir
class CMakeBuild(BuildExtension):
......@@ -342,16 +343,17 @@ class CMakeBuild(BuildExtension):
f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"]
if self.compiler.compiler_type != "msvc":
if not cmake_generator or cmake_generator == "Ninja":
try:
import ninja
ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
cmake_args += [
"-GNinja",
f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
]
except ImportError:
pass
pass
# try:
# import ninja
# ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
# cmake_args += [
# "-GNinja",
# f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
# ]
# except ImportError:
# pass
else:
# Single config generators are handled "normally"
......@@ -387,10 +389,12 @@ class CMakeBuild(BuildExtension):
build_args += [f"--parallel={cpu_count}"]
print("CMake args:", cmake_args)
build_temp = Path(ext.sourcedir) / "build"
print("build_temp:", build_temp)
if not build_temp.exists():
build_temp.mkdir(parents=True)
result = subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True, text=True
)
print("Standard output:", result.stdout)
print("Standard error:", result.stderr)
......@@ -400,9 +404,9 @@ class CMakeBuild(BuildExtension):
if CUDA_HOME is not None or ROCM_HOME is not None:
ops_module = CUDAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp',
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
'csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'csrc/ktransformers_ext/cuda/binding.cpp',
'csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
],
extra_compile_args={
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
......@@ -415,7 +419,7 @@ if CUDA_HOME is not None or ROCM_HOME is not None:
}
)
elif MUSA_HOME is not None:
SimplePorting(cuda_dir_path="ktransformers/ktransformers_ext/cuda", mapping_rule={
SimplePorting(cuda_dir_path="csrc/ktransformers_ext/cuda", mapping_rule={
# Common rules
"at::cuda": "at::musa",
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
......@@ -423,10 +427,10 @@ elif MUSA_HOME is not None:
"nv_bfloat16": "mt_bfloat16",
}).run()
ops_module = MUSAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
'ktransformers/ktransformers_ext/cuda_musa/binding.cpp',
'csrc/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
'csrc/ktransformers_ext/cuda_musa/binding.cpp',
# TODO: Add Marlin support for MUSA.
# 'ktransformers/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
# 'csrc/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
],
extra_compile_args={
'cxx': ['force_mcc'],
......@@ -440,12 +444,30 @@ elif MUSA_HOME is not None:
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
ext_modules = [
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
ops_module,
CUDAExtension(
'vLLMMarlin', [
'csrc/custom_marlin/binding.cpp',
'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu',
'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu',
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '-Xcompiler', '-fPIC'],
},
)
]
if with_balance:
print("using balance_serve")
ext_modules.append(
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
)
setup(
name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(),
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=[
CMakeExtension("cpuinfer_ext"),
ops_module,
]
ext_modules=ext_modules
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment