Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
ac294a74
Unverified
Commit
ac294a74
authored
Mar 24, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Mar 24, 2024
Browse files
Merge pull request #1277 from open-webui/dev
0.1.115
parents
2fa94956
4c959891
Changes
53
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
845 additions
and
292 deletions
+845
-292
.gitignore
.gitignore
+1
-1
CHANGELOG.md
CHANGELOG.md
+18
-0
backend/apps/audio/main.py
backend/apps/audio/main.py
+8
-4
backend/apps/images/main.py
backend/apps/images/main.py
+107
-13
backend/apps/images/utils/comfyui.py
backend/apps/images/utils/comfyui.py
+228
-0
backend/apps/litellm/main.py
backend/apps/litellm/main.py
+7
-2
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+277
-37
backend/apps/openai/main.py
backend/apps/openai/main.py
+13
-8
backend/apps/rag/main.py
backend/apps/rag/main.py
+93
-42
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+9
-3
backend/apps/web/internal/db.py
backend/apps/web/internal/db.py
+5
-2
backend/apps/web/models/auths.py
backend/apps/web/models/auths.py
+7
-2
backend/apps/web/models/chats.py
backend/apps/web/models/chats.py
+0
-14
backend/apps/web/models/documents.py
backend/apps/web/models/documents.py
+7
-2
backend/apps/web/models/tags.py
backend/apps/web/models/tags.py
+9
-4
backend/apps/web/routers/chats.py
backend/apps/web/routers/chats.py
+7
-2
backend/apps/web/routers/users.py
backend/apps/web/routers/users.py
+6
-1
backend/apps/web/routers/utils.py
backend/apps/web/routers/utils.py
+0
-149
backend/config.py
backend/config.py
+35
-4
backend/main.py
backend/main.py
+8
-2
No files found.
.gitignore
View file @
ac294a74
...
...
@@ -166,7 +166,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#
.idea/
.idea/
# Logs
logs
...
...
CHANGELOG.md
View file @
ac294a74
...
...
@@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file.
The format is based on
[
Keep a Changelog
](
https://keepachangelog.com/en/1.1.0/
)
,
and this project adheres to
[
Semantic Versioning
](
https://semver.org/spec/v2.0.0.html
)
.
## [0.1.115] - 2024-03-24
### Added
-
**🔍 Custom Model Selector**
: Easily find and select custom models with the new search filter feature.
-
**🛑 Cancel Model Download**
: Added the ability to cancel model downloads.
-
**🎨 Image Generation ComfyUI**
: Image generation now supports ComfyUI.
-
**🌟 Updated Light Theme**
: Updated the light theme for a fresh look.
-
**🌍 Additional Language Support**
: Now supporting Bulgarian, Italian, Portuguese, Japanese, and Dutch.
### Fixed
-
**🔧 Fixed Broken Experimental GGUF Upload**
: Resolved issues with experimental GGUF upload functionality.
### Changed
-
**🔄 Vector Storage Reset Button**
: Moved the reset vector storage button to document settings.
## [0.1.114] - 2024-03-20
### Added
...
...
backend/apps/audio/main.py
View file @
ac294a74
import
os
import
logging
from
fastapi
import
(
FastAPI
,
Request
,
...
...
@@ -21,7 +22,10 @@ from utils.utils import (
)
from
utils.misc
import
calculate_sha256
from
config
import
CACHE_DIR
,
UPLOAD_DIR
,
WHISPER_MODEL
,
WHISPER_MODEL_DIR
from
config
import
SRC_LOG_LEVELS
,
CACHE_DIR
,
UPLOAD_DIR
,
WHISPER_MODEL
,
WHISPER_MODEL_DIR
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"AUDIO"
])
app
=
FastAPI
()
app
.
add_middleware
(
...
...
@@ -38,7 +42,7 @@ def transcribe(
file
:
UploadFile
=
File
(...),
user
=
Depends
(
get_current_user
),
):
print
(
file
.
content_type
)
log
.
info
(
f
"file.content_type:
{
file
.
content_type
}
"
)
if
file
.
content_type
not
in
[
"audio/mpeg"
,
"audio/wav"
]:
raise
HTTPException
(
...
...
@@ -62,7 +66,7 @@ def transcribe(
)
segments
,
info
=
model
.
transcribe
(
file_path
,
beam_size
=
5
)
print
(
log
.
info
(
"Detected language '%s' with probability %f"
%
(
info
.
language
,
info
.
language_probability
)
)
...
...
@@ -72,7 +76,7 @@ def transcribe(
return
{
"text"
:
transcript
.
strip
()}
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
...
...
backend/apps/images/main.py
View file @
ac294a74
...
...
@@ -18,6 +18,8 @@ from utils.utils import (
get_current_user
,
get_admin_user
,
)
from
apps.images.utils.comfyui
import
ImageGenerationPayload
,
comfyui_generate_image
from
utils.misc
import
calculate_sha256
from
typing
import
Optional
from
pydantic
import
BaseModel
...
...
@@ -25,9 +27,13 @@ from pathlib import Path
import
uuid
import
base64
import
json
import
logging
from
config
import
SRC_LOG_LEVELS
,
CACHE_DIR
,
AUTOMATIC1111_BASE_URL
,
COMFYUI_BASE_URL
from
config
import
CACHE_DIR
,
AUTOMATIC1111_BASE_URL
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"IMAGES"
])
IMAGE_CACHE_DIR
=
Path
(
CACHE_DIR
).
joinpath
(
"./image/generations/"
)
IMAGE_CACHE_DIR
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
...
@@ -49,6 +55,8 @@ app.state.MODEL = ""
app
.
state
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
app
.
state
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
app
.
state
.
IMAGE_SIZE
=
"512x512"
app
.
state
.
IMAGE_STEPS
=
50
...
...
@@ -71,32 +79,48 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user
return
{
"engine"
:
app
.
state
.
ENGINE
,
"enabled"
:
app
.
state
.
ENABLED
}
class
UrlUpdateForm
(
BaseModel
):
url
:
str
class
EngineUrlUpdateForm
(
BaseModel
):
AUTOMATIC1111_BASE_URL
:
Optional
[
str
]
=
None
COMFYUI_BASE_URL
:
Optional
[
str
]
=
None
@
app
.
get
(
"/url"
)
async
def
get_automatic1111_url
(
user
=
Depends
(
get_admin_user
)):
return
{
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
AUTOMATIC1111_BASE_URL
}
async
def
get_engine_url
(
user
=
Depends
(
get_admin_user
)):
return
{
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
app
.
state
.
COMFYUI_BASE_URL
,
}
@
app
.
post
(
"/url/update"
)
async
def
update_
automatic1111
_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)
async
def
update_
engine
_url
(
form_data
:
Engine
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
if
form_data
.
url
==
""
:
if
form_data
.
AUTOMATIC1111_BASE_URL
==
None
:
app
.
state
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
else
:
url
=
form_data
.
url
.
strip
(
"/"
)
url
=
form_data
.
AUTOMATIC1111_BASE_URL
.
strip
(
"/"
)
try
:
r
=
requests
.
head
(
url
)
app
.
state
.
AUTOMATIC1111_BASE_URL
=
url
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
if
form_data
.
COMFYUI_BASE_URL
==
None
:
app
.
state
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
else
:
url
=
form_data
.
COMFYUI_BASE_URL
.
strip
(
"/"
)
try
:
r
=
requests
.
head
(
url
)
app
.
state
.
COMFYUI_BASE_URL
=
url
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
return
{
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
app
.
state
.
COMFYUI_BASE_URL
,
"status"
:
True
,
}
...
...
@@ -186,6 +210,18 @@ def get_models(user=Depends(get_current_user)):
{
"id"
:
"dall-e-2"
,
"name"
:
"DALL·E 2"
},
{
"id"
:
"dall-e-3"
,
"name"
:
"DALL·E 3"
},
]
elif
app
.
state
.
ENGINE
==
"comfyui"
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
COMFYUI_BASE_URL
}
/object_info"
)
info
=
r
.
json
()
return
list
(
map
(
lambda
model
:
{
"id"
:
model
,
"name"
:
model
},
info
[
"CheckpointLoaderSimple"
][
"input"
][
"required"
][
"ckpt_name"
][
0
],
)
)
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/sd-models"
...
...
@@ -207,6 +243,8 @@ async def get_default_model(user=Depends(get_admin_user)):
try
:
if
app
.
state
.
ENGINE
==
"openai"
:
return
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
else
"dall-e-2"
}
elif
app
.
state
.
ENGINE
==
"comfyui"
:
return
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
else
""
}
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
...
...
@@ -221,10 +259,12 @@ class UpdateModelForm(BaseModel):
def
set_model_handler
(
model
:
str
):
if
app
.
state
.
ENGINE
==
"openai"
:
app
.
state
.
MODEL
=
model
return
app
.
state
.
MODEL
if
app
.
state
.
ENGINE
==
"comfyui"
:
app
.
state
.
MODEL
=
model
return
app
.
state
.
MODEL
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
...
...
@@ -266,6 +306,23 @@ def save_b64_image(b64_str):
with
open
(
file_path
,
"wb"
)
as
f
:
f
.
write
(
img_data
)
return
image_id
except
Exception
as
e
:
log
.
error
(
f
"Error saving image:
{
e
}
"
)
return
None
def
save_url_image
(
url
):
image_id
=
str
(
uuid
.
uuid4
())
file_path
=
IMAGE_CACHE_DIR
.
joinpath
(
f
"
{
image_id
}
.png"
)
try
:
r
=
requests
.
get
(
url
)
r
.
raise_for_status
()
with
open
(
file_path
,
"wb"
)
as
image_file
:
image_file
.
write
(
r
.
content
)
return
image_id
except
Exception
as
e
:
print
(
f
"Error saving image:
{
e
}
"
)
...
...
@@ -278,6 +335,8 @@ def generate_image(
user
=
Depends
(
get_current_user
),
):
width
,
height
=
tuple
(
map
(
int
,
app
.
state
.
IMAGE_SIZE
.
split
(
"x"
)))
r
=
None
try
:
if
app
.
state
.
ENGINE
==
"openai"
:
...
...
@@ -315,12 +374,47 @@ def generate_image(
return
images
elif
app
.
state
.
ENGINE
==
"comfyui"
:
data
=
{
"prompt"
:
form_data
.
prompt
,
"width"
:
width
,
"height"
:
height
,
"n"
:
form_data
.
n
,
}
if
app
.
state
.
IMAGE_STEPS
!=
None
:
data
[
"steps"
]
=
app
.
state
.
IMAGE_STEPS
if
form_data
.
negative_prompt
!=
None
:
data
[
"negative_prompt"
]
=
form_data
.
negative_prompt
data
=
ImageGenerationPayload
(
**
data
)
res
=
comfyui_generate_image
(
app
.
state
.
MODEL
,
data
,
user
.
id
,
app
.
state
.
COMFYUI_BASE_URL
,
)
print
(
res
)
images
=
[]
for
image
in
res
[
"data"
]:
image_id
=
save_url_image
(
image
[
"url"
])
images
.
append
({
"url"
:
f
"/cache/image/generations/
{
image_id
}
.png"
})
file_body_path
=
IMAGE_CACHE_DIR
.
joinpath
(
f
"
{
image_id
}
.json"
)
with
open
(
file_body_path
,
"w"
)
as
f
:
json
.
dump
(
data
.
model_dump
(
exclude_none
=
True
),
f
)
print
(
images
)
return
images
else
:
if
form_data
.
model
:
set_model_handler
(
form_data
.
model
)
width
,
height
=
tuple
(
map
(
int
,
app
.
state
.
IMAGE_SIZE
.
split
(
"x"
)))
data
=
{
"prompt"
:
form_data
.
prompt
,
"batch_size"
:
form_data
.
n
,
...
...
@@ -341,7 +435,7 @@ def generate_image(
res
=
r
.
json
()
print
(
res
)
log
.
debug
(
f
"res:
{
res
}
"
)
images
=
[]
...
...
backend/apps/images/utils/comfyui.py
0 → 100644
View file @
ac294a74
import
websocket
# NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import
uuid
import
json
import
urllib.request
import
urllib.parse
import
random
from
pydantic
import
BaseModel
from
typing
import
Optional
COMFYUI_DEFAULT_PROMPT
=
"""
{
"3": {
"inputs": {
"seed": 0,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "model.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "Negative Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
"""
def
queue_prompt
(
prompt
,
client_id
,
base_url
):
print
(
"queue_prompt"
)
p
=
{
"prompt"
:
prompt
,
"client_id"
:
client_id
}
data
=
json
.
dumps
(
p
).
encode
(
"utf-8"
)
req
=
urllib
.
request
.
Request
(
f
"
{
base_url
}
/prompt"
,
data
=
data
)
return
json
.
loads
(
urllib
.
request
.
urlopen
(
req
).
read
())
def
get_image
(
filename
,
subfolder
,
folder_type
,
base_url
):
print
(
"get_image"
)
data
=
{
"filename"
:
filename
,
"subfolder"
:
subfolder
,
"type"
:
folder_type
}
url_values
=
urllib
.
parse
.
urlencode
(
data
)
with
urllib
.
request
.
urlopen
(
f
"
{
base_url
}
/view?
{
url_values
}
"
)
as
response
:
return
response
.
read
()
def
get_image_url
(
filename
,
subfolder
,
folder_type
,
base_url
):
print
(
"get_image"
)
data
=
{
"filename"
:
filename
,
"subfolder"
:
subfolder
,
"type"
:
folder_type
}
url_values
=
urllib
.
parse
.
urlencode
(
data
)
return
f
"
{
base_url
}
/view?
{
url_values
}
"
def
get_history
(
prompt_id
,
base_url
):
print
(
"get_history"
)
with
urllib
.
request
.
urlopen
(
f
"
{
base_url
}
/history/
{
prompt_id
}
"
)
as
response
:
return
json
.
loads
(
response
.
read
())
def
get_images
(
ws
,
prompt
,
client_id
,
base_url
):
prompt_id
=
queue_prompt
(
prompt
,
client_id
,
base_url
)[
"prompt_id"
]
output_images
=
[]
while
True
:
out
=
ws
.
recv
()
if
isinstance
(
out
,
str
):
message
=
json
.
loads
(
out
)
if
message
[
"type"
]
==
"executing"
:
data
=
message
[
"data"
]
if
data
[
"node"
]
is
None
and
data
[
"prompt_id"
]
==
prompt_id
:
break
# Execution is done
else
:
continue
# previews are binary data
history
=
get_history
(
prompt_id
,
base_url
)[
prompt_id
]
for
o
in
history
[
"outputs"
]:
for
node_id
in
history
[
"outputs"
]:
node_output
=
history
[
"outputs"
][
node_id
]
if
"images"
in
node_output
:
for
image
in
node_output
[
"images"
]:
url
=
get_image_url
(
image
[
"filename"
],
image
[
"subfolder"
],
image
[
"type"
],
base_url
)
output_images
.
append
({
"url"
:
url
})
return
{
"data"
:
output_images
}
class
ImageGenerationPayload
(
BaseModel
):
prompt
:
str
negative_prompt
:
Optional
[
str
]
=
""
steps
:
Optional
[
int
]
=
None
seed
:
Optional
[
int
]
=
None
width
:
int
height
:
int
n
:
int
=
1
def
comfyui_generate_image
(
model
:
str
,
payload
:
ImageGenerationPayload
,
client_id
,
base_url
):
host
=
base_url
.
replace
(
"http://"
,
""
).
replace
(
"https://"
,
""
)
comfyui_prompt
=
json
.
loads
(
COMFYUI_DEFAULT_PROMPT
)
comfyui_prompt
[
"4"
][
"inputs"
][
"ckpt_name"
]
=
model
comfyui_prompt
[
"5"
][
"inputs"
][
"batch_size"
]
=
payload
.
n
comfyui_prompt
[
"5"
][
"inputs"
][
"width"
]
=
payload
.
width
comfyui_prompt
[
"5"
][
"inputs"
][
"height"
]
=
payload
.
height
# set the text prompt for our positive CLIPTextEncode
comfyui_prompt
[
"6"
][
"inputs"
][
"text"
]
=
payload
.
prompt
comfyui_prompt
[
"7"
][
"inputs"
][
"text"
]
=
payload
.
negative_prompt
if
payload
.
steps
:
comfyui_prompt
[
"3"
][
"inputs"
][
"steps"
]
=
payload
.
steps
comfyui_prompt
[
"3"
][
"inputs"
][
"seed"
]
=
(
payload
.
seed
if
payload
.
seed
else
random
.
randint
(
0
,
18446744073709551614
)
)
try
:
ws
=
websocket
.
WebSocket
()
ws
.
connect
(
f
"ws://
{
host
}
/ws?clientId=
{
client_id
}
"
)
print
(
"WebSocket connection established."
)
except
Exception
as
e
:
print
(
f
"Failed to connect to WebSocket server:
{
e
}
"
)
return
None
try
:
images
=
get_images
(
ws
,
comfyui_prompt
,
client_id
,
base_url
)
except
Exception
as
e
:
print
(
f
"Error while receiving images:
{
e
}
"
)
images
=
None
ws
.
close
()
return
images
backend/apps/litellm/main.py
View file @
ac294a74
import
logging
from
litellm.proxy.proxy_server
import
ProxyConfig
,
initialize
from
litellm.proxy.proxy_server
import
app
...
...
@@ -9,7 +11,10 @@ from starlette.responses import StreamingResponse
import
json
from
utils.utils
import
get_http_authorization_cred
,
get_current_user
from
config
import
ENV
from
config
import
SRC_LOG_LEVELS
,
ENV
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"LITELLM"
])
from
config
import
(
...
...
@@ -49,7 +54,7 @@ async def auth_middleware(request: Request, call_next):
try
:
user
=
get_current_user
(
get_http_authorization_cred
(
auth_header
))
print
(
user
)
log
.
debug
(
f
"user:
{
user
}
"
)
request
.
state
.
user
=
user
except
Exception
as
e
:
return
JSONResponse
(
status_code
=
400
,
content
=
{
"detail"
:
str
(
e
)})
...
...
backend/apps/ollama/main.py
View file @
ac294a74
from
fastapi
import
FastAPI
,
Request
,
Response
,
HTTPException
,
Depends
,
status
from
fastapi
import
(
FastAPI
,
Request
,
Response
,
HTTPException
,
Depends
,
status
,
UploadFile
,
File
,
BackgroundTasks
,
)
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
StreamingResponse
from
fastapi.concurrency
import
run_in_threadpool
from
pydantic
import
BaseModel
,
ConfigDict
import
os
import
copy
import
random
import
requests
import
json
import
uuid
import
aiohttp
import
asyncio
import
logging
from
urllib.parse
import
urlparse
from
typing
import
Optional
,
List
,
Union
from
apps.web.models.users
import
Users
from
constants
import
ERROR_MESSAGES
from
utils.utils
import
decode_token
,
get_current_user
,
get_admin_user
from
config
import
OLLAMA_BASE_URLS
,
MODEL_FILTER_ENABLED
,
MODEL_FILTER_LIST
from
typing
import
Optional
,
List
,
Union
from
config
import
SRC_LOG_LEVELS
,
OLLAMA_BASE_URLS
,
MODEL_FILTER_ENABLED
,
MODEL_FILTER_LIST
,
UPLOAD_DIR
from
utils.misc
import
calculate_sha256
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"OLLAMA"
])
app
=
FastAPI
()
app
.
add_middleware
(
...
...
@@ -69,7 +88,7 @@ class UrlUpdateForm(BaseModel):
async
def
update_ollama_api_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
OLLAMA_BASE_URLS
=
form_data
.
urls
print
(
app
.
state
.
OLLAMA_BASE_URLS
)
log
.
info
(
f
"
app.state.OLLAMA_BASE_URLS
:
{
app
.
state
.
OLLAMA_BASE_URLS
}
"
)
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
OLLAMA_BASE_URLS
}
...
...
@@ -90,7 +109,7 @@ async def fetch_url(url):
return
await
response
.
json
()
except
Exception
as
e
:
# Handle connection error here
print
(
f
"Connection error:
{
e
}
"
)
log
.
error
(
f
"Connection error:
{
e
}
"
)
return
None
...
...
@@ -114,7 +133,7 @@ def merge_models_lists(model_lists):
async
def
get_all_models
():
print
(
"get_all_models"
)
log
.
info
(
"get_all_models
()
"
)
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/tags"
)
for
url
in
app
.
state
.
OLLAMA_BASE_URLS
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
...
...
@@ -155,7 +174,7 @@ async def get_ollama_tags(
return
r
.
json
()
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
...
...
@@ -201,7 +220,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
return
r
.
json
()
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
...
...
@@ -227,18 +246,33 @@ async def pull_model(
form_data
:
ModelNameForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
):
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
def
get_request
():
nonlocal
url
nonlocal
r
request_id
=
str
(
uuid
.
uuid4
())
try
:
REQUEST_POOL
.
append
(
request_id
)
def
stream_content
():
try
:
yield
json
.
dumps
({
"id"
:
request_id
,
"done"
:
False
})
+
"
\n
"
for
chunk
in
r
.
iter_content
(
chunk_size
=
8192
):
if
request_id
in
REQUEST_POOL
:
yield
chunk
else
:
print
(
"User: canceled request"
)
break
finally
:
if
hasattr
(
r
,
"close"
):
r
.
close
()
if
request_id
in
REQUEST_POOL
:
REQUEST_POOL
.
remove
(
request_id
)
r
=
requests
.
request
(
method
=
"POST"
,
...
...
@@ -259,8 +293,9 @@ async def pull_model(
try
:
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
...
...
@@ -299,7 +334,7 @@ async def push_model(
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
debug
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -331,7 +366,7 @@ async def push_model(
try
:
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
...
...
@@ -359,9 +394,9 @@ class CreateModelForm(BaseModel):
async
def
create_model
(
form_data
:
CreateModelForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
):
print
(
form_data
)
log
.
debug
(
f
"form_data:
{
form_data
}
"
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -383,7 +418,7 @@ async def create_model(
r
.
raise_for_status
()
print
(
r
)
log
.
debug
(
f
"r:
{
r
}
"
)
return
StreamingResponse
(
stream_content
(),
...
...
@@ -396,7 +431,7 @@ async def create_model(
try
:
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
...
...
@@ -434,7 +469,7 @@ async def copy_model(
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
r
=
requests
.
request
(
...
...
@@ -444,11 +479,11 @@ async def copy_model(
)
r
.
raise_for_status
()
print
(
r
.
text
)
log
.
debug
(
f
"r.text:
{
r
.
text
}
"
)
return
True
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
...
...
@@ -481,7 +516,7 @@ async def delete_model(
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
r
=
requests
.
request
(
...
...
@@ -491,11 +526,11 @@ async def delete_model(
)
r
.
raise_for_status
()
print
(
r
.
text
)
log
.
debug
(
f
"r.text:
{
r
.
text
}
"
)
return
True
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
...
...
@@ -521,7 +556,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
])
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
r
=
requests
.
request
(
...
...
@@ -533,7 +568,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
return
r
.
json
()
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
...
...
@@ -573,7 +608,7 @@ async def generate_embeddings(
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
r
=
requests
.
request
(
...
...
@@ -585,7 +620,7 @@ async def generate_embeddings(
return
r
.
json
()
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
...
...
@@ -633,7 +668,7 @@ async def generate_completion(
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -654,7 +689,7 @@ async def generate_completion(
if
request_id
in
REQUEST_POOL
:
yield
chunk
else
:
pr
in
t
(
"User: canceled request"
)
log
.
warn
in
g
(
"User: canceled request"
)
break
finally
:
if
hasattr
(
r
,
"close"
):
...
...
@@ -731,11 +766,11 @@ async def generate_chat_completion(
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
prin
t
(
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
())
log
.
debug
(
"form_data.model_dump_json(exclude_none=True).encode(): {0} "
.
forma
t
(
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
())
)
def
get_request
():
nonlocal
form_data
...
...
@@ -754,7 +789,7 @@ async def generate_chat_completion(
if
request_id
in
REQUEST_POOL
:
yield
chunk
else
:
pr
in
t
(
"User: canceled request"
)
log
.
warn
in
g
(
"User: canceled request"
)
break
finally
:
if
hasattr
(
r
,
"close"
):
...
...
@@ -777,7 +812,7 @@ async def generate_chat_completion(
headers
=
dict
(
r
.
headers
),
)
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
e
try
:
...
...
@@ -831,7 +866,7 @@ async def generate_openai_chat_completion(
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -854,7 +889,7 @@ async def generate_openai_chat_completion(
if
request_id
in
REQUEST_POOL
:
yield
chunk
else
:
pr
in
t
(
"User: canceled request"
)
log
.
warn
in
g
(
"User: canceled request"
)
break
finally
:
if
hasattr
(
r
,
"close"
):
...
...
@@ -897,6 +932,211 @@ async def generate_openai_chat_completion(
)
class
UrlForm
(
BaseModel
):
url
:
str
class
UploadBlobForm
(
BaseModel
):
filename
:
str
def
parse_huggingface_url
(
hf_url
):
try
:
# Parse the URL
parsed_url
=
urlparse
(
hf_url
)
# Get the path and split it into components
path_components
=
parsed_url
.
path
.
split
(
"/"
)
# Extract the desired output
user_repo
=
"/"
.
join
(
path_components
[
1
:
3
])
model_file
=
path_components
[
-
1
]
return
model_file
except
ValueError
:
return
None
async
def
download_file_stream
(
ollama_url
,
file_url
,
file_path
,
file_name
,
chunk_size
=
1024
*
1024
):
done
=
False
if
os
.
path
.
exists
(
file_path
):
current_size
=
os
.
path
.
getsize
(
file_path
)
else
:
current_size
=
0
headers
=
{
"Range"
:
f
"bytes=
{
current_size
}
-"
}
if
current_size
>
0
else
{}
timeout
=
aiohttp
.
ClientTimeout
(
total
=
600
)
# Set the timeout
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
)
as
session
:
async
with
session
.
get
(
file_url
,
headers
=
headers
)
as
response
:
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
+
current_size
with
open
(
file_path
,
"ab+"
)
as
file
:
async
for
data
in
response
.
content
.
iter_chunked
(
chunk_size
):
current_size
+=
len
(
data
)
file
.
write
(
data
)
done
=
current_size
==
total_size
progress
=
round
((
current_size
/
total_size
)
*
100
,
2
)
yield
f
'data: {{"progress":
{
progress
}
, "completed":
{
current_size
}
, "total":
{
total_size
}
}}
\n\n
'
if
done
:
file
.
seek
(
0
)
hashed
=
calculate_sha256
(
file
)
file
.
seek
(
0
)
url
=
f
"
{
ollama_url
}
/api/blobs/sha256:
{
hashed
}
"
response
=
requests
.
post
(
url
,
data
=
file
)
if
response
.
ok
:
res
=
{
"done"
:
done
,
"blob"
:
f
"sha256:
{
hashed
}
"
,
"name"
:
file_name
,
}
os
.
remove
(
file_path
)
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
else
:
raise
"Ollama: Could not create blob, Please try again."
# def number_generator():
# for i in range(1, 101):
# yield f"data: {i}\n"
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
@
app
.
post
(
"/models/download"
)
@
app
.
post
(
"/models/download/{url_idx}"
)
async
def
download_model
(
form_data
:
UrlForm
,
url_idx
:
Optional
[
int
]
=
None
,
):
if
url_idx
==
None
:
url_idx
=
0
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
file_name
=
parse_huggingface_url
(
form_data
.
url
)
if
file_name
:
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
file_name
}
"
return
StreamingResponse
(
download_file_stream
(
url
,
form_data
.
url
,
file_path
,
file_name
),
)
else
:
return
None
@
app
.
post
(
"/models/upload"
)
@
app
.
post
(
"/models/upload/{url_idx}"
)
def
upload_model
(
file
:
UploadFile
=
File
(...),
url_idx
:
Optional
[
int
]
=
None
):
if
url_idx
==
None
:
url_idx
=
0
ollama_url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
file
.
filename
}
"
# Save file in chunks
with
open
(
file_path
,
"wb+"
)
as
f
:
for
chunk
in
file
.
file
:
f
.
write
(
chunk
)
def
file_process_stream
():
nonlocal
ollama_url
total_size
=
os
.
path
.
getsize
(
file_path
)
chunk_size
=
1024
*
1024
try
:
with
open
(
file_path
,
"rb"
)
as
f
:
total
=
0
done
=
False
while
not
done
:
chunk
=
f
.
read
(
chunk_size
)
if
not
chunk
:
done
=
True
continue
total
+=
len
(
chunk
)
progress
=
round
((
total
/
total_size
)
*
100
,
2
)
res
=
{
"progress"
:
progress
,
"total"
:
total_size
,
"completed"
:
total
,
}
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
if
done
:
f
.
seek
(
0
)
hashed
=
calculate_sha256
(
f
)
f
.
seek
(
0
)
url
=
f
"
{
ollama_url
}
/api/blobs/sha256:
{
hashed
}
"
response
=
requests
.
post
(
url
,
data
=
f
)
if
response
.
ok
:
res
=
{
"done"
:
done
,
"blob"
:
f
"sha256:
{
hashed
}
"
,
"name"
:
file
.
filename
,
}
os
.
remove
(
file_path
)
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
else
:
raise
Exception
(
"Ollama: Could not create blob, Please try again."
)
except
Exception
as
e
:
res
=
{
"error"
:
str
(
e
)}
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
return
StreamingResponse
(
file_process_stream
(),
media_type
=
"text/event-stream"
)
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# url_idx = 0
# url = app.state.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
# async def file_upload_generator(file):
# print(file)
# try:
# async with aiofiles.open(file_location, "wb") as f:
# completed_size = 0
# while True:
# chunk = await file.read(1024*1024)
# if not chunk:
# break
# await f.write(chunk)
# completed_size += len(chunk)
# progress = (completed_size / total_size) * 100
# print(progress)
# yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n'
# except Exception as e:
# print(e)
# yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n"
# finally:
# await file.close()
# print("done")
# yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n'
# return StreamingResponse(
# file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream"
# )
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
async
def
deprecated_proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_current_user
)):
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
0
]
...
...
@@ -947,7 +1187,7 @@ async def deprecated_proxy(path: str, request: Request, user=Depends(get_current
if
request_id
in
REQUEST_POOL
:
yield
chunk
else
:
pr
in
t
(
"User: canceled request"
)
log
.
warn
in
g
(
"User: canceled request"
)
break
finally
:
if
hasattr
(
r
,
"close"
):
...
...
backend/apps/openai/main.py
View file @
ac294a74
...
...
@@ -6,6 +6,7 @@ import requests
import
aiohttp
import
asyncio
import
json
import
logging
from
pydantic
import
BaseModel
...
...
@@ -19,6 +20,7 @@ from utils.utils import (
get_admin_user
,
)
from
config
import
(
SRC_LOG_LEVELS
,
OPENAI_API_BASE_URLS
,
OPENAI_API_KEYS
,
CACHE_DIR
,
...
...
@@ -31,6 +33,9 @@ from typing import List, Optional
import
hashlib
from
pathlib
import
Path
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"OPENAI"
])
app
=
FastAPI
()
app
.
add_middleware
(
CORSMiddleware
,
...
...
@@ -134,7 +139,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return
FileResponse
(
file_path
)
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
...
...
@@ -160,7 +165,7 @@ async def fetch_url(url, key):
return
await
response
.
json
()
except
Exception
as
e
:
# Handle connection error here
print
(
f
"Connection error:
{
e
}
"
)
log
.
error
(
f
"Connection error:
{
e
}
"
)
return
None
...
...
@@ -182,7 +187,7 @@ def merge_models_lists(model_lists):
async
def
get_all_models
():
print
(
"get_all_models"
)
log
.
info
(
"get_all_models
()
"
)
if
len
(
app
.
state
.
OPENAI_API_KEYS
)
==
1
and
app
.
state
.
OPENAI_API_KEYS
[
0
]
==
""
:
models
=
{
"data"
:
[]}
...
...
@@ -208,7 +213,7 @@ async def get_all_models():
)
}
print
(
models
)
log
.
info
(
f
"models:
{
models
}
"
)
app
.
state
.
MODELS
=
{
model
[
"id"
]:
model
for
model
in
models
[
"data"
]}
return
models
...
...
@@ -246,7 +251,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return
response_data
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
...
...
@@ -280,7 +285,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if
body
.
get
(
"model"
)
==
"gpt-4-vision-preview"
:
if
"max_tokens"
not
in
body
:
body
[
"max_tokens"
]
=
4000
print
(
"Modified body_dict:"
,
body
)
log
.
debug
(
"Modified body_dict:"
,
body
)
# Fix for ChatGPT calls failing because the num_ctx key is in body
if
"num_ctx"
in
body
:
...
...
@@ -292,7 +297,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
# Convert the modified body back to JSON
body
=
json
.
dumps
(
body
)
except
json
.
JSONDecodeError
as
e
:
print
(
"Error loading request body into a dictionary:"
,
e
)
log
.
error
(
"Error loading request body into a dictionary:"
,
e
)
url
=
app
.
state
.
OPENAI_API_BASE_URLS
[
idx
]
key
=
app
.
state
.
OPENAI_API_KEYS
[
idx
]
...
...
@@ -330,7 +335,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
response_data
=
r
.
json
()
return
response_data
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
...
...
backend/apps/rag/main.py
View file @
ac294a74
...
...
@@ -8,7 +8,7 @@ from fastapi import (
Form
,
)
from
fastapi.middleware.cors
import
CORSMiddleware
import
os
,
shutil
import
os
,
shutil
,
logging
from
pathlib
import
Path
from
typing
import
List
...
...
@@ -54,6 +54,7 @@ from utils.misc import (
)
from
utils.utils
import
get_current_user
,
get_admin_user
from
config
import
(
SRC_LOG_LEVELS
,
UPLOAD_DIR
,
DOCS_DIR
,
RAG_EMBEDDING_MODEL
,
...
...
@@ -66,6 +67,9 @@ from config import (
from
constants
import
ERROR_MESSAGES
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
#
# if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer(
...
...
@@ -110,40 +114,6 @@ class CollectionNameForm(BaseModel):
class
StoreWebForm
(
CollectionNameForm
):
url
:
str
def
store_data_in_vector_db
(
data
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
text_splitter
=
RecursiveCharacterTextSplitter
(
chunk_size
=
app
.
state
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
CHUNK_OVERLAP
)
docs
=
text_splitter
.
split_documents
(
data
)
texts
=
[
doc
.
page_content
for
doc
in
docs
]
metadatas
=
[
doc
.
metadata
for
doc
in
docs
]
try
:
if
overwrite
:
for
collection
in
CHROMA_CLIENT
.
list_collections
():
if
collection_name
==
collection
.
name
:
print
(
f
"deleting existing collection
{
collection_name
}
"
)
CHROMA_CLIENT
.
delete_collection
(
name
=
collection_name
)
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
collection
.
add
(
documents
=
texts
,
metadatas
=
metadatas
,
ids
=
[
str
(
uuid
.
uuid1
())
for
_
in
texts
]
)
return
True
except
Exception
as
e
:
print
(
e
)
if
e
.
__class__
.
__name__
==
"UniqueConstraintError"
:
return
True
return
False
@
app
.
get
(
"/"
)
async
def
get_status
():
return
{
...
...
@@ -274,7 +244,7 @@ def query_doc_handler(
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
...
...
@@ -318,13 +288,63 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
"filename"
:
form_data
.
url
,
}
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
)
def
store_data_in_vector_db
(
data
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
text_splitter
=
RecursiveCharacterTextSplitter
(
chunk_size
=
app
.
state
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
CHUNK_OVERLAP
,
add_start_index
=
True
,
)
docs
=
text_splitter
.
split_documents
(
data
)
return
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
)
def
store_text_in_vector_db
(
text
,
metadata
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
text_splitter
=
RecursiveCharacterTextSplitter
(
chunk_size
=
app
.
state
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
CHUNK_OVERLAP
,
add_start_index
=
True
,
)
docs
=
text_splitter
.
create_documents
([
text
],
metadatas
=
[
metadata
])
return
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
)
def
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
texts
=
[
doc
.
page_content
for
doc
in
docs
]
metadatas
=
[
doc
.
metadata
for
doc
in
docs
]
try
:
if
overwrite
:
for
collection
in
CHROMA_CLIENT
.
list_collections
():
if
collection_name
==
collection
.
name
:
print
(
f
"deleting existing collection
{
collection_name
}
"
)
CHROMA_CLIENT
.
delete_collection
(
name
=
collection_name
)
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
collection
.
add
(
documents
=
texts
,
metadatas
=
metadatas
,
ids
=
[
str
(
uuid
.
uuid1
())
for
_
in
texts
]
)
return
True
except
Exception
as
e
:
print
(
e
)
if
e
.
__class__
.
__name__
==
"UniqueConstraintError"
:
return
True
return
False
def
get_loader
(
filename
:
str
,
file_content_type
:
str
,
file_path
:
str
):
file_ext
=
filename
.
split
(
"."
)[
-
1
].
lower
()
known_type
=
True
...
...
@@ -416,7 +436,7 @@ def store_doc(
):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
print
(
file
.
content_type
)
log
.
info
(
f
"file.content_type:
{
file
.
content_type
}
"
)
try
:
filename
=
file
.
filename
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
filename
}
"
...
...
@@ -447,7 +467,7 @@ def store_doc(
detail
=
ERROR_MESSAGES
.
DEFAULT
(),
)
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
if
"No pandoc was found"
in
str
(
e
):
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
...
...
@@ -460,6 +480,37 @@ def store_doc(
)
class
TextRAGForm
(
BaseModel
):
name
:
str
content
:
str
collection_name
:
Optional
[
str
]
=
None
@
app
.
post
(
"/text"
)
def
store_text
(
form_data
:
TextRAGForm
,
user
=
Depends
(
get_current_user
),
):
collection_name
=
form_data
.
collection_name
if
collection_name
==
None
:
collection_name
=
calculate_sha256_string
(
form_data
.
content
)
result
=
store_text_in_vector_db
(
form_data
.
content
,
metadata
=
{
"name"
:
form_data
.
name
,
"created_by"
:
user
.
id
},
collection_name
=
collection_name
,
)
if
result
:
return
{
"status"
:
True
,
"collection_name"
:
collection_name
}
else
:
raise
HTTPException
(
status_code
=
status
.
HTTP_500_INTERNAL_SERVER_ERROR
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(),
)
@
app
.
get
(
"/scan"
)
def
scan_docs_dir
(
user
=
Depends
(
get_admin_user
)):
for
path
in
Path
(
DOCS_DIR
).
rglob
(
"./**/*"
):
...
...
@@ -512,7 +563,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
)
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
return
True
...
...
@@ -533,11 +584,11 @@ def reset(user=Depends(get_admin_user)) -> bool:
elif
os
.
path
.
isdir
(
file_path
):
shutil
.
rmtree
(
file_path
)
except
Exception
as
e
:
print
(
"Failed to delete %s. Reason: %s"
%
(
file_path
,
e
))
log
.
error
(
"Failed to delete %s. Reason: %s"
%
(
file_path
,
e
))
try
:
CHROMA_CLIENT
.
reset
()
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
return
True
backend/apps/rag/utils.py
View file @
ac294a74
import
re
import
logging
from
typing
import
List
from
config
import
CHROMA_CLIENT
from
config
import
SRC_LOG_LEVELS
,
CHROMA_CLIENT
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
def
query_doc
(
collection_name
:
str
,
query
:
str
,
k
:
int
,
embedding_function
):
...
...
@@ -97,7 +101,7 @@ def rag_template(template: str, context: str, query: str):
def
rag_messages
(
docs
,
messages
,
template
,
k
,
embedding_function
):
print
(
docs
)
log
.
debug
(
f
"docs:
{
docs
}
"
)
last_user_message_idx
=
None
for
i
in
range
(
len
(
messages
)
-
1
,
-
1
,
-
1
):
...
...
@@ -137,6 +141,8 @@ def rag_messages(docs, messages, template, k, embedding_function):
k
=
k
,
embedding_function
=
embedding_function
,
)
elif
doc
[
"type"
]
==
"text"
:
context
=
doc
[
"content"
]
else
:
context
=
query_doc
(
collection_name
=
doc
[
"collection_name"
],
...
...
@@ -145,7 +151,7 @@ def rag_messages(docs, messages, template, k, embedding_function):
embedding_function
=
embedding_function
,
)
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
context
=
None
relevant_contexts
.
append
(
context
)
...
...
backend/apps/web/internal/db.py
View file @
ac294a74
from
peewee
import
*
from
config
import
DATA_DIR
from
config
import
SRC_LOG_LEVELS
,
DATA_DIR
import
os
import
logging
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"DB"
])
# Check if the file exists
if
os
.
path
.
exists
(
f
"
{
DATA_DIR
}
/ollama.db"
):
# Rename the file
os
.
rename
(
f
"
{
DATA_DIR
}
/ollama.db"
,
f
"
{
DATA_DIR
}
/webui.db"
)
print
(
"File renamed successfully."
)
log
.
info
(
"File renamed successfully."
)
else
:
pass
...
...
backend/apps/web/models/auths.py
View file @
ac294a74
...
...
@@ -2,6 +2,7 @@ from pydantic import BaseModel
from
typing
import
List
,
Union
,
Optional
import
time
import
uuid
import
logging
from
peewee
import
*
from
apps.web.models.users
import
UserModel
,
Users
...
...
@@ -9,6 +10,10 @@ from utils.utils import verify_password
from
apps.web.internal.db
import
DB
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
####################
# DB MODEL
####################
...
...
@@ -86,7 +91,7 @@ class AuthsTable:
def
insert_new_auth
(
self
,
email
:
str
,
password
:
str
,
name
:
str
,
role
:
str
=
"pending"
)
->
Optional
[
UserModel
]:
print
(
"insert_new_auth"
)
log
.
info
(
"insert_new_auth"
)
id
=
str
(
uuid
.
uuid4
())
...
...
@@ -103,7 +108,7 @@ class AuthsTable:
return
None
def
authenticate_user
(
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
print
(
"authenticate_user
"
,
email
)
log
.
info
(
f
"authenticate_user
:
{
email
}
"
)
try
:
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
if
auth
:
...
...
backend/apps/web/models/chats.py
View file @
ac294a74
...
...
@@ -95,20 +95,6 @@ class ChatTable:
except
:
return
None
def
update_chat_by_id
(
self
,
id
:
str
,
chat
:
dict
)
->
Optional
[
ChatModel
]:
try
:
query
=
Chat
.
update
(
chat
=
json
.
dumps
(
chat
),
title
=
chat
[
"title"
]
if
"title"
in
chat
else
"New Chat"
,
timestamp
=
int
(
time
.
time
()),
).
where
(
Chat
.
id
==
id
)
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
except
:
return
None
def
get_chat_lists_by_user_id
(
self
,
user_id
:
str
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
...
...
backend/apps/web/models/documents.py
View file @
ac294a74
...
...
@@ -3,6 +3,7 @@ from peewee import *
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
import
time
import
logging
from
utils.utils
import
decode_token
from
utils.misc
import
get_gravatar_url
...
...
@@ -11,6 +12,10 @@ from apps.web.internal.db import DB
import
json
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
####################
# Documents DB Schema
####################
...
...
@@ -118,7 +123,7 @@ class DocumentsTable:
doc
=
Document
.
get
(
Document
.
name
==
form_data
.
name
)
return
DocumentModel
(
**
model_to_dict
(
doc
))
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
return
None
def
update_doc_content_by_name
(
...
...
@@ -138,7 +143,7 @@ class DocumentsTable:
doc
=
Document
.
get
(
Document
.
name
==
name
)
return
DocumentModel
(
**
model_to_dict
(
doc
))
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
return
None
def
delete_doc_by_name
(
self
,
name
:
str
)
->
bool
:
...
...
backend/apps/web/models/tags.py
View file @
ac294a74
...
...
@@ -6,9 +6,14 @@ from playhouse.shortcuts import model_to_dict
import
json
import
uuid
import
time
import
logging
from
apps.web.internal.db
import
DB
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
####################
# Tag DB Schema
####################
...
...
@@ -173,7 +178,7 @@ class TagTable:
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
)
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
print
(
res
)
log
.
debug
(
f
"res:
{
res
}
"
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
...
...
@@ -185,7 +190,7 @@ class TagTable:
return
True
except
Exception
as
e
:
print
(
"delete_tag
"
,
e
)
log
.
error
(
f
"delete_tag
:
{
e
}
"
)
return
False
def
delete_tag_by_tag_name_and_chat_id_and_user_id
(
...
...
@@ -198,7 +203,7 @@ class TagTable:
&
(
ChatIdTag
.
user_id
==
user_id
)
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
print
(
res
)
log
.
debug
(
f
"res:
{
res
}
"
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
...
...
@@ -210,7 +215,7 @@ class TagTable:
return
True
except
Exception
as
e
:
print
(
"delete_tag
"
,
e
)
log
.
error
(
f
"delete_tag
:
{
e
}
"
)
return
False
def
delete_tags_by_chat_id_and_user_id
(
self
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
...
...
backend/apps/web/routers/chats.py
View file @
ac294a74
...
...
@@ -5,6 +5,7 @@ from utils.utils import get_current_user, get_admin_user
from
fastapi
import
APIRouter
from
pydantic
import
BaseModel
import
json
import
logging
from
apps.web.models.users
import
Users
from
apps.web.models.chats
import
(
...
...
@@ -27,6 +28,10 @@ from apps.web.models.tags import (
from
constants
import
ERROR_MESSAGES
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
router
=
APIRouter
()
############################
...
...
@@ -78,7 +83,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
chat
=
Chats
.
insert_new_chat
(
user
.
id
,
form_data
)
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
()
)
...
...
@@ -95,7 +100,7 @@ async def get_all_tags(user=Depends(get_current_user)):
tags
=
Tags
.
get_tags_by_user_id
(
user
.
id
)
return
tags
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
()
)
...
...
backend/apps/web/routers/users.py
View file @
ac294a74
...
...
@@ -7,6 +7,7 @@ from fastapi import APIRouter
from
pydantic
import
BaseModel
import
time
import
uuid
import
logging
from
apps.web.models.users
import
UserModel
,
UserUpdateForm
,
UserRoleUpdateForm
,
Users
from
apps.web.models.auths
import
Auths
...
...
@@ -14,6 +15,10 @@ from apps.web.models.auths import Auths
from
utils.utils
import
get_current_user
,
get_password_hash
,
get_admin_user
from
constants
import
ERROR_MESSAGES
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
router
=
APIRouter
()
############################
...
...
@@ -83,7 +88,7 @@ async def update_user_by_id(
if
form_data
.
password
:
hashed
=
get_password_hash
(
form_data
.
password
)
print
(
hashed
)
log
.
debug
(
f
"hashed:
{
hashed
}
"
)
Auths
.
update_user_password_by_id
(
user_id
,
hashed
)
Auths
.
update_email_by_id
(
user_id
,
form_data
.
email
.
lower
())
...
...
backend/apps/web/routers/utils.py
View file @
ac294a74
...
...
@@ -21,155 +21,6 @@ from constants import ERROR_MESSAGES
router
=
APIRouter
()
class
UploadBlobForm
(
BaseModel
):
filename
:
str
from
urllib.parse
import
urlparse
def
parse_huggingface_url
(
hf_url
):
try
:
# Parse the URL
parsed_url
=
urlparse
(
hf_url
)
# Get the path and split it into components
path_components
=
parsed_url
.
path
.
split
(
"/"
)
# Extract the desired output
user_repo
=
"/"
.
join
(
path_components
[
1
:
3
])
model_file
=
path_components
[
-
1
]
return
model_file
except
ValueError
:
return
None
async
def
download_file_stream
(
url
,
file_path
,
file_name
,
chunk_size
=
1024
*
1024
):
done
=
False
if
os
.
path
.
exists
(
file_path
):
current_size
=
os
.
path
.
getsize
(
file_path
)
else
:
current_size
=
0
headers
=
{
"Range"
:
f
"bytes=
{
current_size
}
-"
}
if
current_size
>
0
else
{}
timeout
=
aiohttp
.
ClientTimeout
(
total
=
600
)
# Set the timeout
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
)
as
session
:
async
with
session
.
get
(
url
,
headers
=
headers
)
as
response
:
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
+
current_size
with
open
(
file_path
,
"ab+"
)
as
file
:
async
for
data
in
response
.
content
.
iter_chunked
(
chunk_size
):
current_size
+=
len
(
data
)
file
.
write
(
data
)
done
=
current_size
==
total_size
progress
=
round
((
current_size
/
total_size
)
*
100
,
2
)
yield
f
'data: {{"progress":
{
progress
}
, "completed":
{
current_size
}
, "total":
{
total_size
}
}}
\n\n
'
if
done
:
file
.
seek
(
0
)
hashed
=
calculate_sha256
(
file
)
file
.
seek
(
0
)
url
=
f
"
{
OLLAMA_BASE_URLS
[
0
]
}
/api/blobs/sha256:
{
hashed
}
"
response
=
requests
.
post
(
url
,
data
=
file
)
if
response
.
ok
:
res
=
{
"done"
:
done
,
"blob"
:
f
"sha256:
{
hashed
}
"
,
"name"
:
file_name
,
}
os
.
remove
(
file_path
)
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
else
:
raise
"Ollama: Could not create blob, Please try again."
@
router
.
get
(
"/download"
)
async
def
download
(
url
:
str
,
):
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
file_name
=
parse_huggingface_url
(
url
)
if
file_name
:
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
file_name
}
"
return
StreamingResponse
(
download_file_stream
(
url
,
file_path
,
file_name
),
media_type
=
"text/event-stream"
,
)
else
:
return
None
@
router
.
post
(
"/upload"
)
def
upload
(
file
:
UploadFile
=
File
(...)):
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
file
.
filename
}
"
# Save file in chunks
with
open
(
file_path
,
"wb+"
)
as
f
:
for
chunk
in
file
.
file
:
f
.
write
(
chunk
)
def
file_process_stream
():
total_size
=
os
.
path
.
getsize
(
file_path
)
chunk_size
=
1024
*
1024
try
:
with
open
(
file_path
,
"rb"
)
as
f
:
total
=
0
done
=
False
while
not
done
:
chunk
=
f
.
read
(
chunk_size
)
if
not
chunk
:
done
=
True
continue
total
+=
len
(
chunk
)
progress
=
round
((
total
/
total_size
)
*
100
,
2
)
res
=
{
"progress"
:
progress
,
"total"
:
total_size
,
"completed"
:
total
,
}
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
if
done
:
f
.
seek
(
0
)
hashed
=
calculate_sha256
(
f
)
f
.
seek
(
0
)
url
=
f
"
{
OLLAMA_BASE_URLS
[
0
]
}
/blobs/sha256:
{
hashed
}
"
response
=
requests
.
post
(
url
,
data
=
f
)
if
response
.
ok
:
res
=
{
"done"
:
done
,
"blob"
:
f
"sha256:
{
hashed
}
"
,
"name"
:
file
.
filename
,
}
os
.
remove
(
file_path
)
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
else
:
raise
Exception
(
"Ollama: Could not create blob, Please try again."
)
except
Exception
as
e
:
res
=
{
"error"
:
str
(
e
)}
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
return
StreamingResponse
(
file_process_stream
(),
media_type
=
"text/event-stream"
)
@
router
.
get
(
"/gravatar"
)
async
def
get_gravatar
(
email
:
str
,
...
...
backend/config.py
View file @
ac294a74
import
os
import
sys
import
logging
import
chromadb
from
chromadb
import
Settings
from
base64
import
b64encode
...
...
@@ -21,7 +23,7 @@ try:
load_dotenv
(
find_dotenv
(
"../.env"
))
except
ImportError
:
pr
in
t
(
"dotenv not installed, skipping..."
)
log
.
warn
in
g
(
"dotenv not installed, skipping..."
)
WEBUI_NAME
=
"Open WebUI"
shutil
.
copyfile
(
"../build/favicon.png"
,
"./static/favicon.png"
)
...
...
@@ -100,6 +102,34 @@ for version in soup.find_all("h2"):
CHANGELOG
=
changelog_json
####################################
# LOGGING
####################################
log_levels
=
[
"CRITICAL"
,
"ERROR"
,
"WARNING"
,
"INFO"
,
"DEBUG"
]
GLOBAL_LOG_LEVEL
=
os
.
environ
.
get
(
"GLOBAL_LOG_LEVEL"
,
""
).
upper
()
if
GLOBAL_LOG_LEVEL
in
log_levels
:
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
GLOBAL_LOG_LEVEL
,
force
=
True
)
else
:
GLOBAL_LOG_LEVEL
=
"INFO"
log
=
logging
.
getLogger
(
__name__
)
log
.
info
(
f
"GLOBAL_LOG_LEVEL:
{
GLOBAL_LOG_LEVEL
}
"
)
log_sources
=
[
"AUDIO"
,
"CONFIG"
,
"DB"
,
"IMAGES"
,
"LITELLM"
,
"MAIN"
,
"MODELS"
,
"OLLAMA"
,
"OPENAI"
,
"RAG"
]
SRC_LOG_LEVELS
=
{}
for
source
in
log_sources
:
log_env_var
=
source
+
"_LOG_LEVEL"
SRC_LOG_LEVELS
[
source
]
=
os
.
environ
.
get
(
log_env_var
,
""
).
upper
()
if
SRC_LOG_LEVELS
[
source
]
not
in
log_levels
:
SRC_LOG_LEVELS
[
source
]
=
GLOBAL_LOG_LEVEL
log
.
info
(
f
"
{
log_env_var
}
:
{
SRC_LOG_LEVELS
[
source
]
}
"
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"CONFIG"
])
####################################
# CUSTOM_NAME
####################################
...
...
@@ -125,7 +155,7 @@ if CUSTOM_NAME:
WEBUI_NAME
=
data
[
"name"
]
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
pass
...
...
@@ -194,9 +224,9 @@ def create_config_file(file_path):
LITELLM_CONFIG_PATH
=
f
"
{
DATA_DIR
}
/litellm/config.yaml"
if
not
os
.
path
.
exists
(
LITELLM_CONFIG_PATH
):
print
(
"Config file doesn't exist. Creating..."
)
log
.
info
(
"Config file doesn't exist. Creating..."
)
create_config_file
(
LITELLM_CONFIG_PATH
)
print
(
"Config file created successfully."
)
log
.
info
(
"Config file created successfully."
)
####################################
...
...
@@ -376,3 +406,4 @@ WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models"
####################################
AUTOMATIC1111_BASE_URL
=
os
.
getenv
(
"AUTOMATIC1111_BASE_URL"
,
""
)
COMFYUI_BASE_URL
=
os
.
getenv
(
"COMFYUI_BASE_URL"
,
""
)
backend/main.py
View file @
ac294a74
...
...
@@ -4,6 +4,7 @@ import markdown
import
time
import
os
import
sys
import
logging
import
requests
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
...
...
@@ -38,10 +39,15 @@ from config import (
FRONTEND_BUILD_DIR
,
MODEL_FILTER_ENABLED
,
MODEL_FILTER_LIST
,
GLOBAL_LOG_LEVEL
,
SRC_LOG_LEVELS
,
WEBHOOK_URL
,
)
from
constants
import
ERROR_MESSAGES
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
GLOBAL_LOG_LEVEL
)
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MAIN"
])
class
SPAStaticFiles
(
StaticFiles
):
async
def
get_response
(
self
,
path
:
str
,
scope
):
...
...
@@ -70,7 +76,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
if
request
.
method
==
"POST"
and
(
"/api/chat"
in
request
.
url
.
path
or
"/chat/completions"
in
request
.
url
.
path
):
print
(
request
.
url
.
path
)
log
.
debug
(
f
"request.url.path:
{
request
.
url
.
path
}
"
)
# Read the original request body
body
=
await
request
.
body
()
...
...
@@ -93,7 +99,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
)
del
data
[
"docs"
]
print
(
data
[
"
messages"
]
)
log
.
debug
(
f
"data['messages']:
{
data
[
'
messages
'
]
}
"
)
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment