"tests/python/common/transforms/test_transform.py" did not exist on "a1472bcff6baf52303ba76a0230658a52d3be2d4"
Unverified Commit d050d865 authored by Yuhao Tsui's avatar Yuhao Tsui Committed by GitHub
Browse files

Update completions.py

parent 1bcfce8c
...@@ -47,7 +47,10 @@ class OllamaGenerationStreamResponse(BaseModel): ...@@ -47,7 +47,10 @@ class OllamaGenerationStreamResponse(BaseModel):
done: bool = Field(...) done: bool = Field(...)
class OllamaGenerationResponse(BaseModel): class OllamaGenerationResponse(BaseModel):
pass model: str
created_at: str
response: str
done: bool
@router.post("/generate", tags=['ollama']) @router.post("/generate", tags=['ollama'])
async def generate(request: Request, input: OllamaGenerateCompletionRequest): async def generate(request: Request, input: OllamaGenerateCompletionRequest):
...@@ -75,8 +78,17 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest): ...@@ -75,8 +78,17 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
yield d.model_dump_json() + '\n' yield d.model_dump_json() + '\n'
return check_link_response(request, inner()) return check_link_response(request, inner())
else: else:
raise NotImplementedError complete_response = ""
async for token in interface.inference(input.prompt, id):
complete_response += token
response = OllamaGenerationResponse(
model=config.model_name,
created_at=str(datetime.now()),
response=complete_response,
done=True
)
return response
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
class OllamaChatCompletionMessage(BaseModel): class OllamaChatCompletionMessage(BaseModel):
role: str role: str
...@@ -100,10 +112,17 @@ class OllamaChatCompletionStreamResponse(BaseModel): ...@@ -100,10 +112,17 @@ class OllamaChatCompletionStreamResponse(BaseModel):
eval_count: Optional[int] = Field(None, description="Number of tokens generated") eval_count: Optional[int] = Field(None, description="Number of tokens generated")
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds") eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
class OllamaChatCompletionResponse(BaseModel): class OllamaChatCompletionResponse(BaseModel):
pass model: str
created_at: str
message: dict
done: bool
total_duration: Optional[int] = Field(None, description="Total time spent in nanoseconds")
load_duration: Optional[int] = Field(None, description="Time spent loading model in nanoseconds")
prompt_eval_count: Optional[int] = Field(None, description="Number of tokens in prompt")
prompt_eval_duration: Optional[int] = Field(None, description="Time spent evaluating prompt in nanoseconds")
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
@router.post("/chat", tags=['ollama']) @router.post("/chat", tags=['ollama'])
async def chat(request: Request, input: OllamaChatCompletionRequest): async def chat(request: Request, input: OllamaChatCompletionRequest):
...@@ -154,8 +173,35 @@ async def chat(request: Request, input: OllamaChatCompletionRequest): ...@@ -154,8 +173,35 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
yield d.model_dump_json() + '\n' yield d.model_dump_json() + '\n'
return check_link_response(request, inner()) return check_link_response(request, inner())
else: else:
raise NotImplementedError("Non-streaming chat is not implemented.") start_time = time()
complete_response = ""
eval_count = 0
async for token in interface.inference(prompt, id):
complete_response += token
eval_count += 1
end_time = time()
total_duration = int((end_time - start_time) * 1_000_000_000)
prompt_eval_count = len(prompt.split())
eval_duration = total_duration
prompt_eval_duration = 0
load_duration = 0
response = OllamaChatCompletionResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={"role": "assistant", "content": complete_response},
done=True,
total_duration=total_duration,
load_duration=load_duration,
prompt_eval_count=prompt_eval_count,
prompt_eval_duration=prompt_eval_duration,
eval_count=eval_count,
eval_duration=eval_duration
)
return response
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
class OllamaModel(BaseModel): class OllamaModel(BaseModel):
name: str name: str
...@@ -214,4 +260,4 @@ async def show(request: Request, input: OllamaShowRequest): ...@@ -214,4 +260,4 @@ async def show(request: Request, input: OllamaShowRequest):
quantization_level=" " quantization_level=" "
), ),
model_info=OllamaModelInfo() model_info=OllamaModelInfo()
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment