"vscode:/vscode.git/clone" did not exist on "887c2b4575772f8f70e0bc55dd701ae274d3ab32"
Commit 18c42e67 authored by chenxl's avatar chenxl
Browse files

Initial commit

parents
from time import time
from typing import Optional,List
from uuid import uuid4
from ktransformers.server.models.assistants.assistants import Assistant
from ktransformers.server.schemas.assistants.assistants import AssistantCreate,AssistantObject,AssistantModify
from ktransformers.server.utils.sql_utils import SQLUtil
from ktransformers.server.config.log import logger
from ktransformers.server.schemas.base import Order
class AssistantDatabaseManager:
def __init__(self) -> None:
self.sql_util = SQLUtil()
def create_assistant_object(self, assistant: AssistantCreate) -> AssistantObject:
assistant = AssistantObject(
**assistant.model_dump(mode='json'),
id=str(uuid4()),
object='assistant',
created_at=int(time()),
)
return assistant
def db_count_assistants(self) -> int:
with self.sql_util.get_db() as db:
return db.query(Assistant).count()
def db_create_assistant(self, assistant: AssistantCreate):
ass_obj = self.create_assistant_object(assistant)
ass_obj.sync_db()
return ass_obj
def db_list_assistants(self, limit: Optional[int], order: Order) -> List[AssistantObject]:
with self.sql_util.get_db() as db:
query = db.query(Assistant).order_by(
order.to_sqlalchemy_order()(Assistant.created_at))
if limit is not None:
db_assistants = query.limit(limit)
else:
db_assistants = query.all()
return [AssistantObject.model_validate(a.__dict__) for a in db_assistants]
def db_get_assistant_by_id(self, assistant_id: str) -> Optional[AssistantObject]:
with self.sql_util.get_db() as db:
db_assistant = db.query(Assistant).filter(
Assistant.id == assistant_id).first()
if db_assistant is None:
logger.debug(f"no assistant with id {str}")
return None
return AssistantObject.model_validate(db_assistant.__dict__)
def db_update_assistant_by_id(self, assistant_id: str, assistant: AssistantModify):
with self.sql_util.get_db() as db:
db_assistant = db.query(Assistant).filter(
Assistant.id == assistant_id).first()
self.sql_util.db_update_commit_refresh(db, db_assistant, assistant)
return AssistantObject.model_validate(db_assistant.__dict__)
def db_delete_assistant_by_id(self, assistant_id: str):
with self.sql_util.get_db() as db:
db_assistant = db.query(Assistant).filter(
Assistant.id == assistant_id).first()
db.delete(db_assistant)
db.commit()
from time import time
from typing import Optional
from uuid import uuid4
from ktransformers.server.models.assistants.messages import Message
from ktransformers.server.schemas.assistants.messages import MessageCore, MessageCreate, MessageObject
from ktransformers.server.schemas.base import Order,ObjectID
from ktransformers.server.utils.sql_utils import SQLUtil
class MessageDatabaseManager:
def __init__(self) -> None:
self.sql_util = SQLUtil()
@staticmethod
def create_db_message_by_core(message: MessageCore):
message_dict = message.model_dump(mode="json")
return Message(**message_dict, id=str(uuid4()), created_at=int(time()))
def create_db_message(self, message: MessageCreate):
return MessageDatabaseManager.create_db_message_by_core(message.to_core())
def db_add_message(self, message: Message):
with self.sql_util.get_db() as db:
db.add(message)
self.sql_util.db_add_commit_refresh(db, message)
def db_create_message(self, thread_id: str, message: MessageCreate, status: MessageObject.Status):
db_message = self.create_db_message(message)
db_message.status = status.value
db_message.thread_id = thread_id
self.db_add_message(db_message)
return MessageObject.model_validate(db_message.__dict__)
@staticmethod
def create_message_object(thread_id: ObjectID, run_id: ObjectID, message: MessageCreate):
core = message.to_core()
return MessageObject(
**core.model_dump(mode='json'),
id=str(uuid4()),
object='thread.message',
created_at=int(time()),
thread_id=thread_id,
run_id=run_id,
status=MessageObject.Status.in_progress,
)
def db_sync_message(self, message: MessageObject):
db_message = Message(
**message.model_dump(mode="json"),
)
with self.sql_util.get_db() as db:
self.sql_util.db_merge_commit(db, db_message)
def db_list_messages_of_thread(
self, thread_id: str, limit: Optional[int] = None, order: Order = Order.DESC):
# logger.debug(
# f"list messages of: {thread_id}, limit {limit}, order {order}")
with self.sql_util.get_db() as db:
query = (
db.query(Message)
.filter(Message.thread_id == thread_id)
.order_by(order.to_sqlalchemy_order()(Message.created_at))
)
if limit is not None:
messages = query.limit(limit)
else:
messages = query.all()
message_list = [MessageObject.model_validate(m.__dict__) for m in messages]
return message_list
def db_get_message_by_id(self, thread_id: ObjectID, message_id: ObjectID) -> MessageObject:
with self.sql_util.get_db() as db:
message = db.query(Message).filter(
Message.id == message_id).first()
assert message.thread_id == thread_id
message_info = MessageObject.model_validate(message.__dict__)
return message_info
def db_delete_message_by_id(self, thread_id: ObjectID, message_id: ObjectID):
with self.sql_util.get_db() as db:
message = db.query(Message).filter(
Message.id == message_id).first()
assert message.thread_id == thread_id
db.delete(message)
db.commit()
from time import time
from uuid import uuid4
from ktransformers.server.models.assistants.runs import Run
from ktransformers.server.schemas.assistants.runs import RunCreate,RunObject
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.utils.sql_utils import SQLUtil
class RunsDatabaseManager:
def __init__(self) -> None:
self.sql_util = SQLUtil()
def create_run_object(self, thread_id: ObjectID, run: RunCreate) -> RunObject:
run_obj = RunObject(
**run.model_dump(mode='json', exclude={"stream"}),
id=str(uuid4()),
object='run',
created_at=int(time()),
thread_id=thread_id,
status=RunObject.Status.queued,
)
run_obj.set_compute_save(0)
return run_obj
def db_create_run(self, thread_id: str, run: RunCreate):
db_run = Run(
**run.model_dump(mode="json", exclude={"stream"}),
id=str(uuid4()),
created_at=int(time()),
status="queued",
thread_id=thread_id,
)
with self.sql_util.get_db() as db:
self.sql_util.db_add_commit_refresh(db, db_run)
run_obj = RunObject.model_validate(db_run.__dict__)
run_obj.set_compute_save(0)
return run_obj
def db_sync_run(self, run: RunObject) -> None:
db_run = Run(
**run.model_dump(mode='json'),
)
with self.sql_util.get_db() as db:
self.sql_util.db_merge_commit(db, db_run)
def db_get_run(self, run_id: ObjectID) -> RunObject:
with self.sql_util.get_db() as db:
db_run = db.query(Run).filter(Run.id == run_id).first()
return RunObject.model_validate(db_run.__dict__)
from time import time
from typing import Optional,List
from uuid import uuid4
from ktransformers.server.models.assistants.messages import Message
from ktransformers.server.models.assistants.threads import Thread
from ktransformers.server.schemas.assistants.threads import ThreadCreate,ThreadObject
from ktransformers.server.schemas.base import ObjectID, Order
from ktransformers.server.schemas.conversation import ThreadPreview
from ktransformers.server.utils.sql_utils import SQLUtil
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
from ktransformers.server.config.log import logger
from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager
class ThreadsDatabaseManager:
def __init__(self) -> None:
self.sql_util = SQLUtil()
self.message_manager = MessageDatabaseManager()
self.assistant_maanager = AssistantDatabaseManager()
def db_create_thread(self, thread: ThreadCreate):
thread_id = str(uuid4())
db_messages = []
with self.sql_util.get_db() as db:
if thread.messages is not None:
logger.debug("Creating messages first for thread")
for message in thread.messages:
db_message: Message = MessageDatabaseManager.create_db_message_by_core(
message)
db_message.role = "user"
db_message.thread_id = thread_id
db.add(db_message)
db_messages.append(db_message)
db_thread = Thread(
**thread.model_dump(exclude="messages"),
id=str(uuid4()),
created_at=int(time()),
messages=db_messages,
)
self.sql_util.db_add_commit_refresh(db, db_thread)
thread_obj = ThreadObject.model_validate(db_thread.__dict__)
if 'assistant_id' in thread.meta_data:
# assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id'], db)
assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id'])
logger.info(
f'Append this related thread to assistant {assistant.id}')
assistant.append_related_threads([thread_obj.id])
assistant.sync_db(db)
return thread_obj
def db_get_thread_by_id(self, thread_id: ObjectID):
with self.sql_util.get_db() as db:
db_thread = db.query(Thread).filter(Thread.id == thread_id).first()
return ThreadObject.model_validate(db_thread.__dict__)
def db_list_threads(self, limit: Optional[int], order: Order) -> List[ThreadObject]:
with self.sql_util.get_db() as db:
query = db.query(Thread).order_by(order.to_sqlalchemy_order()(
Thread.created_at)).filter(~Thread.meta_data.contains('assistant_id'))
if limit is not None:
db_threads = query.limit(limit)
else:
db_threads = query.all()
return [ThreadObject.model_validate(tool.__dict__) for tool in db_threads]
def db_list_threads_preview(self, limit: Optional[int], order: Order) -> List[ThreadPreview]:
threads = self.db_list_threads(limit, order)
previews = []
for thread in threads:
messages = self.message_manager.db_list_messages_of_thread(
thread.id, limit=2, order=Order.ASC)
if len(messages) == 2:
message = messages[0]
assistant = self.assistant_maanager.db_get_assistant_by_id(
messages[1].assistant_id)
else:
message = None
assistant = None
previews.append(ThreadPreview(
assistant=assistant, thread=thread, first_message=message))
return previews
def db_delete_thread_by_id(self, thread_id: ObjectID):
with self.sql_util.get_db() as db:
db_thread = db.query(Thread).filter(Thread.id == thread_id).first()
db.delete(db_thread)
# TODO delete related messages and runs and other stuff or just gc
db.commit()
from fastapi import HTTPException, status
def db_exception():
return HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="DB Error",
)
def not_implemented(what):
return HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail=f"{what} not implemented",
)
def internal_server_error(what):
return HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"{what}")
def request_error(what):
return HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{what}")
import os
import re
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import uvicorn.logging
import argparse
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import create_interface
from ktransformers.server.backend.args import default_args
from fastapi.openapi.utils import get_openapi
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from ktransformers.server.api import router, post_db_creation_operations
from ktransformers.server.utils.sql_utils import Base, SQLUtil
from ktransformers.server.config.log import logger
def mount_app_routes(mount_app: FastAPI):
sql_util = SQLUtil()
logger.info("Creating SQL tables")
Base.metadata.create_all(bind=sql_util.sqlalchemy_engine)
post_db_creation_operations()
mount_app.include_router(router)
def create_app():
cfg = Config()
app = FastAPI()
if Config().web_cross_domain:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
mount_app_routes(app)
if cfg.mount_web:
mount_index_routes(app)
return app
def update_web_port(config_file: str):
ip_port_pattern = r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}"
with open(config_file, "r", encoding="utf-8") as f_cfg:
web_config = f_cfg.read()
ip_port = "localhost:" + str(Config().server_port)
new_web_config = re.sub(ip_port_pattern, ip_port, web_config)
with open(config_file, "w", encoding="utf-8") as f_cfg:
f_cfg.write(new_web_config)
def mount_index_routes(app: FastAPI):
project_dir = os.path.dirname(os.path.dirname(__file__))
web_dir = os.path.join(project_dir, "website/dist")
web_config_file = os.path.join(web_dir, "config.js")
update_web_port(web_config_file)
if os.path.exists(web_dir):
app.mount("/web", StaticFiles(directory=web_dir), name="static")
else:
err_str = f"No website resources in {web_dir}, please complile the website by npm first"
logger.error(err_str)
print(err_str)
exit(1)
def run_api(app, host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app,
host=host,
port=port,
ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"),
)
else:
uvicorn.run(app, host=host, port=port, log_level='debug')
def custom_openapi(app):
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="ktransformers server",
version="1.0.0",
summary="This is a server that provides a RESTful API for ktransformers.",
description="We provided chat completion and openai assistant interfaces.",
routes=app.routes,
)
openapi_schema["info"]["x-logo"] = {
"url": "https://kvcache.ai/media/icon_1.png"
}
app.openapi_schema = openapi_schema
return app.openapi_schema
def main():
cfg = Config()
parser = argparse.ArgumentParser(prog='kvcache.ai',
description='Ktransformers')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=cfg.server_port)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
parser.add_argument("--web", type=bool, default=False)
parser.add_argument("--model_name", type=str, default=cfg.model_name)
parser.add_argument("--model_path", type=str, default=cfg.model_path)
parser.add_argument("--device", type=str, default=cfg.model_device)
parser.add_argument("--gguf_path", type=str, default=cfg.gguf_path)
parser.add_argument("--optimize_config_path", type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=cfg.cpu_infer)
parser.add_argument("--type", type=str, default=cfg.backend_type)
# 初始化消息
args = parser.parse_args()
cfg.model_name = args.model_name
cfg.model_path = args.model_path
cfg.model_device = args.device
cfg.mount_web = args.web
cfg.server_ip = args.host
cfg.server_port = args.port
cfg.cpu_infer = args.cpu_infer
cfg.backend_type = args.type
default_args.model_dir = args.model_path
default_args.device = args.device
default_args.gguf_path = args.gguf_path
default_args.optimize_config_path = args.optimize_config_path
app = create_app()
custom_openapi(app)
create_interface(config=cfg, default_args=default_args)
run_api(app=app,
host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,)
if __name__ == "__main__":
main()
from sqlalchemy import JSON, Column, Float, Integer, String, Text
from sqlalchemy.orm import relationship
from ktransformers.server.utils.sql_utils import Base
class Assistant(Base):
__tablename__ = "assistants"
id = Column(String, primary_key=True, index=True)
object = Column(String, default="assistant")
created_at = Column(Integer)
name = Column(String, nullable=True)
description = Column(String, nullable=True)
model = Column(String)
instructions = Column(Text, nullable=True)
tools = Column(JSON)
tool_resources = Column(JSON)
temperature = Column(Float, nullable=True)
meta_data = Column(JSON, nullable=True)
top_p = Column(Float, nullable=True)
response_format = Column(JSON, default="auto")
build_status = Column(JSON, nullable=True)
runs = relationship("Run", back_populates="assistant")
messages = relationship("Message", back_populates="assistant")
from sqlalchemy import JSON, Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from ktransformers.server.utils.sql_utils import Base
class Message(Base):
__tablename__ = "messages"
id = Column(String, primary_key=True, index=True)
object = Column(String, default="thread.message")
created_at = Column(Integer)
thread_id = Column(String, ForeignKey("threads.id"))
status = Column(String, default="in_progress")
incomplete_details = Column(JSON, nullable=True)
completed_at = Column(Integer, nullable=True)
incomplete_at = Column(Integer, nullable=True)
role = Column(JSON)
content = Column(JSON)
assistant_id = Column(String, ForeignKey("assistants.id"), nullable=True)
run_id = Column(String, ForeignKey("runs.id"), nullable=True)
attachments = Column(JSON, nullable=True)
meta_data = Column(JSON, nullable=True)
thread = relationship("Thread", back_populates="messages")
assistant = relationship("Assistant", back_populates="messages")
run = relationship("Run", back_populates="message")
from sqlalchemy import JSON, Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from ktransformers.server.utils.sql_utils import Base
class RunStep(Base):
__tablename__ = "run_steps"
# todo
id = Column(String, primary_key=True, index=True)
object = Column(String, default="thread.run.step")
created_at = Column(Integer)
assistant_id = Column(String, ForeignKey("assistants.id"))
thread_id = Column(String, ForeignKey("threads.id"))
run_id = Column(String, ForeignKey("runs.id"))
type = Column(String)
status = Column(String)
step_details = Column(JSON)
last_error = Column(JSON, nullable=True)
expires_at = Column(Integer, nullable=True)
cancelled_at = Column(Integer, nullable=True)
failed_at = Column(Integer, nullable=True)
completed_at = Column(Integer, nullable=True)
meta_data = Column(JSON, nullable=True)
usage = Column(JSON, nullable=True)
assistant = relationship("Assistant", back_populates="run_steps")
thread = relationship("Thread", back_populates="run_steps")
run = relationship("Run", back_populates="run_steps")
from sqlalchemy import JSON, Column, Float, ForeignKey, Integer, String, Text
from sqlalchemy.orm import relationship
from ktransformers.server.utils.sql_utils import Base
class Run(Base):
__tablename__ = "runs"
id = Column(String, primary_key=True, index=True)
object = Column(String, default="thread.run")
created_at = Column(Integer)
thread_id = Column(String, ForeignKey("threads.id"))
assistant_id = Column(String, ForeignKey("assistants.id"))
status = Column(String)
required_action = Column(JSON, nullable=True)
last_error = Column(JSON, nullable=True)
expires_at = Column(Integer, nullable=True)
started_at = Column(Integer, nullable=True)
cancelled_at = Column(Integer, nullable=True)
failed_at = Column(Integer, nullable=True)
completed_at = Column(Integer, nullable=True)
incomplete_details = Column(JSON, nullable=True)
# get from assistant
model = Column(String)
instructions = Column(Text, nullable=True)
tools = Column(JSON)
meta_data = Column(JSON, nullable=True)
usage = Column(JSON, nullable=True)
temperature = Column(Float, nullable=True)
top_p = Column(Float, nullable=True)
max_propmp_tokens = Column(Integer, nullable=True)
truncation_strategy = Column(JSON)
tool_choice = Column(JSON)
response_format = Column(JSON, default="auto")
thread = relationship("Thread", back_populates="runs")
assistant = relationship("Assistant", back_populates="runs")
message = relationship("Message", back_populates="run")
from sqlalchemy import JSON, Column, Integer, String
from sqlalchemy.orm import relationship
from ktransformers.server.utils.sql_utils import Base
class Thread(Base):
__tablename__ = "threads"
id = Column(String, primary_key=True, index=True)
object = Column(String, default="thread")
created_at = Column(Integer)
tool_resources = Column(JSON, nullable=True)
meta_data = Column(JSON, nullable=True)
runs = relationship("Run", back_populates="thread")
messages = relationship("Message", back_populates="thread")
torch >= 2.3.0,<=2.3.1
transformers == 4.43.2
fastapi >= 0.111.0
langchain >= 0.2.0
blessed >= 1.20.0
accelerate >= 0.31.0
sentencepiece >= 0.1.97
setuptools
build
ninja
wheel
colorlog
fire
\ No newline at end of file
from enum import Enum
from time import time
from typing import AsyncIterable, Callable, Dict, List, Optional, Union
from asyncio import Lock, Queue
from fastapi import logger
from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator
import torch
from ktransformers.server.config.config import Config
from ktransformers.server.models.assistants.assistants import Assistant
from ktransformers.server.models.assistants.threads import Thread
from ktransformers.server.schemas.assistants.messages import Role
from ktransformers.server.schemas.assistants.runs import RunObject,RunStreamResponse,ObjectWithCreatedTime
from ktransformers.server.schemas.assistants.threads import ThreadObject
from ktransformers.server.schemas.base import Metadata,MetadataField,ObjectID
from ktransformers.server.schemas.assistants.tool import Tool,CodeInterpreter,FileSearch,RelatedThreads,FuntionTool,ToolResource,CodeInterpreterResource,FileSearchResource,RelatedThreadsResource,ToolType
from ktransformers.server.utils.sql_utils import SQLUtil
class AssistantBase(BaseModel):
name: Optional[str] = Field(None,description='The name of the assistant.')
description: Optional[str] = Field(None,description='The description of the assistant.')
instructions: Optional[str] = Field(None,description='Instructions which is added in front of the input of LLM')
tools: List[Tool] = Field([], max_length=128)
@field_validator('tools', mode='before')
def validate_tools(cls, value):
re = []
if not isinstance(value, list):
raise ValueError('Invalid type for tools')
for tool in value:
if 'type' not in tool:
raise ValueError('Invalid type for tools')
if tool['type'] == 'code_interpreter':
re.append(CodeInterpreter(**tool))
elif tool['type'] == 'file_search':
re.append(FileSearch(**tool))
elif tool['type'] == 'related_threads':
re.append(RelatedThreads(**tool))
elif tool['type'] == 'function':
re.append(FuntionTool(**tool))
else:
raise ValueError('Invalid type for tools')
return re
tool_resources: List[ToolResource] = Field([], max_length=128)
@field_validator('tool_resources', mode='before')
def validate_tool_resources(cls, value):
re = []
if not isinstance(value, list):
raise ValueError('Invalid type for tool resources')
for tool_re in value:
if 'file_ids' in tool_re:
re.append(CodeInterpreterResource(**tool_re))
elif 'vector_stores' in tool_re:
re.append(FileSearchResource(**tool_re))
elif 'thread_ids' in tool_re:
re.append(RelatedThreadsResource(**tool_re))
else:
raise ValueError('Invalid type for tool resources')
return re
meta_data: Metadata = MetadataField
@model_validator(mode='before')
def convert_meta_data(cls, values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
temperature: Optional[float] = Field(ge=0.0, le=2.0, default=1)
top_p: Optional[float] = Field(ge=0.0, le=1.0, default=1)
response_format: Union[str, Dict[str, str]] = "auto"
class AssistantCreate(AssistantBase):
model: str
class AssistantBuildStatus(BaseModel):
class Status(Enum):
not_build = "not_build"
in_queue = "in_queue"
parsing = "parsing"
prefilling = "prefilling"
dumping = "dumping"
completed = "completed"
paused = "paused"
_lock: Lock = PrivateAttr(default_factory=Lock)
_queue: Optional[Queue] = PrivateAttr(None)
status: Status = Field(default=Status.not_build)
total_file_count: int = Field(default=0)
parsed_file_count: int = Field(default=0)
prefilling_current: int = Field(default=0)
prefilling_total: int = Field(default=0)
build_started_time: Optional[int] = Field(default=None)
build_completed_time: Optional[int] = Field(default=None)
# in megabytes
assistant_usage: int = Field(default=0, description='')
assistant_total_usage: int = Field(default=0)
disk_free_space: int = Field(default=0)
disk_total_space: int = Field(default=0)
def to_stream_reply(self) -> str:
return f"event: assistant.build.status\ndata: {self.model_dump_json()}\n\n"
class AssistantObject(AssistantBase, ObjectWithCreatedTime):
model: Optional[str] = Field(
default=Config().model_name)
related_threads_objects: Optional[List] = Field(None, exclude=True)
_encoded_instruction: Optional[torch.Tensor] = PrivateAttr(default=None)
build_status: AssistantBuildStatus = Field(default=AssistantBuildStatus())
def as_api_response(self):
return self.model_dump(exclude={'build_status'})
def get_related_threads_ids(self) -> List[ObjectID]:
re = []
for tool, tool_re in zip(self.tools, self.tool_resources):
if tool.type == ToolType.RELATED_THREADS:
re += tool_re.thread_ids or []
return re
def get_related_threads_objects(self) -> List:
# raise NotImplementedError # should be replaced
sql_utils = SQLUtil()
if self.related_threads_objects is None:
with sql_utils.get_db() as db:
db_threads = db.query(Thread).all()
self.related_threads_objects = [tool for tool in [ThreadObject.model_validate(
tool.__dict__) for tool in db_threads] if tool.is_related_threads and tool.meta_data['assistant_id'] == self.id]
# logger.debug(
# f'Found {len(self.related_threads_objects)} related threads')
return self.related_threads_objects
def append_related_threads(self, thread_ids: List[ObjectID]):
# logger.debug(f'{self.tools} {self.tool_resources}')
for tool, tool_re in zip(self.tools, self.tool_resources):
if tool.type == ToolType.RELATED_THREADS:
tool_re.thread_ids += thread_ids
return
self.tools.append(RelatedThreads(type=ToolType.RELATED_THREADS))
self.tool_resources.append(
RelatedThreadsResource(thread_ids=thread_ids))
async def update_build_status(self, events: AsyncIterable) -> AsyncIterable:
async for event in events:
# logger.debug(event)
if isinstance(event, RunStreamResponse):
if event.event == RunObject.Status.completed:
self.build_status.status = AssistantBuildStatus.Status.completed
self.build_status.build_completed_time = int(time())
self.sync_db()
yield self.build_status.model_copy()
elif isinstance(event, dict):
# logger.debug('dict')
if 'stage' in event:
if event['stage'] == 'prefill':
self.build_status.status = AssistantBuildStatus.Status.prefilling
self.build_status.prefilling_current = event['curr_progress']
self.build_status.prefilling_total = event['max_progress']
if event['stage'] == 'parse':
self.build_status.status = AssistantBuildStatus.Status.parsing
self.build_status.parsed_file_count = event['curr_progress']
self.build_status.total_file_count = event['max_progress']
yield self.build_status.model_copy()
def get_build_status(self) -> AssistantBuildStatus:
return self.build_status
def sync_db(self)->None:
# raise NotImplementedError # should be replaced
sql_utils = SQLUtil()
db_assistant = Assistant(
**self.model_dump(mode='json'),
)
with sql_utils.get_db() as db:
sql_utils.db_merge_commit(db, db_assistant)
def get_encoded_instruction(self,encode_fn:Callable)->torch.Tensor:
if self._encoded_instruction is None:
logger.info(f'encoding assistant instruction: {self.instructions}')
self._encoded_instruction = encode_fn(self.instructions, Role.user)
return self._encoded_instruction
class AssistantModify(AssistantBase):
model: Optional[str] = None
# Non API Backend
from enum import Enum
from typing import ForwardRef, List, Optional, Union,Callable
import torch
from pydantic import BaseModel, PrivateAttr, model_validator
from ktransformers.server.exceptions import not_implemented
from ktransformers.server.config.log import logger
from ktransformers.server.models.assistants.messages import Message
from ktransformers.server.schemas.base import Metadata, MetadataField, ObjectWithCreatedTime
from ktransformers.server.schemas.assistants.tool import Field,CodeInterpreter,FileSearch
from ktransformers.server.utils.sql_utils import SQLUtil
class IncompleteDetails(BaseModel):
reason: str
class ContentType(Enum):
image_file = "image_file"
image_url = "image_url"
text = "text"
class ContentObject(BaseModel):
type: ContentType
class ImageFile(BaseModel):
file_id: str
detail: str
class ImageFileObject(ContentObject):
image_file: ImageFile
class ImageUrl(BaseModel):
url: str
detail: str
class ImageUrlObject(ContentObject):
image_url: ImageUrl
class Annotation(BaseModel):
todo: str
class Text(BaseModel):
value: str
annotations: List[Annotation] = Field(default=[])
class TextObject(ContentObject):
text: Text
delta_index: int = Field(default=0,exclude=True)
special_tokens_on: bool = Field(default=False,exclude=True)
last_two: str= Field(default='',exclude=True)
def filter_append(self,text:str):
self.text.value+=text
self.delta_index+=1
return True
Content = Union[ImageFileObject, ImageUrlObject, TextObject]
class Attachment(BaseModel):
file_id: Optional[str] = Field(default=None)
tools: Optional[List[Union[CodeInterpreter, FileSearch]]] = Field(default=None)
class Role(Enum):
user = "user"
assistant = "assistant"
def is_user(self)->bool:
return self == Role.user
class MessageCore(BaseModel):
role: Role
content: List[Content]
attachments: Optional[List[Attachment]]
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
class MessageBase(MessageCore):
class Status(Enum):
created = "created" # only used for stream
in_progress = "in_progress"
incomplete = "incomplete"
completed = "completed"
thread_id: str
status: Status
incomplete_details: Optional[IncompleteDetails] = None
completed_at: Optional[int] = None
incomplete_at: Optional[int] = None
assistant_id: Optional[str] = None
run_id: Optional[str]
MessageStreamResponse = ForwardRef('MessageStreamResponse')
class MessageObject(MessageBase, ObjectWithCreatedTime):
_encoded_content: Optional[torch.Tensor] = PrivateAttr(default=None)
def get_text_content(self) -> str:
text_content = ""
for content in self.content:
if content.type == ContentType.text:
text_content += content.text.value
else:
raise not_implemented("Content other than text")
return text_content
async def get_encoded_content(self,encode_fn:Callable):
if self._encoded_content is None:
logger.info(f'encoding {self.role.value} message({self.status.value}): {self.get_text_content()}')
self._encoded_content = encode_fn(self.get_text_content(),self.role)
for f in self.get_attached_files():
logger.info(f'encoding file: {f.filename}')
self._encoded_content = torch.cat([self._encoded_content, encode_fn(await f.get_str(),self.role)],dim=-1)
yield None
yield self._encoded_content
def get_attached_files(self):
raise NotImplementedError # should be replaced
def append_message_delta(self,text:str):
raise NotImplementedError # should be replaced
def sync_db(self):
# raise NotImplementedError # should be replaced
sql_utils = SQLUtil()
db_message = Message(
**self.model_dump(mode="json"),
)
with sql_utils.get_db() as db:
sql_utils.db_merge_commit(db, db_message)
def stream_response_with_event(self, event: MessageBase.Status) -> MessageStreamResponse:
match event:
case MessageObject.Status.created:
self.status = MessageObject.Status.in_progress
case _:
self.status = event
return MessageStreamResponse(message=self, event=event)
class MessageStreamResponse(BaseModel):
message: MessageObject
event: MessageObject.Status
def to_stream_reply(self):
return f"event: thread.message.{self.event.value}\ndata: {self.message.model_dump_json()}\n\n"
class MessageCreate(BaseModel):
role: Role = Field(default=Role.user)
content: Union[str | List[Content]]
attachments: Optional[List[Attachment]] = None
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
def to_core(self) -> MessageCore:
# logger.debug(f"Converting message create to core {self.model_dump()}")
core = MessageCore(
role=self.role,
content=[],
attachments=self.attachments,
meta_data=self.meta_data,
)
if isinstance(self.content, str):
core.content = [TextObject(type="text", text=Text(value=self.content, annotations=[]))]
elif isinstance(self.content, list):
core.content = self.content
else:
raise ValueError("Invalid content type")
return core
class MessageModify(BaseModel):
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment