Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
a9a6ed8b
Unverified
Commit
a9a6ed8b
authored
Aug 02, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Aug 02, 2024
Browse files
Merge pull request #4237 from michaelpoluektov/refactor-webui-main
refactor: Simplify functions
parents
64b41655
e6c64282
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
284 additions
and
389 deletions
+284
-389
backend/apps/socket/main.py
backend/apps/socket/main.py
+2
-4
backend/apps/webui/main.py
backend/apps/webui/main.py
+225
-302
backend/apps/webui/models/models.py
backend/apps/webui/models/models.py
+2
-10
backend/main.py
backend/main.py
+16
-41
backend/utils/misc.py
backend/utils/misc.py
+39
-32
No files found.
backend/apps/socket/main.py
View file @
a9a6ed8b
...
@@ -52,7 +52,6 @@ async def user_join(sid, data):
...
@@ -52,7 +52,6 @@ async def user_join(sid, data):
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
if
user
:
if
user
:
SESSION_POOL
[
sid
]
=
user
.
id
SESSION_POOL
[
sid
]
=
user
.
id
if
user
.
id
in
USER_POOL
:
if
user
.
id
in
USER_POOL
:
USER_POOL
[
user
.
id
].
append
(
sid
)
USER_POOL
[
user
.
id
].
append
(
sid
)
...
@@ -80,7 +79,6 @@ def get_models_in_use():
...
@@ -80,7 +79,6 @@ def get_models_in_use():
@
sio
.
on
(
"usage"
)
@
sio
.
on
(
"usage"
)
async
def
usage
(
sid
,
data
):
async
def
usage
(
sid
,
data
):
model_id
=
data
[
"model"
]
model_id
=
data
[
"model"
]
# Cancel previous callback if there is one
# Cancel previous callback if there is one
...
@@ -139,7 +137,7 @@ async def disconnect(sid):
...
@@ -139,7 +137,7 @@ async def disconnect(sid):
print
(
f
"Unknown session ID
{
sid
}
disconnected"
)
print
(
f
"Unknown session ID
{
sid
}
disconnected"
)
async
def
get_event_emitter
(
request_info
):
def
get_event_emitter
(
request_info
):
async
def
__event_emitter__
(
event_data
):
async
def
__event_emitter__
(
event_data
):
await
sio
.
emit
(
await
sio
.
emit
(
"chat-events"
,
"chat-events"
,
...
@@ -154,7 +152,7 @@ async def get_event_emitter(request_info):
...
@@ -154,7 +152,7 @@ async def get_event_emitter(request_info):
return
__event_emitter__
return
__event_emitter__
async
def
get_event_call
(
request_info
):
def
get_event_call
(
request_info
):
async
def
__event_call__
(
event_data
):
async
def
__event_call__
(
event_data
):
response
=
await
sio
.
call
(
response
=
await
sio
.
call
(
"chat-events"
,
"chat-events"
,
...
...
backend/apps/webui/main.py
View file @
a9a6ed8b
This diff is collapsed.
Click to expand it.
backend/apps/webui/models/models.py
View file @
a9a6ed8b
import
json
import
logging
import
logging
from
typing
import
Optional
from
typing
import
Optional
,
List
from
pydantic
import
BaseModel
,
ConfigDict
from
pydantic
import
BaseModel
,
ConfigDict
from
sqlalchemy
import
String
,
Column
,
BigInteger
,
Text
from
sqlalchemy
import
Column
,
BigInteger
,
Text
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_db
from
apps.webui.internal.db
import
Base
,
JSONField
,
get_db
from
typing
import
List
,
Union
,
Optional
from
config
import
SRC_LOG_LEVELS
from
config
import
SRC_LOG_LEVELS
import
time
import
time
...
@@ -113,7 +111,6 @@ class ModelForm(BaseModel):
...
@@ -113,7 +111,6 @@ class ModelForm(BaseModel):
class
ModelsTable
:
class
ModelsTable
:
def
insert_new_model
(
def
insert_new_model
(
self
,
form_data
:
ModelForm
,
user_id
:
str
self
,
form_data
:
ModelForm
,
user_id
:
str
)
->
Optional
[
ModelModel
]:
)
->
Optional
[
ModelModel
]:
...
@@ -126,9 +123,7 @@ class ModelsTable:
...
@@ -126,9 +123,7 @@ class ModelsTable:
}
}
)
)
try
:
try
:
with
get_db
()
as
db
:
with
get_db
()
as
db
:
result
=
Model
(
**
model
.
model_dump
())
result
=
Model
(
**
model
.
model_dump
())
db
.
add
(
result
)
db
.
add
(
result
)
db
.
commit
()
db
.
commit
()
...
@@ -144,13 +139,11 @@ class ModelsTable:
...
@@ -144,13 +139,11 @@ class ModelsTable:
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
with
get_db
()
as
db
:
with
get_db
()
as
db
:
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
return
[
ModelModel
.
model_validate
(
model
)
for
model
in
db
.
query
(
Model
).
all
()]
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
try
:
try
:
with
get_db
()
as
db
:
with
get_db
()
as
db
:
model
=
db
.
get
(
Model
,
id
)
model
=
db
.
get
(
Model
,
id
)
return
ModelModel
.
model_validate
(
model
)
return
ModelModel
.
model_validate
(
model
)
except
:
except
:
...
@@ -178,7 +171,6 @@ class ModelsTable:
...
@@ -178,7 +171,6 @@ class ModelsTable:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
try
:
with
get_db
()
as
db
:
with
get_db
()
as
db
:
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
db
.
query
(
Model
).
filter_by
(
id
=
id
).
delete
()
db
.
commit
()
db
.
commit
()
...
...
backend/main.py
View file @
a9a6ed8b
...
@@ -13,8 +13,6 @@ import aiohttp
...
@@ -13,8 +13,6 @@ import aiohttp
import
requests
import
requests
import
mimetypes
import
mimetypes
import
shutil
import
shutil
import
os
import
uuid
import
inspect
import
inspect
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
,
UploadFile
,
File
,
Form
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
,
UploadFile
,
File
,
Form
...
@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware
...
@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware
from
starlette.responses
import
StreamingResponse
,
Response
,
RedirectResponse
from
starlette.responses
import
StreamingResponse
,
Response
,
RedirectResponse
from
apps.socket.main
import
sio
,
app
as
socket_app
,
get_event_emitter
,
get_event_call
from
apps.socket.main
import
app
as
socket_app
,
get_event_emitter
,
get_event_call
from
apps.ollama.main
import
(
from
apps.ollama.main
import
(
app
as
ollama_app
,
app
as
ollama_app
,
get_all_models
as
get_ollama_models
,
get_all_models
as
get_ollama_models
,
...
@@ -619,32 +617,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -619,32 +617,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content
=
{
"detail"
:
str
(
e
)},
content
=
{
"detail"
:
str
(
e
)},
)
)
# Extract valves from the request body
metadata
=
{
valves
=
None
"chat_id"
:
body
.
pop
(
"chat_id"
,
None
),
if
"valves"
in
body
:
"message_id"
:
body
.
pop
(
"id"
,
None
),
valves
=
body
[
"valves"
]
"session_id"
:
body
.
pop
(
"session_id"
,
None
),
del
body
[
"valves"
]
"valves"
:
body
.
pop
(
"valves"
,
None
),
}
# Extract session_id, chat_id and message_id from the request body
session_id
=
None
__event_emitter__
=
get_event_emitter
(
metadata
)
if
"session_id"
in
body
:
__event_call__
=
get_event_call
(
metadata
)
session_id
=
body
[
"session_id"
]
del
body
[
"session_id"
]
chat_id
=
None
if
"chat_id"
in
body
:
chat_id
=
body
[
"chat_id"
]
del
body
[
"chat_id"
]
message_id
=
None
if
"id"
in
body
:
message_id
=
body
[
"id"
]
del
body
[
"id"
]
__event_emitter__
=
await
get_event_emitter
(
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"session_id"
:
session_id
}
)
__event_call__
=
await
get_event_call
(
{
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"session_id"
:
session_id
}
)
# Initialize data_items to store additional data to be sent to the client
# Initialize data_items to store additional data to be sent to the client
data_items
=
[]
data_items
=
[]
...
@@ -709,13 +690,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
...
@@ -709,13 +690,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if
len
(
citations
)
>
0
:
if
len
(
citations
)
>
0
:
data_items
.
append
({
"citations"
:
citations
})
data_items
.
append
({
"citations"
:
citations
})
body
[
"metadata"
]
=
{
body
[
"metadata"
]
=
metadata
"session_id"
:
session_id
,
"chat_id"
:
chat_id
,
"message_id"
:
message_id
,
"valves"
:
valves
,
}
modified_body_bytes
=
json
.
dumps
(
body
).
encode
(
"utf-8"
)
modified_body_bytes
=
json
.
dumps
(
body
).
encode
(
"utf-8"
)
# Replace the request body with the modified one
# Replace the request body with the modified one
request
.
_body
=
modified_body_bytes
request
.
_body
=
modified_body_bytes
...
@@ -1191,13 +1166,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
...
@@ -1191,13 +1166,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
status_code
=
r
.
status_code
,
status_code
=
r
.
status_code
,
content
=
res
,
content
=
res
,
)
)
except
:
except
Exception
:
pass
pass
else
:
else
:
pass
pass
__event_emitter__
=
await
get_event_emitter
(
__event_emitter__
=
get_event_emitter
(
{
{
"chat_id"
:
data
[
"chat_id"
],
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"message_id"
:
data
[
"id"
],
...
@@ -1205,7 +1180,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
...
@@ -1205,7 +1180,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
}
}
)
)
__event_call__
=
await
get_event_call
(
__event_call__
=
get_event_call
(
{
{
"chat_id"
:
data
[
"chat_id"
],
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"message_id"
:
data
[
"id"
],
...
@@ -1334,14 +1309,14 @@ async def chat_completed(
...
@@ -1334,14 +1309,14 @@ async def chat_completed(
)
)
model
=
app
.
state
.
MODELS
[
model_id
]
model
=
app
.
state
.
MODELS
[
model_id
]
__event_emitter__
=
await
get_event_emitter
(
__event_emitter__
=
get_event_emitter
(
{
{
"chat_id"
:
data
[
"chat_id"
],
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"message_id"
:
data
[
"id"
],
"session_id"
:
data
[
"session_id"
],
"session_id"
:
data
[
"session_id"
],
}
}
)
)
__event_call__
=
await
get_event_call
(
__event_call__
=
get_event_call
(
{
{
"chat_id"
:
data
[
"chat_id"
],
"chat_id"
:
data
[
"chat_id"
],
"message_id"
:
data
[
"id"
],
"message_id"
:
data
[
"id"
],
...
...
backend/utils/misc.py
View file @
a9a6ed8b
from
pathlib
import
Path
from
pathlib
import
Path
import
hashlib
import
hashlib
import
json
import
re
import
re
from
datetime
import
timedelta
from
datetime
import
timedelta
from
typing
import
Optional
,
List
,
Tuple
from
typing
import
Optional
,
List
,
Tuple
...
@@ -8,37 +7,39 @@ import uuid
...
@@ -8,37 +7,39 @@ import uuid
import
time
import
time
def
get_last_user_message_item
(
messages
:
List
[
dict
])
->
str
:
def
get_last_user_message_item
(
messages
:
List
[
dict
])
->
Optional
[
dict
]
:
for
message
in
reversed
(
messages
):
for
message
in
reversed
(
messages
):
if
message
[
"role"
]
==
"user"
:
if
message
[
"role"
]
==
"user"
:
return
message
return
message
return
None
return
None
def
get_last_user_message
(
messages
:
List
[
dict
])
->
str
:
def
get_content_from_message
(
message
:
dict
)
->
Optional
[
str
]:
message
=
get_last_user_message_item
(
messages
)
if
isinstance
(
message
[
"content"
],
list
):
for
item
in
message
[
"content"
]:
if
message
is
not
None
:
if
item
[
"type"
]
==
"text"
:
if
isinstance
(
message
[
"content"
],
list
):
return
item
[
"text"
]
for
item
in
message
[
"content"
]:
else
:
if
item
[
"type"
]
==
"text"
:
return
item
[
"text"
]
return
message
[
"content"
]
return
message
[
"content"
]
return
None
return
None
def
get_last_assistant_message
(
messages
:
List
[
dict
])
->
str
:
def
get_last_user_message
(
messages
:
List
[
dict
])
->
Optional
[
str
]:
message
=
get_last_user_message_item
(
messages
)
if
message
is
None
:
return
None
return
get_content_from_message
(
message
)
def
get_last_assistant_message
(
messages
:
List
[
dict
])
->
Optional
[
str
]:
for
message
in
reversed
(
messages
):
for
message
in
reversed
(
messages
):
if
message
[
"role"
]
==
"assistant"
:
if
message
[
"role"
]
==
"assistant"
:
if
isinstance
(
message
[
"content"
],
list
):
return
get_content_from_message
(
message
)
for
item
in
message
[
"content"
]:
if
item
[
"type"
]
==
"text"
:
return
item
[
"text"
]
return
message
[
"content"
]
return
None
return
None
def
get_system_message
(
messages
:
List
[
dict
])
->
dict
:
def
get_system_message
(
messages
:
List
[
dict
])
->
Optional
[
dict
]
:
for
message
in
messages
:
for
message
in
messages
:
if
message
[
"role"
]
==
"system"
:
if
message
[
"role"
]
==
"system"
:
return
message
return
message
...
@@ -49,7 +50,7 @@ def remove_system_message(messages: List[dict]) -> List[dict]:
...
@@ -49,7 +50,7 @@ def remove_system_message(messages: List[dict]) -> List[dict]:
return
[
message
for
message
in
messages
if
message
[
"role"
]
!=
"system"
]
return
[
message
for
message
in
messages
if
message
[
"role"
]
!=
"system"
]
def
pop_system_message
(
messages
:
List
[
dict
])
->
Tuple
[
dict
,
List
[
dict
]]:
def
pop_system_message
(
messages
:
List
[
dict
])
->
Tuple
[
Optional
[
dict
]
,
List
[
dict
]]:
return
get_system_message
(
messages
),
remove_system_message
(
messages
)
return
get_system_message
(
messages
),
remove_system_message
(
messages
)
...
@@ -87,23 +88,29 @@ def add_or_update_system_message(content: str, messages: List[dict]):
...
@@ -87,23 +88,29 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return
messages
return
messages
def
stream
_message_template
(
model
:
str
,
message
:
str
):
def
openai_chat
_message_template
(
model
:
str
):
return
{
return
{
"id"
:
f
"
{
model
}
-
{
str
(
uuid
.
uuid4
())
}
"
,
"id"
:
f
"
{
model
}
-
{
str
(
uuid
.
uuid4
())
}
"
,
"object"
:
"chat.completion.chunk"
,
"created"
:
int
(
time
.
time
()),
"created"
:
int
(
time
.
time
()),
"model"
:
model
,
"model"
:
model
,
"choices"
:
[
"choices"
:
[{
"index"
:
0
,
"logprobs"
:
None
,
"finish_reason"
:
None
}],
{
"index"
:
0
,
"delta"
:
{
"content"
:
message
},
"logprobs"
:
None
,
"finish_reason"
:
None
,
}
],
}
}
def
openai_chat_chunk_message_template
(
model
:
str
,
message
:
str
):
template
=
openai_chat_message_template
(
model
)
template
[
"object"
]
=
"chat.completion.chunk"
template
[
"choices"
][
0
][
"delta"
]
=
{
"content"
:
message
}
return
template
def
openai_chat_completion_message_template
(
model
:
str
,
message
:
str
):
template
=
openai_chat_message_template
(
model
)
template
[
"object"
]
=
"chat.completion"
template
[
"choices"
][
0
][
"message"
]
=
{
"content"
:
message
,
"role"
:
"assistant"
}
template
[
"choices"
][
0
][
"finish_reason"
]
=
"stop"
def
get_gravatar_url
(
email
):
def
get_gravatar_url
(
email
):
# Trim leading and trailing whitespace from
# Trim leading and trailing whitespace from
# an email address and force all characters
# an email address and force all characters
...
@@ -174,7 +181,7 @@ def extract_folders_after_data_docs(path):
...
@@ -174,7 +181,7 @@ def extract_folders_after_data_docs(path):
tags
=
[]
tags
=
[]
folders
=
parts
[
index_docs
:
-
1
]
folders
=
parts
[
index_docs
:
-
1
]
for
idx
,
part
in
enumerate
(
folders
):
for
idx
,
_
in
enumerate
(
folders
):
tags
.
append
(
"/"
.
join
(
folders
[:
idx
+
1
]))
tags
.
append
(
"/"
.
join
(
folders
[:
idx
+
1
]))
return
tags
return
tags
...
@@ -270,11 +277,11 @@ def parse_ollama_modelfile(model_text):
...
@@ -270,11 +277,11 @@ def parse_ollama_modelfile(model_text):
value
=
param_match
.
group
(
1
)
value
=
param_match
.
group
(
1
)
try
:
try
:
if
param_type
==
int
:
if
param_type
is
int
:
value
=
int
(
value
)
value
=
int
(
value
)
elif
param_type
==
float
:
elif
param_type
is
float
:
value
=
float
(
value
)
value
=
float
(
value
)
elif
param_type
==
bool
:
elif
param_type
is
bool
:
value
=
value
.
lower
()
==
"true"
value
=
value
.
lower
()
==
"true"
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
print
(
e
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment