Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
f462744f
Commit
f462744f
authored
Jul 11, 2024
by
Timothy J. Baek
Browse files
refac
parent
9ab97b83
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
12 deletions
+33
-12
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+6
-4
backend/apps/openai/main.py
backend/apps/openai/main.py
+2
-0
backend/apps/webui/main.py
backend/apps/webui/main.py
+9
-2
backend/main.py
backend/main.py
+16
-6
No files found.
backend/apps/ollama/main.py
View file @
f462744f
...
...
@@ -728,8 +728,10 @@ async def generate_chat_completion(
)
payload
=
{
**
form_data
.
model_dump
(
exclude_none
=
True
),
**
form_data
.
model_dump
(
exclude_none
=
True
,
exclude
=
[
"metadata"
]
),
}
if
"metadata"
in
payload
:
del
payload
[
"metadata"
]
model_id
=
form_data
.
model
model_info
=
Models
.
get_model_by_id
(
model_id
)
...
...
@@ -894,9 +896,9 @@ async def generate_openai_chat_completion(
):
form_data
=
OpenAIChatCompletionForm
(
**
form_data
)
payload
=
{
**
form_data
.
model_dump
(
exclude_none
=
True
),
}
payload
=
{
**
form_data
}
if
"metadata"
in
payload
:
del
payload
[
"metadata"
]
model_id
=
form_data
.
model
model_info
=
Models
.
get_model_by_id
(
model_id
)
...
...
backend/apps/openai/main.py
View file @
f462744f
...
...
@@ -357,6 +357,8 @@ async def generate_chat_completion(
):
idx
=
0
payload
=
{
**
form_data
}
if
"metadata"
in
payload
:
del
payload
[
"metadata"
]
model_id
=
form_data
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
...
...
backend/apps/webui/main.py
View file @
f462744f
...
...
@@ -20,7 +20,6 @@ from apps.webui.routers import (
)
from
apps.webui.models.functions
import
Functions
from
apps.webui.models.models
import
Models
from
apps.webui.utils
import
load_function_module_by_id
from
utils.misc
import
stream_message_template
...
...
@@ -53,7 +52,7 @@ import uuid
import
time
import
json
from
typing
import
Iterator
,
Generator
from
typing
import
Iterator
,
Generator
,
Optional
from
pydantic
import
BaseModel
app
=
FastAPI
()
...
...
@@ -193,6 +192,14 @@ async def generate_function_chat_completion(form_data, user):
model_id
=
form_data
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
metadata
=
None
if
"metadata"
in
form_data
:
metadata
=
form_data
[
"metadata"
]
del
form_data
[
"metadata"
]
if
metadata
:
print
(
metadata
)
if
model_info
:
if
model_info
.
base_model_id
:
form_data
[
"model"
]
=
model_info
.
base_model_id
...
...
backend/main.py
View file @
f462744f
...
...
@@ -618,6 +618,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content
=
{
"detail"
:
str
(
e
)},
)
# `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
task
=
None
if
"task"
in
body
:
task
=
body
[
"task"
]
del
body
[
"task"
]
# Extract session_id, chat_id and message_id from the request body
session_id
=
None
if
"session_id"
in
body
:
...
...
@@ -632,6 +638,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
message_id
=
body
[
"id"
]
del
body
[
"id"
]
__event_emitter__
=
await
get_event_emitter
(
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"session_id"
:
session_id
}
)
...
...
@@ -691,6 +699,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if
len
(
citations
)
>
0
:
data_items
.
append
({
"citations"
:
citations
})
body
[
"metadata"
]
=
{
"session_id"
:
session_id
,
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"task"
:
task
,
}
modified_body_bytes
=
json
.
dumps
(
body
).
encode
(
"utf-8"
)
# Replace the request body with the modified one
request
.
_body
=
modified_body_bytes
...
...
@@ -811,9 +826,6 @@ def filter_pipeline(payload, user):
if
"detail"
in
res
:
raise
Exception
(
r
.
status_code
,
res
[
"detail"
])
if
"pipeline"
not
in
app
.
state
.
MODELS
[
model_id
]
and
"task"
in
payload
:
del
payload
[
"task"
]
return
payload
...
...
@@ -1024,11 +1036,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
status_code
=
status
.
HTTP_404_NOT_FOUND
,
detail
=
"Model not found"
,
)
model
=
app
.
state
.
MODELS
[
model_id
]
pipe
=
model
.
get
(
"pipe"
)
if
pipe
:
if
model
.
get
(
"pipe"
):
return
await
generate_function_chat_completion
(
form_data
,
user
=
user
)
if
model
[
"owned_by"
]
==
"ollama"
:
return
await
generate_ollama_chat_completion
(
form_data
,
user
=
user
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment