Unverified Commit 6e4da83d authored by wang jiahao's avatar wang jiahao Committed by GitHub
Browse files

Merge pull request #978 from cyhasuka/main

Feat: Support Non-streaming chat in Ollama backend
parents b0551323 877aec85
......@@ -49,7 +49,10 @@ class OllamaGenerationStreamResponse(BaseModel):
done: bool = Field(...)
class OllamaGenerationResponse(BaseModel):
pass
model: str
created_at: str
response: str
done: bool
@router.post("/generate", tags=['ollama'])
async def generate(request: Request, input: OllamaGenerateCompletionRequest):
......@@ -81,7 +84,20 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else:
raise NotImplementedError
complete_response = ""
async for res in interface.inference(input.prompt, id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
complete_response += token
response = OllamaGenerationResponse(
model=config.model_name,
created_at=str(datetime.now()),
response=complete_response,
done=True
)
return response
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
class OllamaChatCompletionMessage(BaseModel):
......@@ -106,10 +122,17 @@ class OllamaChatCompletionStreamResponse(BaseModel):
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
class OllamaChatCompletionResponse(BaseModel):
pass
model: str
created_at: str
message: dict
done: bool
total_duration: Optional[int] = Field(None, description="Total time spent in nanoseconds")
load_duration: Optional[int] = Field(None, description="Time spent loading model in nanoseconds")
prompt_eval_count: Optional[int] = Field(None, description="Number of tokens in prompt")
prompt_eval_duration: Optional[int] = Field(None, description="Time spent evaluating prompt in nanoseconds")
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
@router.post("/chat", tags=['ollama'])
async def chat(request: Request, input: OllamaChatCompletionRequest):
......@@ -164,7 +187,38 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else:
raise NotImplementedError("Non-streaming chat is not implemented.")
start_time = time()
complete_response = ""
eval_count = 0
async for res in interface.inference(prompt, id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
complete_response += token
eval_count += 1
end_time = time()
total_duration = int((end_time - start_time) * 1_000_000_000)
prompt_eval_count = len(prompt.split())
eval_duration = total_duration
prompt_eval_duration = 0
load_duration = 0
response = OllamaChatCompletionResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={"role": "assistant", "content": complete_response},
done=True,
total_duration=total_duration,
load_duration=load_duration,
prompt_eval_count=prompt_eval_count,
prompt_eval_duration=prompt_eval_duration,
eval_count=eval_count,
eval_duration=eval_duration
)
return response
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
class OllamaModel(BaseModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment