Unverified Commit 96d75d53 authored by wang jiahao's avatar wang jiahao Committed by GitHub
Browse files

Merge pull request #835 from BITcyman/fix-openai_chat_completion

[fix] support openai chat completion api
parents 63b1c852 299c4dca
...@@ -13,6 +13,8 @@ from ktransformers.server.utils.create_interface import get_interface ...@@ -13,6 +13,8 @@ from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import check_link_response from ktransformers.server.schemas.assistants.streaming import check_link_response
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter(prefix='/api') router = APIRouter(prefix='/api')
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
...@@ -58,7 +60,11 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest): ...@@ -58,7 +60,11 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
if input.stream: if input.stream:
async def inner(): async def inner():
async for token in interface.inference(input.prompt, id): async for res in interface.inference(input.prompt, id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
d = OllamaGenerationStreamResponse( d = OllamaGenerationStreamResponse(
model=config.model_name, model=config.model_name,
created_at=str(datetime.now()), created_at=str(datetime.now()),
...@@ -123,7 +129,11 @@ async def chat(request: Request, input: OllamaChatCompletionRequest): ...@@ -123,7 +129,11 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
eval_count = 0 # 统计生成的 token 数量 eval_count = 0 # 统计生成的 token 数量
tokens = [] tokens = []
async for token in interface.inference(prompt, id): async for res in interface.inference(prompt, id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
d = OllamaChatCompletionStreamResponse( d = OllamaChatCompletionStreamResponse(
model=config.model_name, model=config.model_name,
created_at=str(datetime.now()), created_at=str(datetime.now()),
......
...@@ -5,10 +5,16 @@ from fastapi import APIRouter ...@@ -5,10 +5,16 @@ from fastapi import APIRouter
from fastapi.requests import Request from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import chat_stream_response from ktransformers.server.schemas.assistants.streaming import chat_stream_response
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
from openai.types.chat import ChatCompletion
from openai.types.completion_usage import CompletionUsage
router = APIRouter() router = APIRouter()
@router.get('/models', tags=['openai']) @router.get('/models', tags=['openai'])
...@@ -29,15 +35,76 @@ async def chat_completion(request:Request,create:ChatCompletionCreate): ...@@ -29,15 +35,76 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
if create.stream: if create.stream:
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
async def inner(): async def inner():
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) chunk = ChatCompletionChunk(
async for token in interface.inference(input_message,id,create.temperature,create.top_p): id = id,
chunk.set_token(token) choices = [],
object = 'chat.completion.chunk',
created = int(time()),
model = Config().model_name,
)
async for res in interface.inference(input_message,id, create.temperature, create.top_p):
if isinstance(res, RawUsage):
# at the end of inference, interface.inference() will return the usage of inference
raw_usage = res
chunk.choices = []
chunk.usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
)
yield chunk
else:
token, finish_reason = res
choice = Choice(
index = 0,
delta = ChoiceDelta(content=token, role=None, tool_calls=None),
finish_reason = finish_reason,
logprobs = None,
)
chunk.choices = [choice]
yield chunk yield chunk
return chat_stream_response(request,inner())
return chat_stream_response(request, inner())
else:
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
content = ""
finish_reason = None
async for res in interface.inference(input_message,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
)
else: else:
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time())) token, finish_reason = res
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2) content = content + token
async for token in interface.inference(input_message,id,create.temperature,create.top_p): finish_reason = finish_reason
comp.append_token(token)
return comp choice = Choice(
index = 0,
finish_reason = finish_reason,
message = ChatCompletionMessage(
content=content,
role="assistant"
))
chat_completion = ChatCompletion(
id = id,
choices = [choice],
created = int(time()),
model = Config().model_name,
object = 'chat.completion',
usage = usage
)
return chat_completion
...@@ -6,6 +6,7 @@ from fastapi.requests import Request ...@@ -6,6 +6,7 @@ from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import stream_response from ktransformers.server.schemas.assistants.streaming import stream_response
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter() router = APIRouter()
...@@ -17,10 +18,13 @@ async def create_completion(request:Request,create:CompletionCreate): ...@@ -17,10 +18,13 @@ async def create_completion(request:Request,create:CompletionCreate):
print(f'COMPLETION INPUT:----\n{create.prompt}\n----') print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
if create.stream: if create.stream:
async def inner(): async def inner():
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p): async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
d = {'choices':[{'delta':{'content':token}}]} d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n" yield f"data:{json.dumps(d)}\n\n"
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]} d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
...@@ -28,6 +32,10 @@ async def create_completion(request:Request,create:CompletionCreate): ...@@ -28,6 +32,10 @@ async def create_completion(request:Request,create:CompletionCreate):
return stream_response(request,inner()) return stream_response(request,inner())
else: else:
comp = CompletionObject(id=id,object='text_completion',created=int(time())) comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p): async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
comp.append_token(token) comp.append_token(token)
return comp return comp
...@@ -142,7 +142,7 @@ class ThreadContext: ...@@ -142,7 +142,7 @@ class ThreadContext:
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress) yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
yield self.run.stream_response_with_event(RunObject.Status.in_progress) yield self.run.stream_response_with_event(RunObject.Status.in_progress)
async for token in self.interface.inference(local_messages,self.thread.id): async for token, finish_reason in self.interface.inference(local_messages,self.thread.id):
if self.run.status == RunObject.Status.cancelling: if self.run.status == RunObject.Status.cancelling:
logger.warn(f'Run {self.run.id} cancelling') logger.warn(f'Run {self.run.id} cancelling')
break break
......
...@@ -16,6 +16,7 @@ from ktransformers.local_chat import custom_models, default_optimize_rules ...@@ -16,6 +16,7 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device from ktransformers.util.utils import get_device
from typing import Optional from typing import Optional
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
from ktransformers.server.schemas.endpoints.chat import RawUsage
warm_uped = False warm_uped = False
...@@ -231,3 +232,12 @@ class KTransformersInterface(TransformersInterface): ...@@ -231,3 +232,12 @@ class KTransformersInterface(TransformersInterface):
async with self._infer_lock: async with self._infer_lock:
async for v in super().inference(local_messages, thread_id, temperature, top_p): async for v in super().inference(local_messages, thread_id, temperature, top_p):
yield v yield v
# return this inference raw usage
yield RawUsage(
tokenize_time = self.profiler.get_timer_sec('tokenize'),
prefill_time = self.profiler.get_timer_sec('prefill'),
decode_time = self.profiler.get_timer_sec('decode'),
prefill_count = self.profiler.get_counter('prefill'),
decode_count = self.profiler.get_counter('decode'),
)
\ No newline at end of file
...@@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}") logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
if(self.max_new_tokens <= 0): if(self.max_new_tokens <= 0):
logger.warning("max_new_tokens is less than 0") logger.warning("max_new_tokens is less than 0")
yield self.streamer.end() yield self.streamer.end(), "length"
return return
logger.info(f"max_new_tokens: {self.max_new_tokens}") logger.info(f"max_new_tokens: {self.max_new_tokens}")
self.profiler.set_counter("decode", 0) self.profiler.set_counter("decode", 0)
...@@ -348,10 +348,17 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -348,10 +348,17 @@ class TransformersInterface(BackendInterfaceBase):
next_token = self.decode_one_tokens() next_token = self.decode_one_tokens()
self.profiler.inc("decode") self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token): if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
yield self.streamer.end(), None
yield "", "stop"
assert self.args.batch_size == 1 assert self.args.batch_size == 1
break break
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token), None
yield self.streamer.end()
else: # for's else, if output get max new tokens
yield self.streamer.end(), None
yield "", "length"
def check_is_new(self, thread_id: str): def check_is_new(self, thread_id: str):
if not self.use_static_cache: if not self.use_static_cache:
...@@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase):
if Config().user_force_think: if Config().user_force_think:
think = '<think>\n' think = '<think>\n'
print(think, end="",flush=True) print(think, end="",flush=True)
yield think yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p): for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done # output think token after prefill done
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t yield t, None
self.profiler.pause_timer("prefill") self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode") self.profiler.create_and_start_timer("decode")
for t in self.generate(): for t, finish_reason in self.generate():
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t yield t, finish_reason
print("") print("")
self.profiler.pause_timer("decode") self.profiler.pause_timer("decode")
self.report_last_time_performance() self.report_last_time_performance()
...@@ -5,6 +5,7 @@ langchain >= 0.2.0 ...@@ -5,6 +5,7 @@ langchain >= 0.2.0
blessed >= 1.20.0 blessed >= 1.20.0
accelerate >= 0.31.0 accelerate >= 0.31.0
sentencepiece >= 0.1.97 sentencepiece >= 0.1.97
openai
setuptools setuptools
build build
ninja ninja
......
from typing import List, Optional from typing import List, Optional
from typing_extensions import Literal
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
from ktransformers.server.schemas.base import Object from ktransformers.server.schemas.base import Object
from openai.types.completion_usage import CompletionUsage
from openai.types.chat.chat_completion_chunk import Choice
class Role(Enum): class Role(Enum):
system = 'system' system = 'system'
user = 'user' user = 'user'
...@@ -31,50 +36,25 @@ class ChatCompletionCreate(BaseModel): ...@@ -31,50 +36,25 @@ class ChatCompletionCreate(BaseModel):
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]
class FinishReason(Enum):
stop = 'stop'
length = 'length'
class Choice(BaseModel):
index: int
message: Message
logprobs: Optional[str] = None
finish_reason: FinishReason = None
class DeltaChoice(BaseModel):
index: int
delta: Message
logprobs: Optional[str] = None
finish_reason: FinishReason = None
class Usage(BaseModel):
completion_tokens:int
prompt_tokens:int
total_tokens:int
class ChatCompletionChunk(BaseModel):
id: str
choices: List[Choice]
created: int
model: str
object: Literal["chat.completion.chunk"]
service_tier: Optional[Literal["scale", "default"]] = None
system_fingerprint: Optional[str] = None
usage: Optional[CompletionUsage] = None
class ChatCompletionBase(Object):
created:int
model:str = 'not implmented'
system_fingerprint:str = 'not implmented'
usage: Optional[Usage] = None
class ChatCompletionObject(ChatCompletionBase):
choices:List[Choice] = []
def append_token(self,token:str):
if len(self.choices) == 0:
self.choices.append(Choice(index=0,message=Message(content='',role=Role.assistant)))
self.choices[0].message.content += token
class ChatCompletionChunk(ChatCompletionBase):
choices:List[DeltaChoice] = []
def set_token(self,token:str):
self.choices = [
DeltaChoice(index=0,delta=Message(content=token,role=Role.assistant))
]
def to_stream_reply(self): def to_stream_reply(self):
return f"data: {self.model_dump_json()}\n\n" return f"data: {self.model_dump_json()}\n\n"
class RawUsage(BaseModel):
tokenize_time: float
prefill_time: float
decode_time: float
prefill_count: int
decode_count: int
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment