Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
a9a6ed8b
Unverified
Commit
a9a6ed8b
authored
Aug 02, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Aug 02, 2024
Browse files
Merge pull request #4237 from michaelpoluektov/refactor-webui-main
refactor: Simplify functions
parents
64b41655
e6c64282
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
284 additions
and
389 deletions
+284
-389
backend/apps/socket/main.py
backend/apps/socket/main.py
+2
-4
backend/apps/webui/main.py
backend/apps/webui/main.py
+225
-302
backend/apps/webui/models/models.py
backend/apps/webui/models/models.py
+2
-10
backend/main.py
backend/main.py
+16
-41
backend/utils/misc.py
backend/utils/misc.py
+39
-32
No files found.
backend/apps/socket/main.py
View file @
a9a6ed8b
...
...
@@ -52,7 +52,6 @@ async def user_join(sid, data):
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
if
user
:
SESSION_POOL
[
sid
]
=
user
.
id
if
user
.
id
in
USER_POOL
:
USER_POOL
[
user
.
id
].
append
(
sid
)
...
...
@@ -80,7 +79,6 @@ def get_models_in_use():
@
sio
.
on
(
"usage"
)
async
def
usage
(
sid
,
data
):
model_id
=
data
[
"model"
]
# Cancel previous callback if there is one
...
...
@@ -139,7 +137,7 @@ async def disconnect(sid):
print
(
f
"Unknown session ID
{
sid
}
disconnected"
)
async
def
get_event_emitter
(
request_info
):
def
get_event_emitter
(
request_info
):
async
def
__event_emitter__
(
event_data
):
await
sio
.
emit
(
"chat-events"
,
...
...
@@ -154,7 +152,7 @@ async def get_event_emitter(request_info):
return
__event_emitter__
async
def
get_event_call
(
request_info
):
def
get_event_call
(
request_info
):
async
def
__event_call__
(
event_data
):
response
=
await
sio
.
call
(
"chat-events"
,
...
...
backend/apps/webui/main.py
View file @
a9a6ed8b
This diff is collapsed.
Click to expand it.
backend/apps/webui/models/models.py
View file @
a9a6ed8b
import
json
import
logging
from
typing
import
Optional
from
typing
import
Optional
,
List
from
pydantic
import
BaseModel
,
ConfigDict
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
sqlalchemy
import
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_db
from
typing
import
List
,
Union
,
Optional
from
config
import
SRC_LOG_LEVELS
import
time
...
...
@@ -113,7 +111,6 @@ class ModelForm(BaseModel):
class
ModelsTable
:
def
insert_new_model
(
self
,
form_data
:
ModelForm
,
user_id
:
str
)
->
Optional
[
ModelModel
]:
...
...
@@ -126,9 +123,7 @@ class ModelsTable:
}
)
try
:
with
get_db
()
as
db
:
result
=
Model
(
**
model
.
model_dump
())
db
.
add
(
result
)
db
.
commit
()
...
...
@@ -144,13 +139,11 @@ class ModelsTable:
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
with
get_db
()
as
db
:
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
try
:
with
get_db
()
as
db
:
model
=
db
.
get
(
Model
,
id
)
return
ModelModel
.
model_validate
(
model
)
except
:
...
...
@@ -178,7 +171,6 @@ class ModelsTable:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
with
get_db
()
as
db
:
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
...
...
backend/main.py
View file @
a9a6ed8b
...
...
@@ -13,8 +13,6 @@ import aiohttp
import
requests
import
mimetypes
import
shutil
import
os
import
uuid
import
inspect
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
,
UploadFile
,
File
,
Form
...
...
@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware
from
starlette.responses
import
StreamingResponse
,
Response
,
RedirectResponse
from
apps.socket.main
import
sio
,
app
as
socket_app
,
get_event_emitter
,
get_event_call
from
apps.socket.main
import
app
as
socket_app
,
get_event_emitter
,
get_event_call
from
apps.ollama.main
import
(
app
as
ollama_app
,
get_all_models
as
get_ollama_models
,
...
...
@@ -619,32 +617,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content
=
{
"detail"
:
str
(
e
)},
)
# Extract valves from the request body
valves
=
None
if
"valves"
in
body
:
valves
=
body
[
"valves"
]
del
body
[
"valves"
]
# Extract session_id, chat_id and message_id from the request body
session_id
=
None
if
"session_id"
in
body
:
session_id
=
body
[
"session_id"
]
del
body
[
"session_id"
]
chat_id
=
None
if
"chat_id"
in
body
:
chat_id
=
body
[
"chat_id"
]
del
body
[
"chat_id"
]
message_id
=
None
if
"id"
in
body
:
message_id
=
body
[
"id"
]
del
body
[
"id"
]
__event_emitter__
=
await
get_event_emitter
(
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"session_id"
:
session_id
}
)
__event_call__
=
await
get_event_call
(
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"session_id"
:
session_id
}
)
metadata
=
{
"chat_id"
:
body
.
pop
(
"chat_id"
,
None
),
"message_id"
:
body
.
pop
(
"id"
,
None
),
"session_id"
:
body
.
pop
(
"session_id"
,
None
),
"valves"
:
body
.
pop
(
"valves"
,
None
),
}
__event_emitter__
=
get_event_emitter
(
metadata
)
__event_call__
=
get_event_call
(
metadata
)
# Initialize data_items to store additional data to be sent to the client
data_items
=
[]
...
...
@@ -709,13 +690,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if
len
(
citations
)
>
0
:
data_items
.
append
({
"citations"
:
citations
})
body
[
"metadata"
]
=
{
"session_id"
:
session_id
,
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"valves"
:
valves
,
}
body
[
"metadata"
]
=
metadata
modified_body_bytes
=
json
.
dumps
(
body
).
encode
(
"utf-8"
)
# Replace the request body with the modified one
request
.
_body
=
modified_body_bytes
...
...
@@ -1191,13 +1166,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
status_code
=
r
.
status_code
,
content
=
res
,
)
except
:
except
Exception
:
pass
else
:
pass
__event_emitter__
=
await
get_event_emitter
(
__event_emitter__
=
get_event_emitter
(
{
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
...
...
@@ -1205,7 +1180,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
}
)
__event_call__
=
await
get_event_call
(
__event_call__
=
get_event_call
(
{
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
...
...
@@ -1334,14 +1309,14 @@ async def chat_completed(
)
model
=
app
.
state
.
MODELS
[
model_id
]
__event_emitter__
=
await
get_event_emitter
(
__event_emitter__
=
get_event_emitter
(
{
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"session_id"
:
data
[
"session_id"
],
}
)
__event_call__
=
await
get_event_call
(
__event_call__
=
get_event_call
(
{
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
...
...
backend/utils/misc.py
View file @
a9a6ed8b
from
pathlib
import
Path
import
hashlib
import
json
import
re
from
datetime
import
timedelta
from
typing
import
Optional
,
List
,
Tuple
...
...
@@ -8,37 +7,39 @@ import uuid
import
time
def
get_last_user_message_item
(
messages
:
List
[
dict
])
->
str
:
def
get_last_user_message_item
(
messages
:
List
[
dict
])
->
Optional
[
dict
]
:
for
message
in
reversed
(
messages
):
if
message
[
"role"
]
==
"user"
:
return
message
return
None
def
get_last_user_message
(
messages
:
List
[
dict
])
->
str
:
message
=
get_last_user_message_item
(
messages
)
if
message
is
not
None
:
if
isinstance
(
message
[
"content"
],
list
):
for
item
in
message
[
"content"
]:
if
item
[
"type"
]
==
"text"
:
return
item
[
"text"
]
def
get_content_from_message
(
message
:
dict
)
->
Optional
[
str
]:
if
isinstance
(
message
[
"content"
],
list
):
for
item
in
message
[
"content"
]:
if
item
[
"type"
]
==
"text"
:
return
item
[
"text"
]
else
:
return
message
[
"content"
]
return
None
def
get_last_assistant_message
(
messages
:
List
[
dict
])
->
str
:
def
get_last_user_message
(
messages
:
List
[
dict
])
->
Optional
[
str
]:
message
=
get_last_user_message_item
(
messages
)
if
message
is
None
:
return
None
return
get_content_from_message
(
message
)
def
get_last_assistant_message
(
messages
:
List
[
dict
])
->
Optional
[
str
]:
for
message
in
reversed
(
messages
):
if
message
[
"role"
]
==
"assistant"
:
if
isinstance
(
message
[
"content"
],
list
):
for
item
in
message
[
"content"
]:
if
item
[
"type"
]
==
"text"
:
return
item
[
"text"
]
return
message
[
"content"
]
return
get_content_from_message
(
message
)
return
None
def
get_system_message
(
messages
:
List
[
dict
])
->
dict
:
def
get_system_message
(
messages
:
List
[
dict
])
->
Optional
[
dict
]
:
for
message
in
messages
:
if
message
[
"role"
]
==
"system"
:
return
message
...
...
@@ -49,7 +50,7 @@ def remove_system_message(messages: List[dict]) -> List[dict]:
return
[
message
for
message
in
messages
if
message
[
"role"
]
!=
"system"
]
def
pop_system_message
(
messages
:
List
[
dict
])
->
Tuple
[
dict
,
List
[
dict
]]:
def
pop_system_message
(
messages
:
List
[
dict
])
->
Tuple
[
Optional
[
dict
]
,
List
[
dict
]]:
return
get_system_message
(
messages
),
remove_system_message
(
messages
)
...
...
@@ -87,23 +88,29 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return
messages
def
stream
_message_template
(
model
:
str
,
message
:
str
):
def
openai_chat
_message_template
(
model
:
str
):
return
{
"id"
:
f
"
{
model
}
-
{
str
(
uuid
.
uuid4
())
}
"
,
"object"
:
"chat.completion.chunk"
,
"created"
:
int
(
time
.
time
()),
"model"
:
model
,
"choices"
:
[
{
"index"
:
0
,
"delta"
:
{
"content"
:
message
},
"logprobs"
:
None
,
"finish_reason"
:
None
,
}
],
"choices"
:
[{
"index"
:
0
,
"logprobs"
:
None
,
"finish_reason"
:
None
}],
}
def
openai_chat_chunk_message_template
(
model
:
str
,
message
:
str
):
template
=
openai_chat_message_template
(
model
)
template
[
"object"
]
=
"chat.completion.chunk"
template
[
"choices"
][
0
][
"delta"
]
=
{
"content"
:
message
}
return
template
def
openai_chat_completion_message_template
(
model
:
str
,
message
:
str
):
template
=
openai_chat_message_template
(
model
)
template
[
"object"
]
=
"chat.completion"
template
[
"choices"
][
0
][
"message"
]
=
{
"content"
:
message
,
"role"
:
"assistant"
}
template
[
"choices"
][
0
][
"finish_reason"
]
=
"stop"
def
get_gravatar_url
(
email
):
# Trim leading and trailing whitespace from
# an email address and force all characters
...
...
@@ -174,7 +181,7 @@ def extract_folders_after_data_docs(path):
tags
=
[]
folders
=
parts
[
index_docs
:
-
1
]
for
idx
,
part
in
enumerate
(
folders
):
for
idx
,
_
in
enumerate
(
folders
):
tags
.
append
(
"/"
.
join
(
folders
[:
idx
+
1
]))
return
tags
...
...
@@ -270,11 +277,11 @@ def parse_ollama_modelfile(model_text):
value
=
param_match
.
group
(
1
)
try
:
if
param_type
==
int
:
if
param_type
is
int
:
value
=
int
(
value
)
elif
param_type
==
float
:
elif
param_type
is
float
:
value
=
float
(
value
)
elif
param_type
==
bool
:
elif
param_type
is
bool
:
value
=
value
.
lower
()
==
"true"
except
Exception
as
e
:
print
(
e
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment