Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
ffa16821
Unverified
Commit
ffa16821
authored
Mar 24, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Mar 24, 2024
Browse files
Merge pull request #1237 from lainedfles/debug_print
Migrate to python logging module with env var control.
parents
a1faa307
371dfc11
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
174 additions
and
82 deletions
+174
-82
backend/apps/audio/main.py
backend/apps/audio/main.py
+8
-4
backend/apps/images/main.py
backend/apps/images/main.py
+7
-3
backend/apps/litellm/main.py
backend/apps/litellm/main.py
+7
-2
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+37
-34
backend/apps/openai/main.py
backend/apps/openai/main.py
+13
-8
backend/apps/rag/main.py
backend/apps/rag/main.py
+12
-9
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+7
-3
backend/apps/web/internal/db.py
backend/apps/web/internal/db.py
+5
-2
backend/apps/web/models/auths.py
backend/apps/web/models/auths.py
+7
-2
backend/apps/web/models/documents.py
backend/apps/web/models/documents.py
+7
-2
backend/apps/web/models/tags.py
backend/apps/web/models/tags.py
+9
-4
backend/apps/web/routers/chats.py
backend/apps/web/routers/chats.py
+7
-2
backend/apps/web/routers/users.py
backend/apps/web/routers/users.py
+6
-1
backend/config.py
backend/config.py
+34
-4
backend/main.py
backend/main.py
+8
-2
No files found.
backend/apps/audio/main.py
View file @
ffa16821
import
os
import
os
import
logging
from
fastapi
import
(
from
fastapi
import
(
FastAPI
,
FastAPI
,
Request
,
Request
,
...
@@ -21,7 +22,10 @@ from utils.utils import (
...
@@ -21,7 +22,10 @@ from utils.utils import (
)
)
from
utils.misc
import
calculate_sha256
from
utils.misc
import
calculate_sha256
from
config
import
CACHE_DIR
,
UPLOAD_DIR
,
WHISPER_MODEL
,
WHISPER_MODEL_DIR
from
config
import
SRC_LOG_LEVELS
,
CACHE_DIR
,
UPLOAD_DIR
,
WHISPER_MODEL
,
WHISPER_MODEL_DIR
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"AUDIO"
])
app
=
FastAPI
()
app
=
FastAPI
()
app
.
add_middleware
(
app
.
add_middleware
(
...
@@ -38,7 +42,7 @@ def transcribe(
...
@@ -38,7 +42,7 @@ def transcribe(
file
:
UploadFile
=
File
(...),
file
:
UploadFile
=
File
(...),
user
=
Depends
(
get_current_user
),
user
=
Depends
(
get_current_user
),
):
):
print
(
file
.
content_type
)
log
.
info
(
f
"file.content_type:
{
file
.
content_type
}
"
)
if
file
.
content_type
not
in
[
"audio/mpeg"
,
"audio/wav"
]:
if
file
.
content_type
not
in
[
"audio/mpeg"
,
"audio/wav"
]:
raise
HTTPException
(
raise
HTTPException
(
...
@@ -62,7 +66,7 @@ def transcribe(
...
@@ -62,7 +66,7 @@ def transcribe(
)
)
segments
,
info
=
model
.
transcribe
(
file_path
,
beam_size
=
5
)
segments
,
info
=
model
.
transcribe
(
file_path
,
beam_size
=
5
)
print
(
log
.
info
(
"Detected language '%s' with probability %f"
"Detected language '%s' with probability %f"
%
(
info
.
language
,
info
.
language_probability
)
%
(
info
.
language
,
info
.
language_probability
)
)
)
...
@@ -72,7 +76,7 @@ def transcribe(
...
@@ -72,7 +76,7 @@ def transcribe(
return
{
"text"
:
transcript
.
strip
()}
return
{
"text"
:
transcript
.
strip
()}
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
...
...
backend/apps/images/main.py
View file @
ffa16821
...
@@ -27,10 +27,14 @@ from pathlib import Path
...
@@ -27,10 +27,14 @@ from pathlib import Path
import
uuid
import
uuid
import
base64
import
base64
import
json
import
json
import
logging
from
config
import
CACHE_DIR
,
AUTOMATIC1111_BASE_URL
,
COMFYUI_BASE_URL
from
config
import
SRC_LOG_LEVELS
,
CACHE_DIR
,
AUTOMATIC1111_BASE_URL
,
COMFYUI_BASE_URL
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"IMAGES"
])
IMAGE_CACHE_DIR
=
Path
(
CACHE_DIR
).
joinpath
(
"./image/generations/"
)
IMAGE_CACHE_DIR
=
Path
(
CACHE_DIR
).
joinpath
(
"./image/generations/"
)
IMAGE_CACHE_DIR
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
IMAGE_CACHE_DIR
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
@@ -304,7 +308,7 @@ def save_b64_image(b64_str):
...
@@ -304,7 +308,7 @@ def save_b64_image(b64_str):
return
image_id
return
image_id
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Error saving image:
{
e
}
"
)
log
.
error
(
f
"Error saving image:
{
e
}
"
)
return
None
return
None
...
@@ -431,7 +435,7 @@ def generate_image(
...
@@ -431,7 +435,7 @@ def generate_image(
res
=
r
.
json
()
res
=
r
.
json
()
print
(
res
)
log
.
debug
(
f
"res:
{
res
}
"
)
images
=
[]
images
=
[]
...
...
backend/apps/litellm/main.py
View file @
ffa16821
import
logging
from
litellm.proxy.proxy_server
import
ProxyConfig
,
initialize
from
litellm.proxy.proxy_server
import
ProxyConfig
,
initialize
from
litellm.proxy.proxy_server
import
app
from
litellm.proxy.proxy_server
import
app
...
@@ -9,7 +11,10 @@ from starlette.responses import StreamingResponse
...
@@ -9,7 +11,10 @@ from starlette.responses import StreamingResponse
import
json
import
json
from
utils.utils
import
get_http_authorization_cred
,
get_current_user
from
utils.utils
import
get_http_authorization_cred
,
get_current_user
from
config
import
ENV
from
config
import
SRC_LOG_LEVELS
,
ENV
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"LITELLM"
])
from
config
import
(
from
config
import
(
...
@@ -49,7 +54,7 @@ async def auth_middleware(request: Request, call_next):
...
@@ -49,7 +54,7 @@ async def auth_middleware(request: Request, call_next):
try
:
try
:
user
=
get_current_user
(
get_http_authorization_cred
(
auth_header
))
user
=
get_current_user
(
get_http_authorization_cred
(
auth_header
))
print
(
user
)
log
.
debug
(
f
"user:
{
user
}
"
)
request
.
state
.
user
=
user
request
.
state
.
user
=
user
except
Exception
as
e
:
except
Exception
as
e
:
return
JSONResponse
(
status_code
=
400
,
content
=
{
"detail"
:
str
(
e
)})
return
JSONResponse
(
status_code
=
400
,
content
=
{
"detail"
:
str
(
e
)})
...
...
backend/apps/ollama/main.py
View file @
ffa16821
...
@@ -23,6 +23,7 @@ import json
...
@@ -23,6 +23,7 @@ import json
import
uuid
import
uuid
import
aiohttp
import
aiohttp
import
asyncio
import
asyncio
import
logging
from
urllib.parse
import
urlparse
from
urllib.parse
import
urlparse
from
typing
import
Optional
,
List
,
Union
from
typing
import
Optional
,
List
,
Union
...
@@ -30,11 +31,13 @@ from typing import Optional, List, Union
...
@@ -30,11 +31,13 @@ from typing import Optional, List, Union
from
apps.web.models.users
import
Users
from
apps.web.models.users
import
Users
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
from
utils.utils
import
decode_token
,
get_current_user
,
get_admin_user
from
utils.utils
import
decode_token
,
get_current_user
,
get_admin_user
from
utils.misc
import
calculate_sha256
from
config
import
OLLAMA_BASE_URLS
,
MODEL_FILTER_ENABLED
,
MODEL_FILTER_LIST
,
UPLOAD_DIR
from
config
import
SRC_LOG_LEVELS
,
OLLAMA_BASE_URLS
,
MODEL_FILTER_ENABLED
,
MODEL_FILTER_LIST
,
UPLOAD_DIR
from
utils.misc
import
calculate_sha256
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"OLLAMA"
])
app
=
FastAPI
()
app
=
FastAPI
()
app
.
add_middleware
(
app
.
add_middleware
(
...
@@ -85,7 +88,7 @@ class UrlUpdateForm(BaseModel):
...
@@ -85,7 +88,7 @@ class UrlUpdateForm(BaseModel):
async
def
update_ollama_api_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)):
async
def
update_ollama_api_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
OLLAMA_BASE_URLS
=
form_data
.
urls
app
.
state
.
OLLAMA_BASE_URLS
=
form_data
.
urls
print
(
app
.
state
.
OLLAMA_BASE_URLS
)
log
.
info
(
f
"
app.state.OLLAMA_BASE_URLS
:
{
app
.
state
.
OLLAMA_BASE_URLS
}
"
)
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
OLLAMA_BASE_URLS
}
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
OLLAMA_BASE_URLS
}
...
@@ -106,7 +109,7 @@ async def fetch_url(url):
...
@@ -106,7 +109,7 @@ async def fetch_url(url):
return
await
response
.
json
()
return
await
response
.
json
()
except
Exception
as
e
:
except
Exception
as
e
:
# Handle connection error here
# Handle connection error here
print
(
f
"Connection error:
{
e
}
"
)
log
.
error
(
f
"Connection error:
{
e
}
"
)
return
None
return
None
...
@@ -130,7 +133,7 @@ def merge_models_lists(model_lists):
...
@@ -130,7 +133,7 @@ def merge_models_lists(model_lists):
async
def
get_all_models
():
async
def
get_all_models
():
print
(
"get_all_models"
)
log
.
info
(
"get_all_models
()
"
)
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/tags"
)
for
url
in
app
.
state
.
OLLAMA_BASE_URLS
]
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/tags"
)
for
url
in
app
.
state
.
OLLAMA_BASE_URLS
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
responses
=
await
asyncio
.
gather
(
*
tasks
)
...
@@ -171,7 +174,7 @@ async def get_ollama_tags(
...
@@ -171,7 +174,7 @@ async def get_ollama_tags(
return
r
.
json
()
return
r
.
json
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -217,7 +220,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
...
@@ -217,7 +220,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
return
r
.
json
()
return
r
.
json
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -243,7 +246,7 @@ async def pull_model(
...
@@ -243,7 +246,7 @@ async def pull_model(
form_data
:
ModelNameForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
form_data
:
ModelNameForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
):
):
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -292,7 +295,7 @@ async def pull_model(
...
@@ -292,7 +295,7 @@ async def pull_model(
return
await
run_in_threadpool
(
get_request
)
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -331,7 +334,7 @@ async def push_model(
...
@@ -331,7 +334,7 @@ async def push_model(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
debug
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -363,7 +366,7 @@ async def push_model(
...
@@ -363,7 +366,7 @@ async def push_model(
try
:
try
:
return
await
run_in_threadpool
(
get_request
)
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -391,9 +394,9 @@ class CreateModelForm(BaseModel):
...
@@ -391,9 +394,9 @@ class CreateModelForm(BaseModel):
async
def
create_model
(
async
def
create_model
(
form_data
:
CreateModelForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
form_data
:
CreateModelForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
):
):
print
(
form_data
)
log
.
debug
(
f
"form_data:
{
form_data
}
"
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -415,7 +418,7 @@ async def create_model(
...
@@ -415,7 +418,7 @@ async def create_model(
r
.
raise_for_status
()
r
.
raise_for_status
()
print
(
r
)
log
.
debug
(
f
"r:
{
r
}
"
)
return
StreamingResponse
(
return
StreamingResponse
(
stream_content
(),
stream_content
(),
...
@@ -428,7 +431,7 @@ async def create_model(
...
@@ -428,7 +431,7 @@ async def create_model(
try
:
try
:
return
await
run_in_threadpool
(
get_request
)
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -466,7 +469,7 @@ async def copy_model(
...
@@ -466,7 +469,7 @@ async def copy_model(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
try
:
r
=
requests
.
request
(
r
=
requests
.
request
(
...
@@ -476,11 +479,11 @@ async def copy_model(
...
@@ -476,11 +479,11 @@ async def copy_model(
)
)
r
.
raise_for_status
()
r
.
raise_for_status
()
print
(
r
.
text
)
log
.
debug
(
f
"r.text:
{
r
.
text
}
"
)
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -513,7 +516,7 @@ async def delete_model(
...
@@ -513,7 +516,7 @@ async def delete_model(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
try
:
r
=
requests
.
request
(
r
=
requests
.
request
(
...
@@ -523,11 +526,11 @@ async def delete_model(
...
@@ -523,11 +526,11 @@ async def delete_model(
)
)
r
.
raise_for_status
()
r
.
raise_for_status
()
print
(
r
.
text
)
log
.
debug
(
f
"r.text:
{
r
.
text
}
"
)
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -553,7 +556,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
...
@@ -553,7 +556,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
])
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
])
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
try
:
r
=
requests
.
request
(
r
=
requests
.
request
(
...
@@ -565,7 +568,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
...
@@ -565,7 +568,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
return
r
.
json
()
return
r
.
json
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -605,7 +608,7 @@ async def generate_embeddings(
...
@@ -605,7 +608,7 @@ async def generate_embeddings(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
try
:
r
=
requests
.
request
(
r
=
requests
.
request
(
...
@@ -617,7 +620,7 @@ async def generate_embeddings(
...
@@ -617,7 +620,7 @@ async def generate_embeddings(
return
r
.
json
()
return
r
.
json
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -665,7 +668,7 @@ async def generate_completion(
...
@@ -665,7 +668,7 @@ async def generate_completion(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -686,7 +689,7 @@ async def generate_completion(
...
@@ -686,7 +689,7 @@ async def generate_completion(
if
request_id
in
REQUEST_POOL
:
if
request_id
in
REQUEST_POOL
:
yield
chunk
yield
chunk
else
:
else
:
pr
in
t
(
"User: canceled request"
)
log
.
warn
in
g
(
"User: canceled request"
)
break
break
finally
:
finally
:
if
hasattr
(
r
,
"close"
):
if
hasattr
(
r
,
"close"
):
...
@@ -763,11 +766,11 @@ async def generate_chat_completion(
...
@@ -763,11 +766,11 @@ async def generate_chat_completion(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
prin
t
(
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
())
log
.
debug
(
"form_data.model_dump_json(exclude_none=True).encode(): {0} "
.
forma
t
(
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
())
)
def
get_request
():
def
get_request
():
nonlocal
form_data
nonlocal
form_data
...
@@ -786,7 +789,7 @@ async def generate_chat_completion(
...
@@ -786,7 +789,7 @@ async def generate_chat_completion(
if
request_id
in
REQUEST_POOL
:
if
request_id
in
REQUEST_POOL
:
yield
chunk
yield
chunk
else
:
else
:
pr
in
t
(
"User: canceled request"
)
log
.
warn
in
g
(
"User: canceled request"
)
break
break
finally
:
finally
:
if
hasattr
(
r
,
"close"
):
if
hasattr
(
r
,
"close"
):
...
@@ -809,7 +812,7 @@ async def generate_chat_completion(
...
@@ -809,7 +812,7 @@ async def generate_chat_completion(
headers
=
dict
(
r
.
headers
),
headers
=
dict
(
r
.
headers
),
)
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
e
raise
e
try
:
try
:
...
@@ -863,7 +866,7 @@ async def generate_openai_chat_completion(
...
@@ -863,7 +866,7 @@ async def generate_openai_chat_completion(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -886,7 +889,7 @@ async def generate_openai_chat_completion(
...
@@ -886,7 +889,7 @@ async def generate_openai_chat_completion(
if
request_id
in
REQUEST_POOL
:
if
request_id
in
REQUEST_POOL
:
yield
chunk
yield
chunk
else
:
else
:
pr
in
t
(
"User: canceled request"
)
log
.
warn
in
g
(
"User: canceled request"
)
break
break
finally
:
finally
:
if
hasattr
(
r
,
"close"
):
if
hasattr
(
r
,
"close"
):
...
@@ -1184,7 +1187,7 @@ async def deprecated_proxy(path: str, request: Request, user=Depends(get_current
...
@@ -1184,7 +1187,7 @@ async def deprecated_proxy(path: str, request: Request, user=Depends(get_current
if
request_id
in
REQUEST_POOL
:
if
request_id
in
REQUEST_POOL
:
yield
chunk
yield
chunk
else
:
else
:
pr
in
t
(
"User: canceled request"
)
log
.
warn
in
g
(
"User: canceled request"
)
break
break
finally
:
finally
:
if
hasattr
(
r
,
"close"
):
if
hasattr
(
r
,
"close"
):
...
...
backend/apps/openai/main.py
View file @
ffa16821
...
@@ -6,6 +6,7 @@ import requests
...
@@ -6,6 +6,7 @@ import requests
import
aiohttp
import
aiohttp
import
asyncio
import
asyncio
import
json
import
json
import
logging
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
...
@@ -19,6 +20,7 @@ from utils.utils import (
...
@@ -19,6 +20,7 @@ from utils.utils import (
get_admin_user
,
get_admin_user
,
)
)
from
config
import
(
from
config
import
(
SRC_LOG_LEVELS
,
OPENAI_API_BASE_URLS
,
OPENAI_API_BASE_URLS
,
OPENAI_API_KEYS
,
OPENAI_API_KEYS
,
CACHE_DIR
,
CACHE_DIR
,
...
@@ -31,6 +33,9 @@ from typing import List, Optional
...
@@ -31,6 +33,9 @@ from typing import List, Optional
import
hashlib
import
hashlib
from
pathlib
import
Path
from
pathlib
import
Path
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"OPENAI"
])
app
=
FastAPI
()
app
=
FastAPI
()
app
.
add_middleware
(
app
.
add_middleware
(
CORSMiddleware
,
CORSMiddleware
,
...
@@ -134,7 +139,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
...
@@ -134,7 +139,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return
FileResponse
(
file_path
)
return
FileResponse
(
file_path
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -160,7 +165,7 @@ async def fetch_url(url, key):
...
@@ -160,7 +165,7 @@ async def fetch_url(url, key):
return
await
response
.
json
()
return
await
response
.
json
()
except
Exception
as
e
:
except
Exception
as
e
:
# Handle connection error here
# Handle connection error here
print
(
f
"Connection error:
{
e
}
"
)
log
.
error
(
f
"Connection error:
{
e
}
"
)
return
None
return
None
...
@@ -182,7 +187,7 @@ def merge_models_lists(model_lists):
...
@@ -182,7 +187,7 @@ def merge_models_lists(model_lists):
async
def
get_all_models
():
async
def
get_all_models
():
print
(
"get_all_models"
)
log
.
info
(
"get_all_models
()
"
)
if
len
(
app
.
state
.
OPENAI_API_KEYS
)
==
1
and
app
.
state
.
OPENAI_API_KEYS
[
0
]
==
""
:
if
len
(
app
.
state
.
OPENAI_API_KEYS
)
==
1
and
app
.
state
.
OPENAI_API_KEYS
[
0
]
==
""
:
models
=
{
"data"
:
[]}
models
=
{
"data"
:
[]}
...
@@ -208,7 +213,7 @@ async def get_all_models():
...
@@ -208,7 +213,7 @@ async def get_all_models():
)
)
}
}
print
(
models
)
log
.
info
(
f
"models:
{
models
}
"
)
app
.
state
.
MODELS
=
{
model
[
"id"
]:
model
for
model
in
models
[
"data"
]}
app
.
state
.
MODELS
=
{
model
[
"id"
]:
model
for
model
in
models
[
"data"
]}
return
models
return
models
...
@@ -246,7 +251,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
...
@@ -246,7 +251,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return
response_data
return
response_data
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -280,7 +285,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
...
@@ -280,7 +285,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if
body
.
get
(
"model"
)
==
"gpt-4-vision-preview"
:
if
body
.
get
(
"model"
)
==
"gpt-4-vision-preview"
:
if
"max_tokens"
not
in
body
:
if
"max_tokens"
not
in
body
:
body
[
"max_tokens"
]
=
4000
body
[
"max_tokens"
]
=
4000
print
(
"Modified body_dict:"
,
body
)
log
.
debug
(
"Modified body_dict:"
,
body
)
# Fix for ChatGPT calls failing because the num_ctx key is in body
# Fix for ChatGPT calls failing because the num_ctx key is in body
if
"num_ctx"
in
body
:
if
"num_ctx"
in
body
:
...
@@ -292,7 +297,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
...
@@ -292,7 +297,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
# Convert the modified body back to JSON
# Convert the modified body back to JSON
body
=
json
.
dumps
(
body
)
body
=
json
.
dumps
(
body
)
except
json
.
JSONDecodeError
as
e
:
except
json
.
JSONDecodeError
as
e
:
print
(
"Error loading request body into a dictionary:"
,
e
)
log
.
error
(
"Error loading request body into a dictionary:"
,
e
)
url
=
app
.
state
.
OPENAI_API_BASE_URLS
[
idx
]
url
=
app
.
state
.
OPENAI_API_BASE_URLS
[
idx
]
key
=
app
.
state
.
OPENAI_API_KEYS
[
idx
]
key
=
app
.
state
.
OPENAI_API_KEYS
[
idx
]
...
@@ -330,7 +335,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
...
@@ -330,7 +335,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
response_data
=
r
.
json
()
response_data
=
r
.
json
()
return
response_data
return
response_data
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
...
backend/apps/rag/main.py
View file @
ffa16821
...
@@ -8,7 +8,7 @@ from fastapi import (
...
@@ -8,7 +8,7 @@ from fastapi import (
Form
,
Form
,
)
)
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
import
os
,
shutil
import
os
,
shutil
,
logging
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
from
typing
import
List
...
@@ -54,6 +54,7 @@ from utils.misc import (
...
@@ -54,6 +54,7 @@ from utils.misc import (
)
)
from
utils.utils
import
get_current_user
,
get_admin_user
from
utils.utils
import
get_current_user
,
get_admin_user
from
config
import
(
from
config
import
(
SRC_LOG_LEVELS
,
UPLOAD_DIR
,
UPLOAD_DIR
,
DOCS_DIR
,
DOCS_DIR
,
RAG_EMBEDDING_MODEL
,
RAG_EMBEDDING_MODEL
,
...
@@ -66,6 +67,9 @@ from config import (
...
@@ -66,6 +67,9 @@ from config import (
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
#
#
# if RAG_EMBEDDING_MODEL:
# if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer(
# sentence_transformer_ef = SentenceTransformer(
...
@@ -110,7 +114,6 @@ class CollectionNameForm(BaseModel):
...
@@ -110,7 +114,6 @@ class CollectionNameForm(BaseModel):
class
StoreWebForm
(
CollectionNameForm
):
class
StoreWebForm
(
CollectionNameForm
):
url
:
str
url
:
str
@
app
.
get
(
"/"
)
@
app
.
get
(
"/"
)
async
def
get_status
():
async
def
get_status
():
return
{
return
{
...
@@ -241,7 +244,7 @@ def query_doc_handler(
...
@@ -241,7 +244,7 @@ def query_doc_handler(
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
...
@@ -285,7 +288,7 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
...
@@ -285,7 +288,7 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
"filename"
:
form_data
.
url
,
"filename"
:
form_data
.
url
,
}
}
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
...
@@ -433,7 +436,7 @@ def store_doc(
...
@@ -433,7 +436,7 @@ def store_doc(
):
):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
print
(
file
.
content_type
)
log
.
info
(
f
"file.content_type:
{
file
.
content_type
}
"
)
try
:
try
:
filename
=
file
.
filename
filename
=
file
.
filename
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
filename
}
"
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
filename
}
"
...
@@ -464,7 +467,7 @@ def store_doc(
...
@@ -464,7 +467,7 @@ def store_doc(
detail
=
ERROR_MESSAGES
.
DEFAULT
(),
detail
=
ERROR_MESSAGES
.
DEFAULT
(),
)
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
if
"No pandoc was found"
in
str
(
e
):
if
"No pandoc was found"
in
str
(
e
):
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
...
@@ -560,7 +563,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
...
@@ -560,7 +563,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
)
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
return
True
return
True
...
@@ -581,11 +584,11 @@ def reset(user=Depends(get_admin_user)) -> bool:
...
@@ -581,11 +584,11 @@ def reset(user=Depends(get_admin_user)) -> bool:
elif
os
.
path
.
isdir
(
file_path
):
elif
os
.
path
.
isdir
(
file_path
):
shutil
.
rmtree
(
file_path
)
shutil
.
rmtree
(
file_path
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"Failed to delete %s. Reason: %s"
%
(
file_path
,
e
))
log
.
error
(
"Failed to delete %s. Reason: %s"
%
(
file_path
,
e
))
try
:
try
:
CHROMA_CLIENT
.
reset
()
CHROMA_CLIENT
.
reset
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
return
True
return
True
backend/apps/rag/utils.py
View file @
ffa16821
import
re
import
re
import
logging
from
typing
import
List
from
typing
import
List
from
config
import
CHROMA_CLIENT
from
config
import
SRC_LOG_LEVELS
,
CHROMA_CLIENT
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
def
query_doc
(
collection_name
:
str
,
query
:
str
,
k
:
int
,
embedding_function
):
def
query_doc
(
collection_name
:
str
,
query
:
str
,
k
:
int
,
embedding_function
):
...
@@ -97,7 +101,7 @@ def rag_template(template: str, context: str, query: str):
...
@@ -97,7 +101,7 @@ def rag_template(template: str, context: str, query: str):
def
rag_messages
(
docs
,
messages
,
template
,
k
,
embedding_function
):
def
rag_messages
(
docs
,
messages
,
template
,
k
,
embedding_function
):
print
(
docs
)
log
.
debug
(
f
"docs:
{
docs
}
"
)
last_user_message_idx
=
None
last_user_message_idx
=
None
for
i
in
range
(
len
(
messages
)
-
1
,
-
1
,
-
1
):
for
i
in
range
(
len
(
messages
)
-
1
,
-
1
,
-
1
):
...
@@ -147,7 +151,7 @@ def rag_messages(docs, messages, template, k, embedding_function):
...
@@ -147,7 +151,7 @@ def rag_messages(docs, messages, template, k, embedding_function):
embedding_function
=
embedding_function
,
embedding_function
=
embedding_function
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
context
=
None
context
=
None
relevant_contexts
.
append
(
context
)
relevant_contexts
.
append
(
context
)
...
...
backend/apps/web/internal/db.py
View file @
ffa16821
from
peewee
import
*
from
peewee
import
*
from
config
import
DATA_DIR
from
config
import
SRC_LOG_LEVELS
,
DATA_DIR
import
os
import
os
import
logging
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"DB"
])
# Check if the file exists
# Check if the file exists
if
os
.
path
.
exists
(
f
"
{
DATA_DIR
}
/ollama.db"
):
if
os
.
path
.
exists
(
f
"
{
DATA_DIR
}
/ollama.db"
):
# Rename the file
# Rename the file
os
.
rename
(
f
"
{
DATA_DIR
}
/ollama.db"
,
f
"
{
DATA_DIR
}
/webui.db"
)
os
.
rename
(
f
"
{
DATA_DIR
}
/ollama.db"
,
f
"
{
DATA_DIR
}
/webui.db"
)
print
(
"File renamed successfully."
)
log
.
info
(
"File renamed successfully."
)
else
:
else
:
pass
pass
...
...
backend/apps/web/models/auths.py
View file @
ffa16821
...
@@ -2,6 +2,7 @@ from pydantic import BaseModel
...
@@ -2,6 +2,7 @@ from pydantic import BaseModel
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
import
time
import
time
import
uuid
import
uuid
import
logging
from
peewee
import
*
from
peewee
import
*
from
apps.web.models.users
import
UserModel
,
Users
from
apps.web.models.users
import
UserModel
,
Users
...
@@ -9,6 +10,10 @@ from utils.utils import verify_password
...
@@ -9,6 +10,10 @@ from utils.utils import verify_password
from
apps.web.internal.db
import
DB
from
apps.web.internal.db
import
DB
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
####################
####################
# DB MODEL
# DB MODEL
####################
####################
...
@@ -86,7 +91,7 @@ class AuthsTable:
...
@@ -86,7 +91,7 @@ class AuthsTable:
def
insert_new_auth
(
def
insert_new_auth
(
self
,
email
:
str
,
password
:
str
,
name
:
str
,
role
:
str
=
"pending"
self
,
email
:
str
,
password
:
str
,
name
:
str
,
role
:
str
=
"pending"
)
->
Optional
[
UserModel
]:
)
->
Optional
[
UserModel
]:
print
(
"insert_new_auth"
)
log
.
info
(
"insert_new_auth"
)
id
=
str
(
uuid
.
uuid4
())
id
=
str
(
uuid
.
uuid4
())
...
@@ -103,7 +108,7 @@ class AuthsTable:
...
@@ -103,7 +108,7 @@ class AuthsTable:
return
None
return
None
def
authenticate_user
(
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
def
authenticate_user
(
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
print
(
"authenticate_user
"
,
email
)
log
.
info
(
f
"authenticate_user
:
{
email
}
"
)
try
:
try
:
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
if
auth
:
if
auth
:
...
...
backend/apps/web/models/documents.py
View file @
ffa16821
...
@@ -3,6 +3,7 @@ from peewee import *
...
@@ -3,6 +3,7 @@ from peewee import *
from
playhouse.shortcuts
import
model_to_dict
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
import
time
import
time
import
logging
from
utils.utils
import
decode_token
from
utils.utils
import
decode_token
from
utils.misc
import
get_gravatar_url
from
utils.misc
import
get_gravatar_url
...
@@ -11,6 +12,10 @@ from apps.web.internal.db import DB
...
@@ -11,6 +12,10 @@ from apps.web.internal.db import DB
import
json
import
json
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
####################
####################
# Documents DB Schema
# Documents DB Schema
####################
####################
...
@@ -118,7 +123,7 @@ class DocumentsTable:
...
@@ -118,7 +123,7 @@ class DocumentsTable:
doc
=
Document
.
get
(
Document
.
name
==
form_data
.
name
)
doc
=
Document
.
get
(
Document
.
name
==
form_data
.
name
)
return
DocumentModel
(
**
model_to_dict
(
doc
))
return
DocumentModel
(
**
model_to_dict
(
doc
))
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
return
None
return
None
def
update_doc_content_by_name
(
def
update_doc_content_by_name
(
...
@@ -138,7 +143,7 @@ class DocumentsTable:
...
@@ -138,7 +143,7 @@ class DocumentsTable:
doc
=
Document
.
get
(
Document
.
name
==
name
)
doc
=
Document
.
get
(
Document
.
name
==
name
)
return
DocumentModel
(
**
model_to_dict
(
doc
))
return
DocumentModel
(
**
model_to_dict
(
doc
))
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
return
None
return
None
def
delete_doc_by_name
(
self
,
name
:
str
)
->
bool
:
def
delete_doc_by_name
(
self
,
name
:
str
)
->
bool
:
...
...
backend/apps/web/models/tags.py
View file @
ffa16821
...
@@ -6,9 +6,14 @@ from playhouse.shortcuts import model_to_dict
...
@@ -6,9 +6,14 @@ from playhouse.shortcuts import model_to_dict
import
json
import
json
import
uuid
import
uuid
import
time
import
time
import
logging
from
apps.web.internal.db
import
DB
from
apps.web.internal.db
import
DB
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
####################
####################
# Tag DB Schema
# Tag DB Schema
####################
####################
...
@@ -173,7 +178,7 @@ class TagTable:
...
@@ -173,7 +178,7 @@ class TagTable:
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
)
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
)
)
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
print
(
res
)
log
.
debug
(
f
"res:
{
res
}
"
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
if
tag_count
==
0
:
...
@@ -185,7 +190,7 @@ class TagTable:
...
@@ -185,7 +190,7 @@ class TagTable:
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"delete_tag
"
,
e
)
log
.
error
(
f
"delete_tag
:
{
e
}
"
)
return
False
return
False
def
delete_tag_by_tag_name_and_chat_id_and_user_id
(
def
delete_tag_by_tag_name_and_chat_id_and_user_id
(
...
@@ -198,7 +203,7 @@ class TagTable:
...
@@ -198,7 +203,7 @@ class TagTable:
&
(
ChatIdTag
.
user_id
==
user_id
)
&
(
ChatIdTag
.
user_id
==
user_id
)
)
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
print
(
res
)
log
.
debug
(
f
"res:
{
res
}
"
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
if
tag_count
==
0
:
...
@@ -210,7 +215,7 @@ class TagTable:
...
@@ -210,7 +215,7 @@ class TagTable:
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"delete_tag
"
,
e
)
log
.
error
(
f
"delete_tag
:
{
e
}
"
)
return
False
return
False
def
delete_tags_by_chat_id_and_user_id
(
self
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
def
delete_tags_by_chat_id_and_user_id
(
self
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
...
...
backend/apps/web/routers/chats.py
View file @
ffa16821
...
@@ -5,6 +5,7 @@ from utils.utils import get_current_user, get_admin_user
...
@@ -5,6 +5,7 @@ from utils.utils import get_current_user, get_admin_user
from
fastapi
import
APIRouter
from
fastapi
import
APIRouter
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
import
json
import
json
import
logging
from
apps.web.models.users
import
Users
from
apps.web.models.users
import
Users
from
apps.web.models.chats
import
(
from
apps.web.models.chats
import
(
...
@@ -27,6 +28,10 @@ from apps.web.models.tags import (
...
@@ -27,6 +28,10 @@ from apps.web.models.tags import (
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
router
=
APIRouter
()
router
=
APIRouter
()
############################
############################
...
@@ -78,7 +83,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
...
@@ -78,7 +83,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
chat
=
Chats
.
insert_new_chat
(
user
.
id
,
form_data
)
chat
=
Chats
.
insert_new_chat
(
user
.
id
,
form_data
)
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
()
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
()
)
)
...
@@ -95,7 +100,7 @@ async def get_all_tags(user=Depends(get_current_user)):
...
@@ -95,7 +100,7 @@ async def get_all_tags(user=Depends(get_current_user)):
tags
=
Tags
.
get_tags_by_user_id
(
user
.
id
)
tags
=
Tags
.
get_tags_by_user_id
(
user
.
id
)
return
tags
return
tags
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
()
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
()
)
)
...
...
backend/apps/web/routers/users.py
View file @
ffa16821
...
@@ -7,6 +7,7 @@ from fastapi import APIRouter
...
@@ -7,6 +7,7 @@ from fastapi import APIRouter
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
import
time
import
time
import
uuid
import
uuid
import
logging
from
apps.web.models.users
import
UserModel
,
UserUpdateForm
,
UserRoleUpdateForm
,
Users
from
apps.web.models.users
import
UserModel
,
UserUpdateForm
,
UserRoleUpdateForm
,
Users
from
apps.web.models.auths
import
Auths
from
apps.web.models.auths
import
Auths
...
@@ -14,6 +15,10 @@ from apps.web.models.auths import Auths
...
@@ -14,6 +15,10 @@ from apps.web.models.auths import Auths
from
utils.utils
import
get_current_user
,
get_password_hash
,
get_admin_user
from
utils.utils
import
get_current_user
,
get_password_hash
,
get_admin_user
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
router
=
APIRouter
()
router
=
APIRouter
()
############################
############################
...
@@ -83,7 +88,7 @@ async def update_user_by_id(
...
@@ -83,7 +88,7 @@ async def update_user_by_id(
if
form_data
.
password
:
if
form_data
.
password
:
hashed
=
get_password_hash
(
form_data
.
password
)
hashed
=
get_password_hash
(
form_data
.
password
)
print
(
hashed
)
log
.
debug
(
f
"hashed:
{
hashed
}
"
)
Auths
.
update_user_password_by_id
(
user_id
,
hashed
)
Auths
.
update_user_password_by_id
(
user_id
,
hashed
)
Auths
.
update_email_by_id
(
user_id
,
form_data
.
email
.
lower
())
Auths
.
update_email_by_id
(
user_id
,
form_data
.
email
.
lower
())
...
...
backend/config.py
View file @
ffa16821
import
os
import
os
import
sys
import
logging
import
chromadb
import
chromadb
from
chromadb
import
Settings
from
chromadb
import
Settings
from
base64
import
b64encode
from
base64
import
b64encode
...
@@ -21,7 +23,7 @@ try:
...
@@ -21,7 +23,7 @@ try:
load_dotenv
(
find_dotenv
(
"../.env"
))
load_dotenv
(
find_dotenv
(
"../.env"
))
except
ImportError
:
except
ImportError
:
pr
in
t
(
"dotenv not installed, skipping..."
)
log
.
warn
in
g
(
"dotenv not installed, skipping..."
)
WEBUI_NAME
=
"Open WebUI"
WEBUI_NAME
=
"Open WebUI"
shutil
.
copyfile
(
"../build/favicon.png"
,
"./static/favicon.png"
)
shutil
.
copyfile
(
"../build/favicon.png"
,
"./static/favicon.png"
)
...
@@ -100,6 +102,34 @@ for version in soup.find_all("h2"):
...
@@ -100,6 +102,34 @@ for version in soup.find_all("h2"):
CHANGELOG
=
changelog_json
CHANGELOG
=
changelog_json
####################################
# LOGGING
####################################
log_levels
=
[
"CRITICAL"
,
"ERROR"
,
"WARNING"
,
"INFO"
,
"DEBUG"
]
GLOBAL_LOG_LEVEL
=
os
.
environ
.
get
(
"GLOBAL_LOG_LEVEL"
,
""
).
upper
()
if
GLOBAL_LOG_LEVEL
in
log_levels
:
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
GLOBAL_LOG_LEVEL
,
force
=
True
)
else
:
GLOBAL_LOG_LEVEL
=
"INFO"
log
=
logging
.
getLogger
(
__name__
)
log
.
info
(
f
"GLOBAL_LOG_LEVEL:
{
GLOBAL_LOG_LEVEL
}
"
)
log_sources
=
[
"AUDIO"
,
"CONFIG"
,
"DB"
,
"IMAGES"
,
"LITELLM"
,
"MAIN"
,
"MODELS"
,
"OLLAMA"
,
"OPENAI"
,
"RAG"
]
SRC_LOG_LEVELS
=
{}
for
source
in
log_sources
:
log_env_var
=
source
+
"_LOG_LEVEL"
SRC_LOG_LEVELS
[
source
]
=
os
.
environ
.
get
(
log_env_var
,
""
).
upper
()
if
SRC_LOG_LEVELS
[
source
]
not
in
log_levels
:
SRC_LOG_LEVELS
[
source
]
=
GLOBAL_LOG_LEVEL
log
.
info
(
f
"
{
log_env_var
}
:
{
SRC_LOG_LEVELS
[
source
]
}
"
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"CONFIG"
])
####################################
####################################
# CUSTOM_NAME
# CUSTOM_NAME
####################################
####################################
...
@@ -125,7 +155,7 @@ if CUSTOM_NAME:
...
@@ -125,7 +155,7 @@ if CUSTOM_NAME:
WEBUI_NAME
=
data
[
"name"
]
WEBUI_NAME
=
data
[
"name"
]
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
pass
pass
...
@@ -194,9 +224,9 @@ def create_config_file(file_path):
...
@@ -194,9 +224,9 @@ def create_config_file(file_path):
LITELLM_CONFIG_PATH
=
f
"
{
DATA_DIR
}
/litellm/config.yaml"
LITELLM_CONFIG_PATH
=
f
"
{
DATA_DIR
}
/litellm/config.yaml"
if
not
os
.
path
.
exists
(
LITELLM_CONFIG_PATH
):
if
not
os
.
path
.
exists
(
LITELLM_CONFIG_PATH
):
print
(
"Config file doesn't exist. Creating..."
)
log
.
info
(
"Config file doesn't exist. Creating..."
)
create_config_file
(
LITELLM_CONFIG_PATH
)
create_config_file
(
LITELLM_CONFIG_PATH
)
print
(
"Config file created successfully."
)
log
.
info
(
"Config file created successfully."
)
####################################
####################################
...
...
backend/main.py
View file @
ffa16821
...
@@ -4,6 +4,7 @@ import markdown
...
@@ -4,6 +4,7 @@ import markdown
import
time
import
time
import
os
import
os
import
sys
import
sys
import
logging
import
requests
import
requests
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
...
@@ -38,10 +39,15 @@ from config import (
...
@@ -38,10 +39,15 @@ from config import (
FRONTEND_BUILD_DIR
,
FRONTEND_BUILD_DIR
,
MODEL_FILTER_ENABLED
,
MODEL_FILTER_ENABLED
,
MODEL_FILTER_LIST
,
MODEL_FILTER_LIST
,
GLOBAL_LOG_LEVEL
,
SRC_LOG_LEVELS
,
WEBHOOK_URL
,
WEBHOOK_URL
,
)
)
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
GLOBAL_LOG_LEVEL
)
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MAIN"
])
class
SPAStaticFiles
(
StaticFiles
):
class
SPAStaticFiles
(
StaticFiles
):
async
def
get_response
(
self
,
path
:
str
,
scope
):
async
def
get_response
(
self
,
path
:
str
,
scope
):
...
@@ -70,7 +76,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
...
@@ -70,7 +76,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
if
request
.
method
==
"POST"
and
(
if
request
.
method
==
"POST"
and
(
"/api/chat"
in
request
.
url
.
path
or
"/chat/completions"
in
request
.
url
.
path
"/api/chat"
in
request
.
url
.
path
or
"/chat/completions"
in
request
.
url
.
path
):
):
print
(
request
.
url
.
path
)
log
.
debug
(
f
"request.url.path:
{
request
.
url
.
path
}
"
)
# Read the original request body
# Read the original request body
body
=
await
request
.
body
()
body
=
await
request
.
body
()
...
@@ -93,7 +99,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
...
@@ -93,7 +99,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
)
)
del
data
[
"docs"
]
del
data
[
"docs"
]
print
(
data
[
"
messages"
]
)
log
.
debug
(
f
"data['messages']:
{
data
[
'
messages
'
]
}
"
)
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment