chat.py 1.78 KB
Newer Older
chenxl's avatar
chenxl committed
1
2
3
4
5
6
7
import json
from time import time
from uuid import uuid4
from fastapi import APIRouter
from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
ceerrep's avatar
ceerrep committed
8
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage
chenxl's avatar
chenxl committed
9
from ktransformers.server.backend.base import BackendInterfaceBase
10
from ktransformers.server.config.config import Config
chenxl's avatar
chenxl committed
11
12
13

router = APIRouter()

14
15
@router.get('/models', tags=['openai'])
async def list_models():
16
    return [{"id": Config().model_name, "name": Config().model_name}]
17
18
19


@router.post('/chat/completions', tags=['openai'])
chenxl's avatar
chenxl committed
20
21
22
23
24
25
26
27
28
29
30
async def chat_completion(request:Request,create:ChatCompletionCreate):
    id = str(uuid4())

    interface: BackendInterfaceBase = get_interface()
    # input_ids = interface.format_and_tokenize_input_ids(id,messages=create.get_tokenizer_messages())

    input_message = [json.loads(m.model_dump_json()) for m in create.messages]

    if create.stream:
        async def inner():
            chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
lazymio's avatar
lazymio committed
31
            async for token in interface.inference(input_message,id,create.temperature,create.top_p,create.repetition_penalty):
chenxl's avatar
chenxl committed
32
33
34
35
                chunk.set_token(token)
                yield chunk
        return chat_stream_response(request,inner())
    else:
36
        comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
ceerrep's avatar
ceerrep committed
37
        comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
lazymio's avatar
lazymio committed
38
        async for token in interface.inference(input_message,id,create.temperature,create.top_p,create.repetition_penalty):
chenxl's avatar
chenxl committed
39
40
            comp.append_token(token)
        return comp