Unverified Commit a41d2163 authored by wang jiahao's avatar wang jiahao Committed by GitHub
Browse files

Merge pull request #1013 from kvcache-ai/work-concurrent

In v0.2.4 version, we’ve added highly desired multi-concurrency support to the community through a major refactor of the whole architecture.
parents f142f4df 4ed9744e
import argparse
import random
import time
import json
import requests
import pandas as pd
from datasets import load_dataset
import os
import concurrent.futures
import threading
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['https_proxy'] = ''
os.environ['http_proxy'] = ''
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
class DataEvaluator:
def __init__(self):
self.data = []
def load_data(self, file_path):
"""
从数据文件中加载数据,每条记录对应一个实例
"""
ds = load_dataset(file_path, "all")
df = pd.DataFrame(ds['test'])
for _, row in df.iterrows():
self.data.append(row.to_dict())
def get_prompt(self, record):
"""
结合提示信息和记录数据生成完整的题目
"""
options_str = "\n".join([f"{chr(65 + i)}. {opt}" for i, opt in enumerate(record['choices'])])
prompt = hint + "\nQuestion: " + record['question'] + "\n" + options_str + "\nAnswer: '"
return prompt
def post_processing(self, text):
"""
对生成的文本进行后处理,提取最终答案(只返回最后一个字符)
"""
text = text.lstrip('\n').split('\n')[-1]
return text[-1:]
def score(self, pred, answer):
"""
对比预测答案和正确答案,返回得分
"""
if pred == answer:
return 1
return 0
def generate_text(api_url, question, model_name, stream=False):
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': 'Bearer ' # 如有需要,请填入 API Key
}
data = {
"messages": [{"content": question, "role": "user"}],
"model": model_name,
"stream": stream,
}
print("POST data:", data)
response = requests.post(api_url, headers=headers, json=data, timeout=5000000)
if response.status_code == 200:
result = response.json()
return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip()
else:
print(f"API Request failed with status code {response.status_code}")
return None
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
start_total_time = time.time()
total_score = 0
results = []
file_lock = threading.Lock()
# 打乱数据顺序,并选择需要测试的实例数
random.seed(42)
random.shuffle(data_evaluator.data)
data_subset = data_evaluator.data[:min(concurrent_requests, len(data_evaluator.data))]
batch_size = 10 # 每批次最多 10 个实例
def worker(index, data_item):
nonlocal total_score
question = data_evaluator.get_prompt(data_item)
start_time = time.time()
try:
prediction = generate_text(api_url, question, model_name)
if prediction is None:
raise Exception(f"Failed to get prediction for question: {question}")
# 正确答案:将数字转换成字母(0->A, 1->B, 2->C, 3->D)
answer = chr(data_item['answer'] + 65)
processed_prediction = data_evaluator.post_processing(prediction)
score = data_evaluator.score(processed_prediction, answer)
elapsed_time = time.time() - start_time
result_data = {
"question_id": index,
"answer": answer,
"prediction": processed_prediction,
"real_prediction": prediction,
"score": score,
"time": elapsed_time
}
# 写入结果时加锁保证线程安全
with file_lock:
with open(result_file, 'a', encoding='utf-8') as f:
json.dump(result_data, f, ensure_ascii=False, indent=4)
f.write("\n")
return result_data
except Exception as e:
print(f"Error processing request {index}: {e}")
return None
# 按批次处理,每批最多 10 个任务
for batch_start in range(0, len(data_subset), batch_size):
batch = data_subset[batch_start: batch_start + batch_size]
with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
futures = [executor.submit(worker, batch_start + j, data_item) for j, data_item in enumerate(batch)]
for future in concurrent.futures.as_completed(futures):
res = future.result()
if res is not None:
results.append(res)
total_score += res['score']
total_time = time.time() - start_total_time
throughput = len(data_subset) / total_time if total_time > 0 else 0
with open(log_file, 'a', encoding='utf-8') as log_f:
log_f.write(f"Total Time: {total_time:.2f} seconds\n")
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
average_score = total_score / len(data_subset) if data_subset else 0
log_f.write(f"Average Score: {average_score}\n")
log_f.write('-' * 40 + '\n')
print(f"Results saved to {result_file}")
print(f"Log saved to {log_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="API Generate Tester")
parser.add_argument("--concurrent", type=int, default=1000, help="需要测试的实例总数")
parser.add_argument("--file", type=str, default="cais/mmlu", help="数据文件路径")
parser.add_argument("--result", type=str, default="./mmlu_result_silicon.json", help="结果文件保存路径")
parser.add_argument("--log", type=str, default="./mmlu_result_silicon.log", help="日志文件保存路径")
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="模型名称或路径")
parser.add_argument("--api_url", type=str, default="http://localhost:10006/v1/chat/completions", help="API URL")
args = parser.parse_args()
data_evaluator = DataEvaluator()
data_evaluator.load_data(args.file)
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
import asyncio
import json
import sys
import aiohttp
import random
import argparse
import yaml
import os
import time
from time import sleep
decodesz = 128
# Server URL (replace with your server URL)
SERVER_URL = "http://localhost:10002/v1/chat/completions"
bf_list = [1]
decodesz_list = [128]
prompt_list = ['请你介绍下秦始皇', '3.9 和 3.11 哪个大', '抗衰老有何妙招', '给我讲个故事']
async def fetch_event_stream(session, request_id):
try:
payload = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": prompt_list[request_id]}
],
"model": "DeepSeek-V3",
"temperature": 0.3,
"top_p": 1.0,
"stream": True # 开启流式输出
}
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
}
async with session.post(SERVER_URL, json=payload, headers=headers, timeout=50000) as response:
print(f"Request {request_id}: Connected, status {response.status}")
if response.status != 200:
print(f"Request {request_id}: Error, status {response.status}")
return
output_text = "" # 存储当前 response 的所有 token
total_tokens = 0 # 统计总 tokens 数
decode_start_time = None # 记录 decode 阶段开始时间
decode_end_time = None # 记录 decode 结束时间
async for line in response.content:
try:
decoded_line = line.decode("utf-8").strip()
# 过滤空行
if not decoded_line or not decoded_line.startswith("data: "):
continue
decoded_line = decoded_line[6:].strip() # 去掉 `data: `
# 确保 JSON 数据是合法的
if not decoded_line:
continue
response_data = json.loads(decoded_line) # 解析 JSON
# 确保 choices 存在
choices = response_data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
token = delta.get("content", "")
if token:
if decode_start_time is None:
decode_start_time = time.time() # 记录 decode 开始时间
output_text += token # 追加 token
sys.stdout.write(token) # 直接输出 token
sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端
total_tokens += 1 # 增加 token 计数
decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间
# 检查是否完成
finish_reason = choices[0].get("finish_reason", None)
if finish_reason:
# print(f"\nRequest {request_id}: Done")
break # 结束流式处理
except json.JSONDecodeError as e:
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
except IndexError:
print(f"\nRequest {request_id}: List Index Error - choices is empty")
except Exception as e:
print(f"\nRequest {request_id}: Error parsing stream - {e}")
# 计算 decode 速度
if decode_start_time and decode_end_time and total_tokens > 0:
decode_time = decode_end_time - decode_start_time
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
except Exception as e:
print(f"\nRequest {request_id}: Exception - {e}")
async def main(prompt_id):
async with aiohttp.ClientSession() as session:
tasks = [fetch_event_stream(session, prompt_id)]
await asyncio.gather(*tasks)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--question_id", type=int, default=0, required=False)
args = parser.parse_args()
output_file = "ktransformer_test_results.txt"
asyncio.run(main(args.question_id))
import asyncio
import json
import sys
import aiohttp
import random
import argparse
import yaml
import os
import time
from time import sleep
decodesz = 128
# Server URL (replace with your server URL)
decodesz_list = [128]
ktansformer_prompt1024="""在遥远的翡翠森林里,住着各种各样的神奇生物。其中,有一只名叫露露的小狐狸,她与其他狐狸不同,天生长着一双晶莹剔透的翅膀。然而,这双翅膀却从未带她飞翔过。
一天,森林里传来一个惊人的消息:藏在森林深处的魔法泉水干涸了,所有生物赖以生存的泉水即将枯竭。他们说,只有传说中的“天空之羽”才能唤醒泉水,让它重新流淌。然而,“天空之羽”藏在一座高耸入云的山峰上,没有任何动物能抵达那里。
露露听到这个消息后,决定亲自去寻找“天空之羽”,即便她的翅膀无法飞翔,她也要尝试。最终,露露来到了传说中的高峰脚下,根本无法攀爬。她望着天空,心里充满了不甘:“如果我能飞起来,就不会被这座山挡住了……”
正当她感到迷茫时,一只年迈的白鹰出现在她面前。
“孩子,你为什么到这里来?”白鹰用苍老但慈祥的声音问道。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。
在遥远的翡翠森林里,住着各种各样的神奇生物。其中,有一只名叫露露的小狐狸,她与其他狐狸不同,天生长着一双晶莹剔透的翅膀。然而,这双翅膀却从未带她飞翔过。
一天,森林里传来一个惊人的消息:藏在森林深处的魔法泉水干涸了,所有生物赖以生存的泉水即将枯竭。他们说,只有传说中的“天空之羽”才能唤醒泉水,让它重新流淌。然而,“天空之羽”藏在一座高耸入云的山峰上,没有任何动物能抵达那里。
露露听到这个消息后,决定亲自去寻找“天空之羽”,即便她的翅膀无法飞翔,她也要尝试。最终,露露来到了传说中的高峰脚下,根本无法攀爬。她望着天空,心里充满了不甘:“如果我能飞起来,就不会被这座山挡住了……”
正当她感到迷茫时,一只年迈的白鹰出现在她面前。
“孩子,你为什么到这里来?”白鹰用苍老但慈祥的声音问道。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。
露露将森林的困境告诉了白鹰,并说自己愿意付出一切,只要能拯救森林。
白鹰沉思了一会儿,缓缓说道:“你的翅膀并不是没有力量,而是你一直害怕它们不能飞翔。相信自己,勇敢跳下去。”
露露听后,心跳加速,她望着万丈深渊,犹豫不决就在那一瞬间,她竟然真的飞了起来!露露兴奋极了,她终于看到了“天空之羽”——一根散发着金光的羽毛,轻盈地悬浮在空中。露露小心翼翼地将“天空之羽”叼住,振翅返回森林。
当她将羽毛放入干涸的泉水中时,一道金光闪耀。整个森林恢复了生机,花草重新绽放,动物们欢欣鼓舞。从那以后,露露成为了森林的英雄,她是翱翔天空的勇士。她让所有动物都明白:只要相信自己,勇敢前行,就能实现自己的梦想。
请简述这个故事的内涵 写10000个字。想。
请简述这个故事的内涵 故事的内涵这个故事的内涵写10000个字"""
async def fetch_event_stream(session, request_id , prompt):
try:
payload = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": prompt}
],
"model": "DeepSeek-V3",
"temperature": 0.3,
"top_p": 1.0,
"stream": True # 开启流式输出
}
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
}
async with session.post(SERVER_URL, json=payload, headers=headers, timeout=500000) as response:
print(f"Request {request_id}: Connected, status {response.status}")
if response.status != 200:
print(f"Request {request_id}: Error, status {response.status}")
return
output_text = "" # 存储当前 response 的所有 token
total_tokens = 0 # 统计总 tokens 数
decode_start_time = None # 记录 decode 阶段开始时间
decode_end_time = None # 记录 decode 结束时间
async for line in response.content:
try:
decoded_line = line.decode("utf-8").strip()
# 过滤空行
if not decoded_line or not decoded_line.startswith("data: "):
continue
decoded_line = decoded_line[6:].strip() # 去掉 `data: `
# 确保 JSON 数据是合法的
if not decoded_line:
continue
response_data = json.loads(decoded_line) # 解析 JSON
# 确保 choices 存在
choices = response_data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
token = delta.get("content", "")
if token:
if decode_start_time is None:
decode_start_time = time.time() # 记录 decode 开始时间
output_text += token # 追加 token
sys.stdout.write(str(request_id))
sys.stdout.write(token) # 直接输出 token
sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端
total_tokens += 1 # 增加 token 计数
decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间
# 检查是否完成
finish_reason = choices[0].get("finish_reason", None)
if finish_reason:
# print(f"\nRequest {request_id}: Done")
break # 结束流式处理
except json.JSONDecodeError as e:
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
except IndexError:
print(f"\nRequest {request_id}: List Index Error - choices is empty")
except Exception as e:
print(f"\nRequest {request_id}: Error parsing stream - {e}")
# 计算 decode 速度
if decode_start_time and decode_end_time and total_tokens > 0:
decode_time = decode_end_time - decode_start_time
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
except Exception as e:
print(f"\nRequest {request_id}: Exception - {e}")
async def main(concurrent_requests , prompt ):
async with aiohttp.ClientSession() as session:
tasks = [fetch_event_stream(session, i , prompt) for i in range(concurrent_requests)]
await asyncio.gather(*tasks)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
args = parser.parse_args()
SERVER_URL = args.api_url
if args.prompt_lens == 1024:
prompt = ktansformer_prompt1024
elif args.prompt_lens == 2048:
prompt = ktansformer_prompt1024 * 2
asyncio.run(main(args.concurrent, prompt))
...@@ -18,9 +18,26 @@ from ktransformers.models.custom_cache import StaticCache ...@@ -18,9 +18,26 @@ from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer from ktransformers.util.textstream import TextStreamer
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
import socket
warm_uped = False warm_uped = False
def get_free_ports(n: int, continue_prot: list):
sockets = []
ports = []
for _ in range(n):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
if port in continue_prot:
s.close()
continue
ports.append(port)
sockets.append(s)
for s in sockets:
s.close()
return ports
def get_compute_capability(device:torch.device = None): def get_compute_capability(device:torch.device = None):
if torch.cuda.is_available(): if torch.cuda.is_available():
if device is None: if device is None:
...@@ -110,7 +127,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): ...@@ -110,7 +127,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module.load() module.load()
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True, def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal', force_think: bool = False, chunk_prefill_size = 16384, use_flashinfer_mla = False, mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None): num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
import os import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
...@@ -202,11 +219,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud ...@@ -202,11 +219,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
chunk_start = 0 chunk_start = 0
while chunk_start < seq_length: while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_prefill_size, seq_length) chunk_end = min(chunk_start + chunk_size, seq_length)
if past_key_values != None: if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end] past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values) logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_prefill_size chunk_start += chunk_size
next_token_scores = logits_warper(inputs, logits[:, -1, :]) next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample: if generation_config.do_sample:
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import os import os
# insert the path of the project # insert the path of the project
import sys import sys
sys.path.insert(0, "/home/azure/ktransformers") # sys.path.insert(0, "/home/azure/ktransformers")
import argparse import argparse
import torch import torch
from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
......
...@@ -6,4 +6,4 @@ packaging ...@@ -6,4 +6,4 @@ packaging
cpufeature cpufeature
protobuf protobuf
tiktoken tiktoken
blobfile blobfile
\ No newline at end of file
...@@ -35,6 +35,8 @@ try: ...@@ -35,6 +35,8 @@ try:
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
except ImportError: except ImportError:
MUSA_HOME=None MUSA_HOME=None
with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1"
class CpuInstructInfo: class CpuInstructInfo:
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE") CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
...@@ -212,7 +214,7 @@ class VersionInfo: ...@@ -212,7 +214,7 @@ class VersionInfo:
cpu_instruct = self.get_cpu_instruct() cpu_instruct = self.get_cpu_instruct()
backend_version = "" backend_version = ""
if CUDA_HOME is not None: if CUDA_HOME is not None:
backend_version = f"" backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}"
elif MUSA_HOME is not None: elif MUSA_HOME is not None:
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}" backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
elif ROCM_HOME is not None: elif ROCM_HOME is not None:
...@@ -274,11 +276,10 @@ PLAT_TO_CMAKE = { ...@@ -274,11 +276,10 @@ PLAT_TO_CMAKE = {
class CMakeExtension(Extension): class CMakeExtension(Extension):
def __init__(self, name: str, sourcedir: str = "") -> None: def __init__(self, name: str, sourcedir: str) -> None:
super().__init__(name, sources=[]) super().__init__(name, sources=[])
self.sourcedir = os.fspath( print(name, sourcedir)
Path(sourcedir).resolve() / "ktransformers" / "ktransformers_ext") self.sourcedir = sourcedir
class CMakeBuild(BuildExtension): class CMakeBuild(BuildExtension):
...@@ -342,16 +343,17 @@ class CMakeBuild(BuildExtension): ...@@ -342,16 +343,17 @@ class CMakeBuild(BuildExtension):
f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"] f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"]
if self.compiler.compiler_type != "msvc": if self.compiler.compiler_type != "msvc":
if not cmake_generator or cmake_generator == "Ninja": if not cmake_generator or cmake_generator == "Ninja":
try: pass
import ninja # try:
# import ninja
ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
cmake_args += [ # ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
"-GNinja", # cmake_args += [
f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", # "-GNinja",
] # f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
except ImportError: # ]
pass # except ImportError:
# pass
else: else:
# Single config generators are handled "normally" # Single config generators are handled "normally"
...@@ -387,10 +389,12 @@ class CMakeBuild(BuildExtension): ...@@ -387,10 +389,12 @@ class CMakeBuild(BuildExtension):
build_args += [f"--parallel={cpu_count}"] build_args += [f"--parallel={cpu_count}"]
print("CMake args:", cmake_args) print("CMake args:", cmake_args)
build_temp = Path(ext.sourcedir) / "build" build_temp = Path(ext.sourcedir) / "build"
print("build_temp:", build_temp)
if not build_temp.exists(): if not build_temp.exists():
build_temp.mkdir(parents=True) build_temp.mkdir(parents=True)
result = subprocess.run( result = subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True, text=True
) )
print("Standard output:", result.stdout) print("Standard output:", result.stdout)
print("Standard error:", result.stderr) print("Standard error:", result.stderr)
...@@ -400,9 +404,9 @@ class CMakeBuild(BuildExtension): ...@@ -400,9 +404,9 @@ class CMakeBuild(BuildExtension):
if CUDA_HOME is not None or ROCM_HOME is not None: if CUDA_HOME is not None or ROCM_HOME is not None:
ops_module = CUDAExtension('KTransformersOps', [ ops_module = CUDAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu', 'csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp', 'csrc/ktransformers_ext/cuda/binding.cpp',
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu' 'csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
], ],
extra_compile_args={ extra_compile_args={
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'], 'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
...@@ -415,7 +419,7 @@ if CUDA_HOME is not None or ROCM_HOME is not None: ...@@ -415,7 +419,7 @@ if CUDA_HOME is not None or ROCM_HOME is not None:
} }
) )
elif MUSA_HOME is not None: elif MUSA_HOME is not None:
SimplePorting(cuda_dir_path="ktransformers/ktransformers_ext/cuda", mapping_rule={ SimplePorting(cuda_dir_path="csrc/ktransformers_ext/cuda", mapping_rule={
# Common rules # Common rules
"at::cuda": "at::musa", "at::cuda": "at::musa",
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"", "#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
...@@ -423,10 +427,10 @@ elif MUSA_HOME is not None: ...@@ -423,10 +427,10 @@ elif MUSA_HOME is not None:
"nv_bfloat16": "mt_bfloat16", "nv_bfloat16": "mt_bfloat16",
}).run() }).run()
ops_module = MUSAExtension('KTransformersOps', [ ops_module = MUSAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu', 'csrc/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
'ktransformers/ktransformers_ext/cuda_musa/binding.cpp', 'csrc/ktransformers_ext/cuda_musa/binding.cpp',
# TODO: Add Marlin support for MUSA. # TODO: Add Marlin support for MUSA.
# 'ktransformers/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu' # 'csrc/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
], ],
extra_compile_args={ extra_compile_args={
'cxx': ['force_mcc'], 'cxx': ['force_mcc'],
...@@ -440,12 +444,30 @@ elif MUSA_HOME is not None: ...@@ -440,12 +444,30 @@ elif MUSA_HOME is not None:
else: else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
ext_modules = [
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
ops_module,
CUDAExtension(
'vLLMMarlin', [
'csrc/custom_marlin/binding.cpp',
'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu',
'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu',
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '-Xcompiler', '-fPIC'],
},
)
]
if with_balance:
print("using balance_serve")
ext_modules.append(
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
)
setup( setup(
name=VersionInfo.PACKAGE_NAME, name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(), version=VersionInfo().get_package_version(),
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=[ ext_modules=ext_modules
CMakeExtension("cpuinfer_ext"),
ops_module,
]
) )
Subproject commit fd94393fb5b8ba8bae9c0bd6ab1c2a429d81ac76
This source diff could not be displayed because it is too large. You can view the blob instead.
// __ _____ _____ _____
// __| | __| | | | JSON for Modern C++
// | | |__ | | | | | | version 3.11.3
// |_____|_____|_____|_|___| https://github.com/nlohmann/json
//
// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>
// SPDX-License-Identifier: MIT
#ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_
#define INCLUDE_NLOHMANN_JSON_FWD_HPP_
#include <cstdint> // int64_t, uint64_t
#include <map> // map
#include <memory> // allocator
#include <string> // string
#include <vector> // vector
// #include <nlohmann/detail/abi_macros.hpp>
// __ _____ _____ _____
// __| | __| | | | JSON for Modern C++
// | | |__ | | | | | | version 3.11.3
// |_____|_____|_____|_|___| https://github.com/nlohmann/json
//
// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>
// SPDX-License-Identifier: MIT
// This file contains all macro definitions affecting or depending on the ABI
#ifndef JSON_SKIP_LIBRARY_VERSION_CHECK
#if defined(NLOHMANN_JSON_VERSION_MAJOR) && defined(NLOHMANN_JSON_VERSION_MINOR) && defined(NLOHMANN_JSON_VERSION_PATCH)
#if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 11 || NLOHMANN_JSON_VERSION_PATCH != 3
#warning "Already included a different version of the library!"
#endif
#endif
#endif
#define NLOHMANN_JSON_VERSION_MAJOR 3 // NOLINT(modernize-macro-to-enum)
#define NLOHMANN_JSON_VERSION_MINOR 11 // NOLINT(modernize-macro-to-enum)
#define NLOHMANN_JSON_VERSION_PATCH 3 // NOLINT(modernize-macro-to-enum)
#ifndef JSON_DIAGNOSTICS
#define JSON_DIAGNOSTICS 0
#endif
#ifndef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON
#define JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 0
#endif
#if JSON_DIAGNOSTICS
#define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS _diag
#else
#define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS
#endif
#if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON
#define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON _ldvcmp
#else
#define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON
#endif
#ifndef NLOHMANN_JSON_NAMESPACE_NO_VERSION
#define NLOHMANN_JSON_NAMESPACE_NO_VERSION 0
#endif
// Construct the namespace ABI tags component
#define NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) json_abi ## a ## b
#define NLOHMANN_JSON_ABI_TAGS_CONCAT(a, b) \
NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b)
#define NLOHMANN_JSON_ABI_TAGS \
NLOHMANN_JSON_ABI_TAGS_CONCAT( \
NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS, \
NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON)
// Construct the namespace version component
#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) \
_v ## major ## _ ## minor ## _ ## patch
#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(major, minor, patch) \
NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch)
#if NLOHMANN_JSON_NAMESPACE_NO_VERSION
#define NLOHMANN_JSON_NAMESPACE_VERSION
#else
#define NLOHMANN_JSON_NAMESPACE_VERSION \
NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(NLOHMANN_JSON_VERSION_MAJOR, \
NLOHMANN_JSON_VERSION_MINOR, \
NLOHMANN_JSON_VERSION_PATCH)
#endif
// Combine namespace components
#define NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) a ## b
#define NLOHMANN_JSON_NAMESPACE_CONCAT(a, b) \
NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b)
#ifndef NLOHMANN_JSON_NAMESPACE
#define NLOHMANN_JSON_NAMESPACE \
nlohmann::NLOHMANN_JSON_NAMESPACE_CONCAT( \
NLOHMANN_JSON_ABI_TAGS, \
NLOHMANN_JSON_NAMESPACE_VERSION)
#endif
#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN
#define NLOHMANN_JSON_NAMESPACE_BEGIN \
namespace nlohmann \
{ \
inline namespace NLOHMANN_JSON_NAMESPACE_CONCAT( \
NLOHMANN_JSON_ABI_TAGS, \
NLOHMANN_JSON_NAMESPACE_VERSION) \
{
#endif
#ifndef NLOHMANN_JSON_NAMESPACE_END
#define NLOHMANN_JSON_NAMESPACE_END \
} /* namespace (inline namespace) NOLINT(readability/namespace) */ \
} // namespace nlohmann
#endif
/*!
@brief namespace for Niels Lohmann
@see https://github.com/nlohmann
@since version 1.0.0
*/
NLOHMANN_JSON_NAMESPACE_BEGIN
/*!
@brief default JSONSerializer template argument
This serializer ignores the template arguments and uses ADL
([argument-dependent lookup](https://en.cppreference.com/w/cpp/language/adl))
for serialization.
*/
template<typename T = void, typename SFINAE = void>
struct adl_serializer;
/// a class to store JSON values
/// @sa https://json.nlohmann.me/api/basic_json/
template<template<typename U, typename V, typename... Args> class ObjectType =
std::map,
template<typename U, typename... Args> class ArrayType = std::vector,
class StringType = std::string, class BooleanType = bool,
class NumberIntegerType = std::int64_t,
class NumberUnsignedType = std::uint64_t,
class NumberFloatType = double,
template<typename U> class AllocatorType = std::allocator,
template<typename T, typename SFINAE = void> class JSONSerializer =
adl_serializer,
class BinaryType = std::vector<std::uint8_t>, // cppcheck-suppress syntaxError
class CustomBaseClass = void>
class basic_json;
/// @brief JSON Pointer defines a string syntax for identifying a specific value within a JSON document
/// @sa https://json.nlohmann.me/api/json_pointer/
template<typename RefStringType>
class json_pointer;
/*!
@brief default specialization
@sa https://json.nlohmann.me/api/json/
*/
using json = basic_json<>;
/// @brief a minimal map-like container that preserves insertion order
/// @sa https://json.nlohmann.me/api/ordered_map/
template<class Key, class T, class IgnoredLess, class Allocator>
struct ordered_map;
/// @brief specialization that maintains the insertion order of object keys
/// @sa https://json.nlohmann.me/api/ordered_json/
using ordered_json = basic_json<nlohmann::ordered_map>;
NLOHMANN_JSON_NAMESPACE_END
#endif // INCLUDE_NLOHMANN_JSON_FWD_HPP_
Subproject commit f13cdd052eeae5e89decc11bf03697d0f78b15bc
Subproject commit 48bcf39a661a13be22666ac64db8a7f886f2637e
Subproject commit 953a09abc39096da9e216b6eb0002c681cdc1199
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment