Commit 18c42e67 authored by chenxl's avatar chenxl
Browse files

Initial commit

parents
from typing import List, Optional
from fastapi import APIRouter
from ktransformers.server.exceptions import not_implemented
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, MessageModify
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
from ktransformers.server.schemas.base import DeleteResponse, ObjectID, Order
from ktransformers.server.backend.base import ThreadContext
from ktransformers.server.utils.create_interface import get_thread_context_manager
router = APIRouter()
message_manager = MessageDatabaseManager()
@router.post("/{thread_id}/messages", tags=['openai'], response_model=MessageObject)
async def create_message(thread_id: str, msg: MessageCreate):
message = message_manager.db_create_message(
thread_id, msg, MessageObject.Status.in_progress)
ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)
if ctx is not None:
ctx.put_user_message(message)
return message
@router.get("/{thread_id}/messages", tags=['openai'], response_model=List[MessageObject])
async def list_messages(
thread_id: str,
limit: Optional[int] = 20,
order: Order = Order.DESC,
after: Optional[str] = None,
before: Optional[str] = None,
run_id: Optional[str] = None,
):
return message_manager.db_list_messages_of_thread(thread_id, limit, order)
@router.get("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=MessageObject)
async def retrieve_message(thread_id: ObjectID, message_id: ObjectID):
return message_manager.db_get_message_by_id(thread_id, message_id)
@router.post("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=MessageObject)
async def modify_message(thread_id: ObjectID, message_id: ObjectID, msg: MessageModify):
#raise not_implemented('modify message not implemented')
raise not_implemented('modify message')
@router.delete("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=DeleteResponse)
async def delete_message(thread_id: ObjectID, message_id: ObjectID):
ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)
if ctx is not None:
ctx.delete_user_message(message_id)
message_manager.db_delete_message_by_id(thread_id, message_id)
return DeleteResponse(id=message_id, object='thread.message.deleted')
from typing import List, Optional
from fastapi import APIRouter, Request
from ktransformers.server.crud.assistants.runs import RunsDatabaseManager
from ktransformers.server.backend.base import ThreadContext
from ktransformers.server.schemas.assistants.runs import RunCreate,RunObject,RunThreadCreate,RunModify,RunSubmit
from ktransformers.server.schemas.assistants.streaming import api_stream_response
from ktransformers.server.utils.create_interface import get_thread_context_manager
from ktransformers.server.schemas.base import Order
from ktransformers.server.config.log import logger
from ktransformers.server.exceptions import internal_server_error
router = APIRouter()
runs_manager = RunsDatabaseManager()
@router.post("/{thread_id}/runs",tags=['openai'])
async def create_run(request: Request, thread_id: str, run_create: RunCreate):
if run_create.stream:
async def inner():
run = runs_manager.db_create_run(thread_id, run_create)
yield run.stream_response_with_event(event=RunObject.Status.created)
ctx: ThreadContext = await get_thread_context_manager().get_context_by_run_object(run)
async for event in ctx.work():
yield event
return api_stream_response(request, inner())
else:
run = runs_manager.db_create_run(thread_id, run_create)
ctx: ThreadContext = await get_thread_context_manager().get_context_by_run_object(run)
async for event in ctx.work():
pass
return run
@router.post("/runs",tags=['openai'], response_model=RunObject)
async def create_thread_and_run(run_thread: RunThreadCreate):
raise NotImplementedError
@router.get("/{thread_id}/runs",tags=['openai'], response_model=List[RunObject])
async def list_runs(
thread_id: str,
limit: Optional[int] = 20,
order: Optional[Order] = Order.DESC,
after: Optional[str] = None,
before: Optional[str] = None,
):
raise NotImplementedError
@router.get("/{thread_id}/runs/{run_id}",tags=['openai'], response_model=RunObject)
async def retrieve_run(
thread_id: str,
run_id: str,
):
runobj= runs_manager.db_get_run(run_id)
assert runobj.thread_id == thread_id
return runobj
@router.post("/{thread_id}/runs/{run_id}",tags=['openai'], response_model=RunObject)
async def modify_run(
thread_id: str,
run_id: str,
run: RunModify,
):
raise NotImplementedError
@router.post("/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=['openai'],response_model=RunObject)
async def submit_tool_outputs_to_run(thread_id: str, run_id: str, submit: RunSubmit):
raise NotImplementedError
@router.post("/{thread_id}/runs/{run_id}/cancel",tags=['openai'], response_model=RunObject)
async def cancel_run(thread_id: str, run_id: str):
ctx: ThreadContext = await get_thread_context_manager().get_context_by_thread_id(thread_id)
if ctx is not None:
if ctx.run is None:
logger.warn(f'Run {ctx.run.id} is expected to be in_progress, but no context is found')
raise internal_server_error('ctx do not have run')
if ctx.run.id == run_id:
logger.info(f'Cancelling thread: {thread_id} and run: {run_id}')
ctx.run.stream_response_with_event(RunObject.Status.cancelling)
return ctx.run
else:
run = runs_manager.db_get_run(run_id)
logger.info(f'Run {run_id} not in this thread context')
return run
else:
run = runs_manager.db_get_run(run_id)
logger.info(f'Run {run_id} not in context manager')
return run
from typing import List,Optional
from fastapi import APIRouter
from ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager,Order,ObjectID
from ktransformers.server.schemas.assistants.threads import ThreadObject,ThreadCreate,ThreadModify
from ktransformers.server.schemas.base import DeleteResponse
from ktransformers.server.schemas.conversation import ThreadPreview
router = APIRouter(prefix='/threads')
threads_manager = ThreadsDatabaseManager()
@router.post("/",tags=['openai'], response_model=ThreadObject)
async def create_thread(thread: ThreadCreate):
return threads_manager.db_create_thread(thread)
@router.get("/", tags=['openai-ext'],response_model=List[ThreadPreview])
async def list_threads(limit: Optional[int] = 20, order: Order = Order.DESC):
return threads_manager.db_list_threads_preview(limit, order)
@router.get("/{thread_id}",tags=['openai'], response_model=ThreadObject)
async def retrieve_thread(thread_id: ObjectID):
return threads_manager.db_get_thread_by_id(thread_id)
@router.post("/{thread_id}",tags=['openai'], response_model=ThreadObject)
async def modify_thread(thread_id: ObjectID, thread: ThreadModify):
raise NotImplementedError
@router.delete("/{thread_id}",tags=['openai'], response_model=DeleteResponse)
async def delete_thread(thread_id: ObjectID):
threads_manager.db_delete_thread_by_id(thread_id=thread_id)
return DeleteResponse(id=thread_id, object='thread.deleted')
import json
from time import time
from uuid import uuid4
from fastapi import APIRouter
from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject
from ktransformers.server.backend.base import BackendInterfaceBase
router = APIRouter()
@router.post('/chat/completions',tags=['openai'])
async def chat_completion(request:Request,create:ChatCompletionCreate):
id = str(uuid4())
interface: BackendInterfaceBase = get_interface()
# input_ids = interface.format_and_tokenize_input_ids(id,messages=create.get_tokenizer_messages())
input_message = [json.loads(m.model_dump_json()) for m in create.messages]
if create.stream:
async def inner():
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
async for token in interface.inference(input_message,id):
chunk.set_token(token)
yield chunk
return chat_stream_response(request,inner())
else:
comp = ChatCompletionObject(id=id,object='chat.completion.chunk',created=int(time()))
async for token in interface.inference(input_message,id):
comp.append_token(token)
return comp
from fastapi import APIRouter
from . import completions
router = APIRouter()
router.include_router(completions.router)
\ No newline at end of file
import json
from time import time
from uuid import uuid4
from fastapi import APIRouter
from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import stream_response
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
router = APIRouter()
@router.post("/completions",tags=['openai'])
async def create_completion(request:Request,create:CompletionCreate):
id = str(uuid4())
interface = get_interface()
print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
if create.stream:
async def inner():
async for token in interface.inference(create.prompt,id):
d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n"
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
yield f"data:{json.dumps(d)}\n\n"
return stream_response(request,inner())
else:
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for token in interface.inference(create.prompt,id):
comp.append_token(token)
return comp
from fastapi import APIRouter
from .system import router as system_router
router = APIRouter()
router.include_router(system_router)
from fastapi import APIRouter
router = APIRouter()
@router.get('/system-info',tags=['web'])
def system_info():
raise NotImplementedError
from pydantic import BaseModel,Field
from typing import Optional
from ktransformers.server.config.config import Config
class ConfigArgs(BaseModel):
model_name : Optional[str] = Field(..., description="Model name")
model_dir: Optional[str] = Field(..., description="Path to model directory")
optimize_config_path: Optional[str] = Field('./KTransformers/optimize_config/DeepSeek-V2-Chat.json', description="Path of your optimize config json file")
gguf_path: Optional[str] = Field('/models/DeepSeek-Coder-V2-Instruct-GGUF/DeepSeek-Coder-V2-Instruct-Q4_K_M.gguf', description="Path of your gguf file")
class Config:
protected_namespaces = ()
paged : bool = Field(True,description='Wether to use paged attention kv cache')
# total_context: int = Field(16384, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once")
total_context: int = Field(2**18, description="Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the total to distribute dynamically over however many jobs are active at once")
max_batch_size: int = Field(20 if paged else 1, description="Max number of batches to run at once, assuming the sequences will fit within total_context")
max_chunk_size: int = Field(2048, description="Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new job is started, but at the expense of overall prompt ingestion speed")
max_new_tokens: int = Field(500, description="Max new tokens per completion. For this example applies to all jobs")
json_mode: bool = Field(False, description="Use LMFE to constrain the output to JSON format. See schema and details below")
healing: bool = Field(False, description="Demonstrate token healing")
ban_strings: Optional[list] = Field(None, description="Ban some phrases maybe")
gpu_split: Optional[str] = Field(None, description='"auto", or VRAM allocation per GPU in GB')
length: Optional[int] = Field(None, description="Maximum sequence length")
rope_scale: Optional[float] = Field(None, description="RoPE scaling factor")
rope_alpha: Optional[float] = Field(None, description="RoPE alpha value (NTK)")
no_flash_attn: bool = Field(False, description="Disable Flash Attention")
low_mem: bool = Field(
False,
description="Enable VRAM optimizations, potentially trading off speed",
)
experts_per_token: Optional[int] = Field(
None,
description="Override MoE model's default number of experts per token",
)
load_q4: bool = Field(False, description="Load weights in Q4 mode")
fast_safetensors: bool = Field(
False,
description="Optimized safetensors loading with direct I/O (experimental!)",
)
draft_model_dir: Optional[str] = Field(None, description="Path to draft model directory")
no_draft_scale: bool = Field(
False,
description="If draft model has smaller context size than model, don't apply alpha (NTK) scaling to extend it",
)
modes: bool = Field(False, description="List available modes and exit.")
mode: str = Field(
"llama",
description="Chat mode. Use llama for Llama 1/2 chat finetunes.",
)
username: str = Field("User", description="Username when using raw chat mode")
botname: str = Field("Chatbort", description="Bot name when using raw chat mode")
system_prompt: Optional[str] = Field(None, description="Use custom system prompt")
temperature: float = Field(0.95, description="Sampler temperature, default = 0.95 (1 to disable)")
smoothing_factor: float = Field(0.0, description="Smoothing Factor, default = 0.0 (0 to disable)")
dynamic_temperature: Optional[str] = Field(
None,
description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1",
)
top_k: int = Field(50, description="Sampler top-K, default = 50 (0 to disable)")
top_p: float = Field(0.8, description="Sampler top-P, default = 0.8 (0 to disable)")
top_a: float = Field(0.0, description="Sampler top-A, default = 0.0 (0 to disable)")
skew: float = Field(0.0, description="Skew sampling, default = 0.0 (0 to disable)")
typical: float = Field(
0.0,
description="Sampler typical threshold, default = 0.0 (0 to disable)",
)
repetition_penalty: float = Field(
1.01,
description="Sampler repetition penalty, default = 1.01 (1 to disable)",
)
frequency_penalty: float = Field(
0.0,
description="Sampler frequency penalty, default = 0.0 (0 to disable)",
)
presence_penalty: float = Field(
0.0,
description="Sampler presence penalty, default = 0.0 (0 to disable)",
)
max_response_tokens: int = Field(300, description="Max tokens per response, default = 1000")
response_chunk: int = Field(250, description="Space to reserve in context for reply, default = 250")
no_code_formatting: bool = Field(False, description="Disable code formatting/syntax highlighting")
cache_8bit: bool = Field(False, description="Use 8-bit (FP8) cache")
cache_q4: bool = Field(True, description="Use Q4 cache")
ngram_decoding: bool = Field(False, description="Use n-gram speculative decoding")
print_timings: bool = Field(False, description="Output timings after each prompt")
amnesia: bool = Field(False, description="Forget context after every response")
# for transformers
batch_size :int = Field(1,description="Batch Size")
cache_lens:int = Field(4096, description="Cache lens for transformers static cache")
device:str = Field('cuda:2',description="device")
cfg = Config()
default_args = ConfigArgs(model_name=cfg.model_name,model_dir=cfg.model_path)
from asyncio import Queue
from enum import Enum
import sys, os
from typing import AsyncIterator, Dict, List, Optional, Tuple
import torch
from ktransformers.server.config.log import logger
from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
from ktransformers.server.crud.assistants.runs import RunsDatabaseManager
from ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager
from ktransformers.server.exceptions import request_error
from ktransformers.server.schemas.assistants.assistants import AssistantObject
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role
from ktransformers.server.schemas.assistants.runs import RunObject
from ktransformers.server.schemas.assistants.threads import ThreadObject
from ktransformers.server.schemas.base import ObjectID, Order
from ktransformers.server.utils.multi_timer import Profiler
from .args import ConfigArgs,default_args
class BackendInterfaceBase:
'''
Interface to inference frameworks. e.g. transformers, exllama.
Implement __init__ and work
'''
args: ConfigArgs
profiler:Profiler = Profiler()
def __init__(self, args:ConfigArgs = default_args):
raise NotImplementedError
async def inference(self,local_messages,request_unique_id:Optional[str])->AsyncIterator[str]:
'''
work can be called directly, or by ThreadContext
local_messages:
When called by ThreadContext, local_messages are generated by ThreadContext.get_local_messages().
Please deal with different local_messages
request_unique_id:
unique id of different requests, useful when using cache
return:
async str output for stream update
'''
raise NotImplementedError
def report_last_time_performance(self):
try:
tokenize_time = self.profiler.get_timer_sec('tokenize')
prefill_time = self.profiler.get_timer_sec('prefill')
decode_time = self.profiler.get_timer_sec('decode')
prefill_count = self.profiler.get_counter('prefill')
decode_count = self.profiler.get_counter('decode')
logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')
except:
logger.info(f'Performance statistics not recorded')
class ThreadContext:
'''
A thread context holding assistant logics
'''
args: ConfigArgs
# Assistant Logic
assistant: Optional[AssistantObject] = None
related_threads : List[ThreadObject]
thread: ThreadObject
messages: List[MessageObject] = []
run: RunObject
interface: Optional[BackendInterfaceBase] = None
queue: Optional[Queue] = None
timer: Profiler = Profiler()
def __init__(self, run: RunObject,interface:BackendInterfaceBase, args: ConfigArgs = default_args) -> None:
self.args = args
self.thread_manager = ThreadsDatabaseManager()
self.message_manager = MessageDatabaseManager()
self.runs_manager = RunsDatabaseManager()
self.assistant_manager = AssistantDatabaseManager()
self.thread = self.thread_manager.db_get_thread_by_id(run.thread_id)
self.assistant = self.assistant_manager.db_get_assistant_by_id(run.assistant_id)
self.messages = self.message_manager.db_list_messages_of_thread(run.thread_id,order=Order.ASC)
logger.debug(f"{len(self.messages)} messages loaded from database")
self.interface = interface
self.update_by_run(run,args)
def get_local_messages(self):
'''
Get local messages, as the input to interface.work
This function is intended to message preprocess e.g. apply chat template
'''
raise NotImplementedError
def update_by_run(self,run:RunObject,args:ConfigArgs = default_args):
self.run = run
self.args = args
def put_user_message(self, message: MessageObject):
assert (
message.role.is_user() and message.thread_id == self.thread.id and message.status == MessageObject.Status.in_progress
)
self.messages.append(message)
def delete_user_message(self,message_id: ObjectID):
self.messages = [m for m in self.messages if m.id != message_id]
async def work(self)->AsyncIterator:
logger.debug('start working')
user_message = self.messages[-1]
if not user_message.role.is_user():
raise request_error('user must talk before LLM can talk')
user_message.status = MessageObject.Status.completed
user_message.sync_db()
local_messages = self.get_local_messages() # must get this before we interseted reply_message
response_str_count = 0
reply_message = self.message_manager.create_message_object(
self.thread.id,
self.run.id,
MessageCreate(role=Role.assistant, content=""),
)
reply_message.assistant_id = self.assistant.id
self.messages.append(reply_message)
yield reply_message.stream_response_with_event(MessageObject.Status.created)
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
yield self.run.stream_response_with_event(RunObject.Status.in_progress)
async for token in self.interface.inference(local_messages,self.thread.id):
if self.run.status == RunObject.Status.cancelling:
logger.warn(f'Run {self.run.id} cancelling')
break
yield reply_message.append_message_delta(token)
response_str_count+=1
if self.run.status == RunObject.Status.cancelling:
yield self.run.stream_response_with_event(RunObject.Status.cancelled)
yield reply_message.stream_response_with_event(MessageObject.Status.incomplete)
elif self.run.status == RunObject.Status.in_progress:
yield self.run.stream_response_with_event(RunObject.Status.completed)
yield reply_message.stream_response_with_event(MessageObject.Status.completed)
else:
raise NotImplementedError(f'{self.run.status} should not appear here')
reply_message.sync_db()
self.run.sync_db()
\ No newline at end of file
from asyncio import Lock
from typing import Dict, Optional
from ktransformers.server.backend.base import ThreadContext, BackendInterfaceBase
from ktransformers.server.schemas.assistants.runs import RunObject
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.config.log import logger
from ktransformers.server.backend.interfaces.transformers import TransformersThreadContext
from ktransformers.server.backend.interfaces.ktransformers import KTransformersThreadContext
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaThreadContext
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
class ThreadContextManager:
lock: Lock
threads_context: Dict[ObjectID, ThreadContext]
interface: BackendInterfaceBase
def __init__(self,interface) -> None:
logger.debug(f"Creating Context Manager")
self.lock = Lock()
self.threads_context = {}
self.interface = interface
pass
async def get_context_by_run_object(self, run: RunObject) -> ThreadContext:
async with self.lock:
logger.debug(f"keys {self.threads_context.keys()}")
if run.thread_id not in self.threads_context:
logger.debug(f"new inference context {run.thread_id}")
if isinstance(self.interface, ExllamaInterface):
new_context = ExllamaThreadContext(run, self.interface)
elif isinstance(self.interface, KTransformersInterface):
new_context = KTransformersThreadContext(run, self.interface)
elif isinstance(self.interface, TransformersInterface):
new_context = TransformersThreadContext(run, self.interface)
else:
raise NotImplementedError
self.threads_context[run.thread_id] = new_context
# self.threads_context[run.thread_id] = ExllamaInferenceContext(run)
re = self.threads_context[run.thread_id]
re.update_by_run(run)
return re
async def get_context_by_thread_id(self, thread_id: ObjectID) -> Optional[ThreadContext]:
async with self.lock:
if thread_id in self.threads_context:
logger.debug(f'found context for thread {thread_id}')
return self.threads_context[thread_id]
else:
logger.debug(f'no context for thread {thread_id}')
return None
\ No newline at end of file
import sys, os
from typing import AsyncIterator, Dict, Tuple
import torch
from ..args import ConfigArgs, default_args
from ..base import BackendInterfaceBase, ThreadContext
from ktransformers.server.schemas.assistants.runs import RunObject
from ..args import *
class ExllamaThreadContext(ThreadContext):
def __init__(self, run: RunObject, args: ConfigArgs = default_args) -> None:
super().__init__(run,args)
def get_interface(self):
return
def get_local_messages(self):
raise NotImplementedError
class ExllamaInterface(BackendInterfaceBase):
def __init__(self, args: ConfigArgs = ...):
raise NotImplementedError
def tokenize_prompt(self, prompt: str) -> torch.Tensor:
raise NotImplementedError
async def inference(self,local_messages,request_unique_id:Optional[str])->AsyncIterator:
raise NotImplementedError
import torch
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
from ktransformers.server.backend.interfaces.transformers import TransformersInterface,ConfigArgs, TransformersThreadContext,default_args,TextStreamer
from ktransformers.server.config.log import logger
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules
class KTransformersThreadContext(TransformersThreadContext):
pass
class KTransformersInterface(TransformersInterface):
def __init__(self,args:ConfigArgs= default_args):
self.args = args
torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir,device = args.device)
config=AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation="flash_attention_2"
with torch.device("meta"):
self.model=custom_models[config.architectures[0]](config)
optimize_rule_path = default_optimize_rules[config.architectures[0]]
# print(optimize_config)
gguf_path = args.gguf_path
if gguf_path is None:
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
logger.info(f'{args.model_name} loaded from {args.model_dir} to {args.device}')
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=args.device, dtype=self.model.dtype)
logger.info(f'StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}')
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
if self.model.generation_config.pad_token_id is None:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.streamer = TextStreamer(self.tokenizer)
def decode_one_tokens(self):
if not hasattr(self, "cuda_graph_runner"):
self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, return_dict=False, use_cache=True)
if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position)
self.cache.change_seq_length(1)
torch.cuda.synchronize()
logits = logits[0,-1,:]
return self.logits_to_token(logits)
if self.use_static_cache:
mask = torch.ones((1,self.seq_length)).to(self.args.device)
logits = self.model(
self.current_ids,
cache_position=self.active_cache_position,
past_key_values=self.cache,
attention_mask=mask,
return_dict=False,
use_cache=True
)[0]
else:
logits = self.model(
self.current_ids,
return_dict=False
)[0]
logits = logits[0,-1,:]
return self.logits_to_token(logits)
from typing import Any, List, Optional, Set
from transformers import LlamaTokenizer,AutoTokenizer, AutoConfig, LlamaForCausalLM,GenerationConfig, StaticCache, AutoModelForCausalLM,BitsAndBytesConfig
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.utils.multi_timer import Profiler
import torch
import sys, os
from ..base import ThreadContext,BackendInterfaceBase
from ktransformers.server.config.log import logger
from ..args import ConfigArgs,default_args
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
class TextStreamer:
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.decode_kwargs = decode_kwargs
# variables used in the streaming process
self.token_cache = []
self.print_len = 0
self.next_tokens_are_prompt = True
def reset(self):
self.token_cache = []
self.print_len = 0
def put(self, value)->Optional[str]:
"""
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if not isinstance(value,int):
raise ValueError("TextStreamer only supports batch size 1, and int type input")
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return None
# Add the new token to the cache and decodes the entire thing.
self.token_cache.append(value)
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs)
# After the symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.reset()
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len :]
self.print_len += len(printable_text)
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = text[self.print_len : text.rfind(" ") + 1]
self.print_len += len(printable_text)
return printable_text
def end(self)->Optional[str]:
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
printable_text = text[self.print_len :]
self.reset()
else:
printable_text = ""
self.next_tokens_are_prompt = True
return printable_text
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
class TransformersThreadContext(ThreadContext):
def get_local_messages(self):
local_messages = []
for m in self.messages:
local_messages.append(
{'role':m.role.value,
'content':m.get_text_content()}
)
return local_messages
class TransformersInterface(BackendInterfaceBase):
use_static_cache : bool = True
model: Any
tokenizer: AutoTokenizer
cache: StaticCache
generated_ids:torch.Tensor
seq_length:int
streamer: TextStreamer
# thread_related
last_request_id: Optional[str] = None
ever_generated_ids: Set[int] = set()
def __init__(self, args:ConfigArgs = default_args):
self.args = args
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device,use_safetensors=True)
logger.info(f'{args.model_name} loaded from {args.model_dir} to {args.device}')
self.cache = StaticCache(config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=args.device, dtype=self.model.dtype)
logger.info(f'StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}')
self.streamer = TextStreamer(self.tokenizer)
@property
def current_ids(self):
return self.generated_ids[:,self.seq_length-1].unsqueeze(1)
@property
def active_cache_position(self):
return torch.tensor([self.seq_length-1], device=self.args.device)
def tokenize_prompt(self,prompt:str):
input_ids = self.tokenizer.encode(prompt,return_tensors='pt').to(self.args.device)
return input_ids
def format_and_tokenize_input_ids(self,thread_id:ObjectID,messages:List):
for m in messages:
if m['role']=='system':
logger.warn(f'change {m["role"]} to user')
m['role'] = 'user'
new_messages = [messages[0]]
for m in messages[1:]:
if m['role'] == 'user' and new_messages[-1]['role']=='user':
logger.warn('merge two adjacent user messages')
new_messages[-1]['content']+=m['content']
else:
new_messages.append(m)
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
if (self.last_request_id is not None) and self.last_request_id == thread_id:
x = self.generated_ids[:,:self.seq_length]
y = input_ids[:,:self.seq_length]
# We can only hope that the input_ids are the same
unequal_mask = torch.ne(x,y)
unequal_positions = torch.nonzero(unequal_mask)
num_unequal_elements = unequal_mask.sum().item()
logger.warn(f'num_unequal_elements: {num_unequal_elements}')
input_ids = input_ids[:,self.seq_length:]
logger.debug(f'get input ids of shape {input_ids.shape}')
return input_ids
def append_new_tokens(self,new_tokens:int)->Optional[str]:
self.generated_ids[0,self.seq_length] = new_tokens
self.seq_length+=1
return self.streamer.put(new_tokens)
def logits_to_token(self,logits:torch.Tensor):
logits = logits/self.args.temperature
for token_idx in self.ever_generated_ids:
if logits[token_idx] < 0:
logits[token_idx] *= self.args.repetition_penalty
else:
logits[token_idx] /= self.args.repetition_penalty
probs = torch.nn.functional.softmax(logits, dim=-1)
sample = True
if sample:
last = torch.multinomial(probs, num_samples=1)
else:
_, last = torch.topk(probs, k=1, dim=-1)
last = last.item()
self.ever_generated_ids.add(last)
return last
def decode_one_tokens(self):
if self.use_static_cache:
mask = torch.ones((1,self.seq_length)).to(self.args.device)
logits = self.model(
self.current_ids,
cache_position=self.active_cache_position,
past_key_values=self.cache,
attention_mask=mask,
return_dict=False,
use_cache=True
)[0]
else:
logits = self.model(
self.current_ids,
return_dict=False
)[0]
logits = logits[0,-1,:]
return self.logits_to_token(logits)
@torch.no_grad
def prefill(self,input_ids:torch.Tensor,is_new:bool):
input_ids_length = input_ids.shape[-1]
self.profiler.set_counter('prefill',input_ids_length)
logger.debug(f'input_ids: {input_ids.shape}')
if is_new:
self.cache.reset()
self.ever_generated_ids.clear()
former_seq_length = 0
self.seq_length = input_ids_length
self.generated_ids = torch.zeros(
self.args.batch_size, self.seq_length + self.args.max_new_tokens + 1, dtype=torch.int, device=self.args.device
)
else:
logger.debug(f'generate_ids: {self.generated_ids.shape}')
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens+1
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length>0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
)
self.generated_ids = torch.cat([self.generated_ids,new_generate_ids],dim=-1)
logger.debug(f'cache position: {former_seq_length} to {self.seq_length}')
cache_position = torch.arange(former_seq_length,self.seq_length, device=self.args.device)
self.generated_ids[:,cache_position] = input_ids.to(self.args.device).to(torch.int)
mask = torch.ones((1,self.seq_length)).to(self.args.device)
device = input_ids.device
if not(type(self) is TransformersInterface):
input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds, cache_position=cache_position, past_key_values=self.cache,return_dict=False, use_cache=True,attention_mask=mask
)[0]
else:
logits = self.model(
inputs_embeds=inputs_embeds,return_dict=False
)[0]
next_token = self.logits_to_token(logits[0,-1,:])
yield self.append_new_tokens(next_token)
@torch.no_grad
def generate(self):
self.profiler.set_counter('decode',0)
for _ in range(1, self.args.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = self.decode_one_tokens()
self.profiler.inc('decode')
if next_token == self.tokenizer.eos_token_id:
assert self.args.batch_size == 1
break
yield self.append_new_tokens(next_token)
yield self.streamer.end()
def check_is_new(self,thread_id:str):
if not self.use_static_cache:
return True
if self.last_request_id is None:
self.last_request_id = thread_id
return True
else:
if self.last_request_id==thread_id:
return False
else:
self.last_request_id = thread_id
return True
async def inference(self,local_messages,thread_id:str):
self.profiler.create_and_start_timer('tokenize')
if isinstance(local_messages,List):
input_ids = self.format_and_tokenize_input_ids(thread_id,local_messages)
elif isinstance(local_messages,str):
input_ids = self.tokenize_prompt(local_messages)
else:
raise ValueError('local_messages should be List or str')
self.profiler.pause_timer('tokenize')
self.profiler.create_and_start_timer('prefill')
for t in self.prefill(input_ids,self.check_is_new(thread_id)):
if t is not None:
print(t,end='')
yield t
self.profiler.pause_timer('prefill')
self.profiler.create_and_start_timer('decode')
for t in self.generate():
if t is not None:
print(t,end='')
yield t
print('')
self.profiler.pause_timer('decode')
self.report_last_time_performance()
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : unicornchan
Date : 2024-06-11 16:35:42
Version : 1.0.0
LastEditors : chenxl
LastEditTime : 2024-07-27 01:55:42
'''
import os
import yaml
from ktransformers.server.config.singleton import Singleton
class Config(metaclass=Singleton):
"""Singleton pattern Config class, used to get all configurations.
"""
CONFIG_FILE_NAME = "config.yaml"
@staticmethod
def load() -> dict:
"""load config file
Returns:
dict: all configs
"""
base_path: str = os.path.dirname(
os.path.dirname(os.path.dirname(__file__)))
config_yaml: str = os.path.join(
base_path, "configs", Config.CONFIG_FILE_NAME)
if not os.path.exists(config_yaml):
print(f"Can't find config file, {config_yaml}")
exit(-1)
with open(config_yaml, 'r', encoding="utf-8") as fp:
config = yaml.safe_load(fp)
return config
@staticmethod
def to_path(path: str) -> str:
"""
process file path
"""
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
real_path = path if os.path.isabs(
path) else os.path.join(base_path, path)
return real_path
def __init__(self):
cfg = Config.load()
self.base_path = os.path.dirname(
os.path.dirname(os.path.dirname(__file__)))
# log configs
self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"]))
self.log_file = cfg["log"]["file"]
self.log_level = cfg["log"]["level"]
self.backup_count = cfg["log"]["backup_count"]
# server configs
self.server: dict = cfg.get("server",{})
self.server_ip = self.server.get("ip", "0.0.0.0")
self.server_port = self.server.get("port", 9016)
# db configs
self.db_configs: dict = cfg.get("db", {})
self.db_type = self.db_configs.get("type", "")
self.db_host = os.path.join(self.base_path, self.db_configs.get("host", ""))
self.db_port = self.db_configs.get("port", "")
self.db_name = self.db_configs.get("database", "")
self.db_pool_size = self.db_configs.get("pool_size")
self.db_database = self.db_configs.get("database", "")
# user config
self.user_config: dict = cfg.get("user", {})
self.user_secret_key = self.user_config.get("secret_key", "")
self.user_algorithm = self.user_config.get("algorithm", "")
# model config
self.model:dict = cfg.get("model", {})
self.backend_type: str = self.model.get("type", "transformers")
self.model_path: str = self.model.get("path", "")
self.model_name: str = self.model.get("name", "")
self.model_device: str = self.model.get("device", "cuda:0")
self.gguf_path: str = self.model.get("gguf_path", "")
# web config
self.web: dict = cfg.get("web", {})
self.web_cross_domain: bool = self.web.get("open_cross_domain", True)
self.mount_web: bool = self.web.get("mount", False)
self.ext: dict = cfg.get("ext", {})
self.cpu_infer = self.ext.get("cpu_infer", 10)
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : unicornchan
Date : 2024-06-12 02:48:39
Version : 1.0.0
LastEditors : chenxl
LastEditTime : 2024-07-27 01:55:50
'''
import codecs
import logging
import os
import re
import locale
from pathlib import Path
from logging.handlers import BaseRotatingHandler
import time
import colorlog
from ktransformers.server.config.config import Config
class DailyRotatingFileHandler(BaseRotatingHandler):
"""
such as 'logging.TimeRotatingFileHandler', Additional features:
- support multiprocess
- support rotating daily
"""
def __init__(self, filename, backupCount=0, encoding=None, delay=False, utc=False, **kwargs): # pylint: disable=unused-argument
self.backup_count = backupCount
self.utc = utc
self.suffix = "%Y-%m-%d"
self.base_log_path = Path(filename)
if not os.path.exists(self.base_log_path.parent):
os.makedirs(self.base_log_path.parent)
self.base_filename = self.base_log_path.name
self.current_filename = self._compute_fn()
self.current_log_path = self.base_log_path.with_name(
self.current_filename)
BaseRotatingHandler.__init__(self, filename, 'a', encoding, delay)
# pylint: disable=unused-argument, invalid-name
def shouldRollover(self, record):
"""
Determine whether to rotate the log. If the log filename corresponding to the current
time is not consistent with the currently opened log filename, then it is necessary
to rotate the log
Args:
record: record is not used, as we are just comparing times, but it is needed so
the method signatures are the same
"""
if self.current_filename != self._compute_fn():
return True
return False
def doRollover(self):
"""
roll over
"""
# close last log file
if self.stream:
self.stream.close()
self.stream = None # type: ignore
# gen new log file name
self.current_filename = self._compute_fn()
self.current_log_path = self.base_log_path.with_name(
self.current_filename)
if not self.delay:
self.stream = self._open() # type: ignore
self.delete_expired_files()
def _compute_fn(self):
"""
gen log file name
"""
return self.base_filename + "." + time.strftime(self.suffix, time.localtime())
def _open(self):
"""
open a new log file, create soft link
"""
if self.encoding is None:
stream = open(str(self.current_log_path), self.mode, encoding=locale.getpreferredencoding())
else:
stream = codecs.open(str(self.current_log_path), self.mode, self.encoding)
if self.base_log_path.exists():
try:
if not self.base_log_path.is_symlink() or os.readlink(self.base_log_path) != self.current_filename:
os.remove(self.base_log_path)
except OSError:
pass
try:
os.symlink(self.current_filename, str(self.base_log_path))
except OSError:
pass
return stream
def delete_expired_files(self):
"""
delete expired files every day
"""
if self.backup_count <= 0:
return
file_names = os.listdir(str(self.base_log_path.parent))
result = []
prefix = self.base_filename + "."
plen = len(prefix)
for file_name in file_names:
if file_name[:plen] == prefix:
suffix = file_name[plen:]
if re.match(r"^\d{4}-\d{2}-\d{2}(\.\w+)?$", suffix):
result.append(file_name)
if len(result) < self.backup_count:
result = []
else:
result.sort()
result = result[:len(result) - self.backup_count]
for file_name in result:
os.remove(str(self.base_log_path.with_name(file_name)))
class Logger(object):
"""
logger class
"""
level_relations = {
'debug': logging.DEBUG,
'info': logging.INFO,
'warn': logging.WARNING,
'error': logging.ERROR,
'crit': logging.CRITICAL
}
def __init__(self, level: str = 'info'):
fmt = '%(asctime)s %(levelname)s %(pathname)s[%(lineno)d] %(funcName)s: %(message)s'
cfg: Config = Config()
filename: str = os.path.join(cfg.log_dir, cfg.log_file)
backup_count: int = cfg.backup_count
th = DailyRotatingFileHandler(filename=filename, when='MIDNIGHT', backupCount=backup_count, encoding="utf-8")
th.setFormatter(logging.Formatter(fmt))
color_fmt = (
'%(log_color)s%(asctime)s %(levelname)s %(pathname)s[%(lineno)d]: %(message)s'
)
color_formatter = colorlog.ColoredFormatter(
color_fmt,
log_colors={
'DEBUG': 'cyan',
'INFO': 'green',
'WARNING': 'yellow',
'ERROR': 'red',
'CRITICAL': 'bold_red'
}
)
sh = logging.StreamHandler()
sh.setFormatter(color_formatter)
self.logger = logging.getLogger(filename)
self.logger.setLevel(self.level_relations.get(level)) # type: ignore
self.logger.addHandler(th)
self.logger.addHandler(sh)
logger = Logger(level=Config().log_level).logger
#!/usr/bin/env python
# coding=utf-8
'''
Description : Implement singleton
Author : unicornchan
Date : 2024-06-11 17:08:36
Version : 1.0.0
LastEditors : chenxl
LastEditTime : 2024-07-27 01:55:56
'''
import abc
class Singleton(abc.ABCMeta, type):
"""_summary_
Args:
abc.ABCMeta: Provide a mechanism for defining abstract methods and properties,
enforcing subclasses to implement these methods and properties.
type: Inherit from 'type' to make 'Singleton' a metaclass,
enabling the implementation of the Singleton
"""
_instances = {}
def __call__(cls, *args, **kwds):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwds)
return cls._instances[cls]
class AbstractSingleton(abc.ABC, metaclass=Singleton):
"""Provided an abstract Singleton base class, any class inheriting from
this base class will automatically become a Singleton class.
Args:
abc.ABC: Abstract base class, it cannot be instantiated, only inherited.
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment