Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
ac294a74
Unverified
Commit
ac294a74
authored
Mar 24, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Mar 24, 2024
Browse files
Merge pull request #1277 from open-webui/dev
0.1.115
parents
2fa94956
4c959891
Changes
53
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
845 additions
and
292 deletions
+845
-292
.gitignore
.gitignore
+1
-1
CHANGELOG.md
CHANGELOG.md
+18
-0
backend/apps/audio/main.py
backend/apps/audio/main.py
+8
-4
backend/apps/images/main.py
backend/apps/images/main.py
+107
-13
backend/apps/images/utils/comfyui.py
backend/apps/images/utils/comfyui.py
+228
-0
backend/apps/litellm/main.py
backend/apps/litellm/main.py
+7
-2
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+277
-37
backend/apps/openai/main.py
backend/apps/openai/main.py
+13
-8
backend/apps/rag/main.py
backend/apps/rag/main.py
+93
-42
backend/apps/rag/utils.py
backend/apps/rag/utils.py
+9
-3
backend/apps/web/internal/db.py
backend/apps/web/internal/db.py
+5
-2
backend/apps/web/models/auths.py
backend/apps/web/models/auths.py
+7
-2
backend/apps/web/models/chats.py
backend/apps/web/models/chats.py
+0
-14
backend/apps/web/models/documents.py
backend/apps/web/models/documents.py
+7
-2
backend/apps/web/models/tags.py
backend/apps/web/models/tags.py
+9
-4
backend/apps/web/routers/chats.py
backend/apps/web/routers/chats.py
+7
-2
backend/apps/web/routers/users.py
backend/apps/web/routers/users.py
+6
-1
backend/apps/web/routers/utils.py
backend/apps/web/routers/utils.py
+0
-149
backend/config.py
backend/config.py
+35
-4
backend/main.py
backend/main.py
+8
-2
No files found.
.gitignore
View file @
ac294a74
...
@@ -166,7 +166,7 @@ cython_debug/
...
@@ -166,7 +166,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#
.idea/
.idea/
# Logs
# Logs
logs
logs
...
...
CHANGELOG.md
View file @
ac294a74
...
@@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file.
...
@@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file.
The format is based on
[
Keep a Changelog
](
https://keepachangelog.com/en/1.1.0/
)
,
The format is based on
[
Keep a Changelog
](
https://keepachangelog.com/en/1.1.0/
)
,
and this project adheres to
[
Semantic Versioning
](
https://semver.org/spec/v2.0.0.html
)
.
and this project adheres to
[
Semantic Versioning
](
https://semver.org/spec/v2.0.0.html
)
.
## [0.1.115] - 2024-03-24
### Added
-
**🔍 Custom Model Selector**
: Easily find and select custom models with the new search filter feature.
-
**🛑 Cancel Model Download**
: Added the ability to cancel model downloads.
-
**🎨 Image Generation ComfyUI**
: Image generation now supports ComfyUI.
-
**🌟 Updated Light Theme**
: Updated the light theme for a fresh look.
-
**🌍 Additional Language Support**
: Now supporting Bulgarian, Italian, Portuguese, Japanese, and Dutch.
### Fixed
-
**🔧 Fixed Broken Experimental GGUF Upload**
: Resolved issues with experimental GGUF upload functionality.
### Changed
-
**🔄 Vector Storage Reset Button**
: Moved the reset vector storage button to document settings.
## [0.1.114] - 2024-03-20
## [0.1.114] - 2024-03-20
### Added
### Added
...
...
backend/apps/audio/main.py
View file @
ac294a74
import
os
import
os
import
logging
from
fastapi
import
(
from
fastapi
import
(
FastAPI
,
FastAPI
,
Request
,
Request
,
...
@@ -21,7 +22,10 @@ from utils.utils import (
...
@@ -21,7 +22,10 @@ from utils.utils import (
)
)
from
utils.misc
import
calculate_sha256
from
utils.misc
import
calculate_sha256
from
config
import
CACHE_DIR
,
UPLOAD_DIR
,
WHISPER_MODEL
,
WHISPER_MODEL_DIR
from
config
import
SRC_LOG_LEVELS
,
CACHE_DIR
,
UPLOAD_DIR
,
WHISPER_MODEL
,
WHISPER_MODEL_DIR
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"AUDIO"
])
app
=
FastAPI
()
app
=
FastAPI
()
app
.
add_middleware
(
app
.
add_middleware
(
...
@@ -38,7 +42,7 @@ def transcribe(
...
@@ -38,7 +42,7 @@ def transcribe(
file
:
UploadFile
=
File
(...),
file
:
UploadFile
=
File
(...),
user
=
Depends
(
get_current_user
),
user
=
Depends
(
get_current_user
),
):
):
print
(
file
.
content_type
)
log
.
info
(
f
"file.content_type:
{
file
.
content_type
}
"
)
if
file
.
content_type
not
in
[
"audio/mpeg"
,
"audio/wav"
]:
if
file
.
content_type
not
in
[
"audio/mpeg"
,
"audio/wav"
]:
raise
HTTPException
(
raise
HTTPException
(
...
@@ -62,7 +66,7 @@ def transcribe(
...
@@ -62,7 +66,7 @@ def transcribe(
)
)
segments
,
info
=
model
.
transcribe
(
file_path
,
beam_size
=
5
)
segments
,
info
=
model
.
transcribe
(
file_path
,
beam_size
=
5
)
print
(
log
.
info
(
"Detected language '%s' with probability %f"
"Detected language '%s' with probability %f"
%
(
info
.
language
,
info
.
language_probability
)
%
(
info
.
language
,
info
.
language_probability
)
)
)
...
@@ -72,7 +76,7 @@ def transcribe(
...
@@ -72,7 +76,7 @@ def transcribe(
return
{
"text"
:
transcript
.
strip
()}
return
{
"text"
:
transcript
.
strip
()}
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
...
...
backend/apps/images/main.py
View file @
ac294a74
...
@@ -18,6 +18,8 @@ from utils.utils import (
...
@@ -18,6 +18,8 @@ from utils.utils import (
get_current_user
,
get_current_user
,
get_admin_user
,
get_admin_user
,
)
)
from
apps.images.utils.comfyui
import
ImageGenerationPayload
,
comfyui_generate_image
from
utils.misc
import
calculate_sha256
from
utils.misc
import
calculate_sha256
from
typing
import
Optional
from
typing
import
Optional
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
...
@@ -25,9 +27,13 @@ from pathlib import Path
...
@@ -25,9 +27,13 @@ from pathlib import Path
import
uuid
import
uuid
import
base64
import
base64
import
json
import
json
import
logging
from
config
import
SRC_LOG_LEVELS
,
CACHE_DIR
,
AUTOMATIC1111_BASE_URL
,
COMFYUI_BASE_URL
from
config
import
CACHE_DIR
,
AUTOMATIC1111_BASE_URL
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"IMAGES"
])
IMAGE_CACHE_DIR
=
Path
(
CACHE_DIR
).
joinpath
(
"./image/generations/"
)
IMAGE_CACHE_DIR
=
Path
(
CACHE_DIR
).
joinpath
(
"./image/generations/"
)
IMAGE_CACHE_DIR
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
IMAGE_CACHE_DIR
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
@@ -49,6 +55,8 @@ app.state.MODEL = ""
...
@@ -49,6 +55,8 @@ app.state.MODEL = ""
app
.
state
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
app
.
state
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
app
.
state
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
app
.
state
.
IMAGE_SIZE
=
"512x512"
app
.
state
.
IMAGE_SIZE
=
"512x512"
app
.
state
.
IMAGE_STEPS
=
50
app
.
state
.
IMAGE_STEPS
=
50
...
@@ -71,32 +79,48 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user
...
@@ -71,32 +79,48 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user
return
{
"engine"
:
app
.
state
.
ENGINE
,
"enabled"
:
app
.
state
.
ENABLED
}
return
{
"engine"
:
app
.
state
.
ENGINE
,
"enabled"
:
app
.
state
.
ENABLED
}
class
UrlUpdateForm
(
BaseModel
):
class
EngineUrlUpdateForm
(
BaseModel
):
url
:
str
AUTOMATIC1111_BASE_URL
:
Optional
[
str
]
=
None
COMFYUI_BASE_URL
:
Optional
[
str
]
=
None
@
app
.
get
(
"/url"
)
@
app
.
get
(
"/url"
)
async
def
get_automatic1111_url
(
user
=
Depends
(
get_admin_user
)):
async
def
get_engine_url
(
user
=
Depends
(
get_admin_user
)):
return
{
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
AUTOMATIC1111_BASE_URL
}
return
{
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
app
.
state
.
COMFYUI_BASE_URL
,
}
@
app
.
post
(
"/url/update"
)
@
app
.
post
(
"/url/update"
)
async
def
update_
automatic1111
_url
(
async
def
update_
engine
_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)
form_data
:
Engine
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
):
if
form_data
.
url
==
""
:
if
form_data
.
AUTOMATIC1111_BASE_URL
==
None
:
app
.
state
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
app
.
state
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
else
:
else
:
url
=
form_data
.
url
.
strip
(
"/"
)
url
=
form_data
.
AUTOMATIC1111_BASE_URL
.
strip
(
"/"
)
try
:
try
:
r
=
requests
.
head
(
url
)
r
=
requests
.
head
(
url
)
app
.
state
.
AUTOMATIC1111_BASE_URL
=
url
app
.
state
.
AUTOMATIC1111_BASE_URL
=
url
except
Exception
as
e
:
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
if
form_data
.
COMFYUI_BASE_URL
==
None
:
app
.
state
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
else
:
url
=
form_data
.
COMFYUI_BASE_URL
.
strip
(
"/"
)
try
:
r
=
requests
.
head
(
url
)
app
.
state
.
COMFYUI_BASE_URL
=
url
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
return
{
return
{
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
AUTOMATIC1111_BASE_URL
,
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
app
.
state
.
COMFYUI_BASE_URL
,
"status"
:
True
,
"status"
:
True
,
}
}
...
@@ -186,6 +210,18 @@ def get_models(user=Depends(get_current_user)):
...
@@ -186,6 +210,18 @@ def get_models(user=Depends(get_current_user)):
{
"id"
:
"dall-e-2"
,
"name"
:
"DALL·E 2"
},
{
"id"
:
"dall-e-2"
,
"name"
:
"DALL·E 2"
},
{
"id"
:
"dall-e-3"
,
"name"
:
"DALL·E 3"
},
{
"id"
:
"dall-e-3"
,
"name"
:
"DALL·E 3"
},
]
]
elif
app
.
state
.
ENGINE
==
"comfyui"
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
COMFYUI_BASE_URL
}
/object_info"
)
info
=
r
.
json
()
return
list
(
map
(
lambda
model
:
{
"id"
:
model
,
"name"
:
model
},
info
[
"CheckpointLoaderSimple"
][
"input"
][
"required"
][
"ckpt_name"
][
0
],
)
)
else
:
else
:
r
=
requests
.
get
(
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/sd-models"
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/sd-models"
...
@@ -207,6 +243,8 @@ async def get_default_model(user=Depends(get_admin_user)):
...
@@ -207,6 +243,8 @@ async def get_default_model(user=Depends(get_admin_user)):
try
:
try
:
if
app
.
state
.
ENGINE
==
"openai"
:
if
app
.
state
.
ENGINE
==
"openai"
:
return
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
else
"dall-e-2"
}
return
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
else
"dall-e-2"
}
elif
app
.
state
.
ENGINE
==
"comfyui"
:
return
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
else
""
}
else
:
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
options
=
r
.
json
()
...
@@ -221,10 +259,12 @@ class UpdateModelForm(BaseModel):
...
@@ -221,10 +259,12 @@ class UpdateModelForm(BaseModel):
def
set_model_handler
(
model
:
str
):
def
set_model_handler
(
model
:
str
):
if
app
.
state
.
ENGINE
==
"openai"
:
if
app
.
state
.
ENGINE
==
"openai"
:
app
.
state
.
MODEL
=
model
app
.
state
.
MODEL
=
model
return
app
.
state
.
MODEL
return
app
.
state
.
MODEL
if
app
.
state
.
ENGINE
==
"comfyui"
:
app
.
state
.
MODEL
=
model
return
app
.
state
.
MODEL
else
:
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
options
=
r
.
json
()
...
@@ -266,6 +306,23 @@ def save_b64_image(b64_str):
...
@@ -266,6 +306,23 @@ def save_b64_image(b64_str):
with
open
(
file_path
,
"wb"
)
as
f
:
with
open
(
file_path
,
"wb"
)
as
f
:
f
.
write
(
img_data
)
f
.
write
(
img_data
)
return
image_id
except
Exception
as
e
:
log
.
error
(
f
"Error saving image:
{
e
}
"
)
return
None
def
save_url_image
(
url
):
image_id
=
str
(
uuid
.
uuid4
())
file_path
=
IMAGE_CACHE_DIR
.
joinpath
(
f
"
{
image_id
}
.png"
)
try
:
r
=
requests
.
get
(
url
)
r
.
raise_for_status
()
with
open
(
file_path
,
"wb"
)
as
image_file
:
image_file
.
write
(
r
.
content
)
return
image_id
return
image_id
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Error saving image:
{
e
}
"
)
print
(
f
"Error saving image:
{
e
}
"
)
...
@@ -278,6 +335,8 @@ def generate_image(
...
@@ -278,6 +335,8 @@ def generate_image(
user
=
Depends
(
get_current_user
),
user
=
Depends
(
get_current_user
),
):
):
width
,
height
=
tuple
(
map
(
int
,
app
.
state
.
IMAGE_SIZE
.
split
(
"x"
)))
r
=
None
r
=
None
try
:
try
:
if
app
.
state
.
ENGINE
==
"openai"
:
if
app
.
state
.
ENGINE
==
"openai"
:
...
@@ -315,12 +374,47 @@ def generate_image(
...
@@ -315,12 +374,47 @@ def generate_image(
return
images
return
images
elif
app
.
state
.
ENGINE
==
"comfyui"
:
data
=
{
"prompt"
:
form_data
.
prompt
,
"width"
:
width
,
"height"
:
height
,
"n"
:
form_data
.
n
,
}
if
app
.
state
.
IMAGE_STEPS
!=
None
:
data
[
"steps"
]
=
app
.
state
.
IMAGE_STEPS
if
form_data
.
negative_prompt
!=
None
:
data
[
"negative_prompt"
]
=
form_data
.
negative_prompt
data
=
ImageGenerationPayload
(
**
data
)
res
=
comfyui_generate_image
(
app
.
state
.
MODEL
,
data
,
user
.
id
,
app
.
state
.
COMFYUI_BASE_URL
,
)
print
(
res
)
images
=
[]
for
image
in
res
[
"data"
]:
image_id
=
save_url_image
(
image
[
"url"
])
images
.
append
({
"url"
:
f
"/cache/image/generations/
{
image_id
}
.png"
})
file_body_path
=
IMAGE_CACHE_DIR
.
joinpath
(
f
"
{
image_id
}
.json"
)
with
open
(
file_body_path
,
"w"
)
as
f
:
json
.
dump
(
data
.
model_dump
(
exclude_none
=
True
),
f
)
print
(
images
)
return
images
else
:
else
:
if
form_data
.
model
:
if
form_data
.
model
:
set_model_handler
(
form_data
.
model
)
set_model_handler
(
form_data
.
model
)
width
,
height
=
tuple
(
map
(
int
,
app
.
state
.
IMAGE_SIZE
.
split
(
"x"
)))
data
=
{
data
=
{
"prompt"
:
form_data
.
prompt
,
"prompt"
:
form_data
.
prompt
,
"batch_size"
:
form_data
.
n
,
"batch_size"
:
form_data
.
n
,
...
@@ -341,7 +435,7 @@ def generate_image(
...
@@ -341,7 +435,7 @@ def generate_image(
res
=
r
.
json
()
res
=
r
.
json
()
print
(
res
)
log
.
debug
(
f
"res:
{
res
}
"
)
images
=
[]
images
=
[]
...
...
backend/apps/images/utils/comfyui.py
0 → 100644
View file @
ac294a74
import
websocket
# NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import
uuid
import
json
import
urllib.request
import
urllib.parse
import
random
from
pydantic
import
BaseModel
from
typing
import
Optional
COMFYUI_DEFAULT_PROMPT
=
"""
{
"3": {
"inputs": {
"seed": 0,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "model.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "Negative Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
"""
def
queue_prompt
(
prompt
,
client_id
,
base_url
):
print
(
"queue_prompt"
)
p
=
{
"prompt"
:
prompt
,
"client_id"
:
client_id
}
data
=
json
.
dumps
(
p
).
encode
(
"utf-8"
)
req
=
urllib
.
request
.
Request
(
f
"
{
base_url
}
/prompt"
,
data
=
data
)
return
json
.
loads
(
urllib
.
request
.
urlopen
(
req
).
read
())
def
get_image
(
filename
,
subfolder
,
folder_type
,
base_url
):
print
(
"get_image"
)
data
=
{
"filename"
:
filename
,
"subfolder"
:
subfolder
,
"type"
:
folder_type
}
url_values
=
urllib
.
parse
.
urlencode
(
data
)
with
urllib
.
request
.
urlopen
(
f
"
{
base_url
}
/view?
{
url_values
}
"
)
as
response
:
return
response
.
read
()
def
get_image_url
(
filename
,
subfolder
,
folder_type
,
base_url
):
print
(
"get_image"
)
data
=
{
"filename"
:
filename
,
"subfolder"
:
subfolder
,
"type"
:
folder_type
}
url_values
=
urllib
.
parse
.
urlencode
(
data
)
return
f
"
{
base_url
}
/view?
{
url_values
}
"
def
get_history
(
prompt_id
,
base_url
):
print
(
"get_history"
)
with
urllib
.
request
.
urlopen
(
f
"
{
base_url
}
/history/
{
prompt_id
}
"
)
as
response
:
return
json
.
loads
(
response
.
read
())
def
get_images
(
ws
,
prompt
,
client_id
,
base_url
):
prompt_id
=
queue_prompt
(
prompt
,
client_id
,
base_url
)[
"prompt_id"
]
output_images
=
[]
while
True
:
out
=
ws
.
recv
()
if
isinstance
(
out
,
str
):
message
=
json
.
loads
(
out
)
if
message
[
"type"
]
==
"executing"
:
data
=
message
[
"data"
]
if
data
[
"node"
]
is
None
and
data
[
"prompt_id"
]
==
prompt_id
:
break
# Execution is done
else
:
continue
# previews are binary data
history
=
get_history
(
prompt_id
,
base_url
)[
prompt_id
]
for
o
in
history
[
"outputs"
]:
for
node_id
in
history
[
"outputs"
]:
node_output
=
history
[
"outputs"
][
node_id
]
if
"images"
in
node_output
:
for
image
in
node_output
[
"images"
]:
url
=
get_image_url
(
image
[
"filename"
],
image
[
"subfolder"
],
image
[
"type"
],
base_url
)
output_images
.
append
({
"url"
:
url
})
return
{
"data"
:
output_images
}
class
ImageGenerationPayload
(
BaseModel
):
prompt
:
str
negative_prompt
:
Optional
[
str
]
=
""
steps
:
Optional
[
int
]
=
None
seed
:
Optional
[
int
]
=
None
width
:
int
height
:
int
n
:
int
=
1
def
comfyui_generate_image
(
model
:
str
,
payload
:
ImageGenerationPayload
,
client_id
,
base_url
):
host
=
base_url
.
replace
(
"http://"
,
""
).
replace
(
"https://"
,
""
)
comfyui_prompt
=
json
.
loads
(
COMFYUI_DEFAULT_PROMPT
)
comfyui_prompt
[
"4"
][
"inputs"
][
"ckpt_name"
]
=
model
comfyui_prompt
[
"5"
][
"inputs"
][
"batch_size"
]
=
payload
.
n
comfyui_prompt
[
"5"
][
"inputs"
][
"width"
]
=
payload
.
width
comfyui_prompt
[
"5"
][
"inputs"
][
"height"
]
=
payload
.
height
# set the text prompt for our positive CLIPTextEncode
comfyui_prompt
[
"6"
][
"inputs"
][
"text"
]
=
payload
.
prompt
comfyui_prompt
[
"7"
][
"inputs"
][
"text"
]
=
payload
.
negative_prompt
if
payload
.
steps
:
comfyui_prompt
[
"3"
][
"inputs"
][
"steps"
]
=
payload
.
steps
comfyui_prompt
[
"3"
][
"inputs"
][
"seed"
]
=
(
payload
.
seed
if
payload
.
seed
else
random
.
randint
(
0
,
18446744073709551614
)
)
try
:
ws
=
websocket
.
WebSocket
()
ws
.
connect
(
f
"ws://
{
host
}
/ws?clientId=
{
client_id
}
"
)
print
(
"WebSocket connection established."
)
except
Exception
as
e
:
print
(
f
"Failed to connect to WebSocket server:
{
e
}
"
)
return
None
try
:
images
=
get_images
(
ws
,
comfyui_prompt
,
client_id
,
base_url
)
except
Exception
as
e
:
print
(
f
"Error while receiving images:
{
e
}
"
)
images
=
None
ws
.
close
()
return
images
backend/apps/litellm/main.py
View file @
ac294a74
import
logging
from
litellm.proxy.proxy_server
import
ProxyConfig
,
initialize
from
litellm.proxy.proxy_server
import
ProxyConfig
,
initialize
from
litellm.proxy.proxy_server
import
app
from
litellm.proxy.proxy_server
import
app
...
@@ -9,7 +11,10 @@ from starlette.responses import StreamingResponse
...
@@ -9,7 +11,10 @@ from starlette.responses import StreamingResponse
import
json
import
json
from
utils.utils
import
get_http_authorization_cred
,
get_current_user
from
utils.utils
import
get_http_authorization_cred
,
get_current_user
from
config
import
ENV
from
config
import
SRC_LOG_LEVELS
,
ENV
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"LITELLM"
])
from
config
import
(
from
config
import
(
...
@@ -49,7 +54,7 @@ async def auth_middleware(request: Request, call_next):
...
@@ -49,7 +54,7 @@ async def auth_middleware(request: Request, call_next):
try
:
try
:
user
=
get_current_user
(
get_http_authorization_cred
(
auth_header
))
user
=
get_current_user
(
get_http_authorization_cred
(
auth_header
))
print
(
user
)
log
.
debug
(
f
"user:
{
user
}
"
)
request
.
state
.
user
=
user
request
.
state
.
user
=
user
except
Exception
as
e
:
except
Exception
as
e
:
return
JSONResponse
(
status_code
=
400
,
content
=
{
"detail"
:
str
(
e
)})
return
JSONResponse
(
status_code
=
400
,
content
=
{
"detail"
:
str
(
e
)})
...
...
backend/apps/ollama/main.py
View file @
ac294a74
from
fastapi
import
FastAPI
,
Request
,
Response
,
HTTPException
,
Depends
,
status
from
fastapi
import
(
FastAPI
,
Request
,
Response
,
HTTPException
,
Depends
,
status
,
UploadFile
,
File
,
BackgroundTasks
,
)
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
StreamingResponse
from
fastapi.responses
import
StreamingResponse
from
fastapi.concurrency
import
run_in_threadpool
from
fastapi.concurrency
import
run_in_threadpool
from
pydantic
import
BaseModel
,
ConfigDict
from
pydantic
import
BaseModel
,
ConfigDict
import
os
import
copy
import
random
import
random
import
requests
import
requests
import
json
import
json
import
uuid
import
uuid
import
aiohttp
import
aiohttp
import
asyncio
import
asyncio
import
logging
from
urllib.parse
import
urlparse
from
typing
import
Optional
,
List
,
Union
from
apps.web.models.users
import
Users
from
apps.web.models.users
import
Users
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
from
utils.utils
import
decode_token
,
get_current_user
,
get_admin_user
from
utils.utils
import
decode_token
,
get_current_user
,
get_admin_user
from
config
import
OLLAMA_BASE_URLS
,
MODEL_FILTER_ENABLED
,
MODEL_FILTER_LIST
from
typing
import
Optional
,
List
,
Union
from
config
import
SRC_LOG_LEVELS
,
OLLAMA_BASE_URLS
,
MODEL_FILTER_ENABLED
,
MODEL_FILTER_LIST
,
UPLOAD_DIR
from
utils.misc
import
calculate_sha256
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"OLLAMA"
])
app
=
FastAPI
()
app
=
FastAPI
()
app
.
add_middleware
(
app
.
add_middleware
(
...
@@ -69,7 +88,7 @@ class UrlUpdateForm(BaseModel):
...
@@ -69,7 +88,7 @@ class UrlUpdateForm(BaseModel):
async
def
update_ollama_api_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)):
async
def
update_ollama_api_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
OLLAMA_BASE_URLS
=
form_data
.
urls
app
.
state
.
OLLAMA_BASE_URLS
=
form_data
.
urls
print
(
app
.
state
.
OLLAMA_BASE_URLS
)
log
.
info
(
f
"
app.state.OLLAMA_BASE_URLS
:
{
app
.
state
.
OLLAMA_BASE_URLS
}
"
)
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
OLLAMA_BASE_URLS
}
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
OLLAMA_BASE_URLS
}
...
@@ -90,7 +109,7 @@ async def fetch_url(url):
...
@@ -90,7 +109,7 @@ async def fetch_url(url):
return
await
response
.
json
()
return
await
response
.
json
()
except
Exception
as
e
:
except
Exception
as
e
:
# Handle connection error here
# Handle connection error here
print
(
f
"Connection error:
{
e
}
"
)
log
.
error
(
f
"Connection error:
{
e
}
"
)
return
None
return
None
...
@@ -114,7 +133,7 @@ def merge_models_lists(model_lists):
...
@@ -114,7 +133,7 @@ def merge_models_lists(model_lists):
async
def
get_all_models
():
async
def
get_all_models
():
print
(
"get_all_models"
)
log
.
info
(
"get_all_models
()
"
)
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/tags"
)
for
url
in
app
.
state
.
OLLAMA_BASE_URLS
]
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/tags"
)
for
url
in
app
.
state
.
OLLAMA_BASE_URLS
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
responses
=
await
asyncio
.
gather
(
*
tasks
)
...
@@ -155,7 +174,7 @@ async def get_ollama_tags(
...
@@ -155,7 +174,7 @@ async def get_ollama_tags(
return
r
.
json
()
return
r
.
json
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -201,7 +220,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
...
@@ -201,7 +220,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
return
r
.
json
()
return
r
.
json
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -227,18 +246,33 @@ async def pull_model(
...
@@ -227,18 +246,33 @@ async def pull_model(
form_data
:
ModelNameForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
form_data
:
ModelNameForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
):
):
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
def
get_request
():
def
get_request
():
nonlocal
url
nonlocal
url
nonlocal
r
nonlocal
r
request_id
=
str
(
uuid
.
uuid4
())
try
:
try
:
REQUEST_POOL
.
append
(
request_id
)
def
stream_content
():
def
stream_content
():
try
:
yield
json
.
dumps
({
"id"
:
request_id
,
"done"
:
False
})
+
"
\n
"
for
chunk
in
r
.
iter_content
(
chunk_size
=
8192
):
for
chunk
in
r
.
iter_content
(
chunk_size
=
8192
):
if
request_id
in
REQUEST_POOL
:
yield
chunk
yield
chunk
else
:
print
(
"User: canceled request"
)
break
finally
:
if
hasattr
(
r
,
"close"
):
r
.
close
()
if
request_id
in
REQUEST_POOL
:
REQUEST_POOL
.
remove
(
request_id
)
r
=
requests
.
request
(
r
=
requests
.
request
(
method
=
"POST"
,
method
=
"POST"
,
...
@@ -259,8 +293,9 @@ async def pull_model(
...
@@ -259,8 +293,9 @@ async def pull_model(
try
:
try
:
return
await
run_in_threadpool
(
get_request
)
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -299,7 +334,7 @@ async def push_model(
...
@@ -299,7 +334,7 @@ async def push_model(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
debug
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -331,7 +366,7 @@ async def push_model(
...
@@ -331,7 +366,7 @@ async def push_model(
try
:
try
:
return
await
run_in_threadpool
(
get_request
)
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -359,9 +394,9 @@ class CreateModelForm(BaseModel):
...
@@ -359,9 +394,9 @@ class CreateModelForm(BaseModel):
async
def
create_model
(
async
def
create_model
(
form_data
:
CreateModelForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
form_data
:
CreateModelForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
):
):
print
(
form_data
)
log
.
debug
(
f
"form_data:
{
form_data
}
"
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -383,7 +418,7 @@ async def create_model(
...
@@ -383,7 +418,7 @@ async def create_model(
r
.
raise_for_status
()
r
.
raise_for_status
()
print
(
r
)
log
.
debug
(
f
"r:
{
r
}
"
)
return
StreamingResponse
(
return
StreamingResponse
(
stream_content
(),
stream_content
(),
...
@@ -396,7 +431,7 @@ async def create_model(
...
@@ -396,7 +431,7 @@ async def create_model(
try
:
try
:
return
await
run_in_threadpool
(
get_request
)
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -434,7 +469,7 @@ async def copy_model(
...
@@ -434,7 +469,7 @@ async def copy_model(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
try
:
r
=
requests
.
request
(
r
=
requests
.
request
(
...
@@ -444,11 +479,11 @@ async def copy_model(
...
@@ -444,11 +479,11 @@ async def copy_model(
)
)
r
.
raise_for_status
()
r
.
raise_for_status
()
print
(
r
.
text
)
log
.
debug
(
f
"r.text:
{
r
.
text
}
"
)
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -481,7 +516,7 @@ async def delete_model(
...
@@ -481,7 +516,7 @@ async def delete_model(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
try
:
r
=
requests
.
request
(
r
=
requests
.
request
(
...
@@ -491,11 +526,11 @@ async def delete_model(
...
@@ -491,11 +526,11 @@ async def delete_model(
)
)
r
.
raise_for_status
()
r
.
raise_for_status
()
print
(
r
.
text
)
log
.
debug
(
f
"r.text:
{
r
.
text
}
"
)
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -521,7 +556,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
...
@@ -521,7 +556,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
])
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
])
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
try
:
r
=
requests
.
request
(
r
=
requests
.
request
(
...
@@ -533,7 +568,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
...
@@ -533,7 +568,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
return
r
.
json
()
return
r
.
json
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -573,7 +608,7 @@ async def generate_embeddings(
...
@@ -573,7 +608,7 @@ async def generate_embeddings(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
try
:
r
=
requests
.
request
(
r
=
requests
.
request
(
...
@@ -585,7 +620,7 @@ async def generate_embeddings(
...
@@ -585,7 +620,7 @@ async def generate_embeddings(
return
r
.
json
()
return
r
.
json
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -633,7 +668,7 @@ async def generate_completion(
...
@@ -633,7 +668,7 @@ async def generate_completion(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -654,7 +689,7 @@ async def generate_completion(
...
@@ -654,7 +689,7 @@ async def generate_completion(
if
request_id
in
REQUEST_POOL
:
if
request_id
in
REQUEST_POOL
:
yield
chunk
yield
chunk
else
:
else
:
pr
in
t
(
"User: canceled request"
)
log
.
warn
in
g
(
"User: canceled request"
)
break
break
finally
:
finally
:
if
hasattr
(
r
,
"close"
):
if
hasattr
(
r
,
"close"
):
...
@@ -731,11 +766,11 @@ async def generate_chat_completion(
...
@@ -731,11 +766,11 @@ async def generate_chat_completion(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
prin
t
(
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
())
log
.
debug
(
"form_data.model_dump_json(exclude_none=True).encode(): {0} "
.
forma
t
(
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
())
)
def
get_request
():
def
get_request
():
nonlocal
form_data
nonlocal
form_data
...
@@ -754,7 +789,7 @@ async def generate_chat_completion(
...
@@ -754,7 +789,7 @@ async def generate_chat_completion(
if
request_id
in
REQUEST_POOL
:
if
request_id
in
REQUEST_POOL
:
yield
chunk
yield
chunk
else
:
else
:
pr
in
t
(
"User: canceled request"
)
log
.
warn
in
g
(
"User: canceled request"
)
break
break
finally
:
finally
:
if
hasattr
(
r
,
"close"
):
if
hasattr
(
r
,
"close"
):
...
@@ -777,7 +812,7 @@ async def generate_chat_completion(
...
@@ -777,7 +812,7 @@ async def generate_chat_completion(
headers
=
dict
(
r
.
headers
),
headers
=
dict
(
r
.
headers
),
)
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
e
raise
e
try
:
try
:
...
@@ -831,7 +866,7 @@ async def generate_openai_chat_completion(
...
@@ -831,7 +866,7 @@ async def generate_openai_chat_completion(
)
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
print
(
url
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -854,7 +889,7 @@ async def generate_openai_chat_completion(
...
@@ -854,7 +889,7 @@ async def generate_openai_chat_completion(
if
request_id
in
REQUEST_POOL
:
if
request_id
in
REQUEST_POOL
:
yield
chunk
yield
chunk
else
:
else
:
pr
in
t
(
"User: canceled request"
)
log
.
warn
in
g
(
"User: canceled request"
)
break
break
finally
:
finally
:
if
hasattr
(
r
,
"close"
):
if
hasattr
(
r
,
"close"
):
...
@@ -897,6 +932,211 @@ async def generate_openai_chat_completion(
...
@@ -897,6 +932,211 @@ async def generate_openai_chat_completion(
)
)
class
UrlForm
(
BaseModel
):
url
:
str
class
UploadBlobForm
(
BaseModel
):
filename
:
str
def
parse_huggingface_url
(
hf_url
):
try
:
# Parse the URL
parsed_url
=
urlparse
(
hf_url
)
# Get the path and split it into components
path_components
=
parsed_url
.
path
.
split
(
"/"
)
# Extract the desired output
user_repo
=
"/"
.
join
(
path_components
[
1
:
3
])
model_file
=
path_components
[
-
1
]
return
model_file
except
ValueError
:
return
None
async
def
download_file_stream
(
ollama_url
,
file_url
,
file_path
,
file_name
,
chunk_size
=
1024
*
1024
):
done
=
False
if
os
.
path
.
exists
(
file_path
):
current_size
=
os
.
path
.
getsize
(
file_path
)
else
:
current_size
=
0
headers
=
{
"Range"
:
f
"bytes=
{
current_size
}
-"
}
if
current_size
>
0
else
{}
timeout
=
aiohttp
.
ClientTimeout
(
total
=
600
)
# Set the timeout
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
)
as
session
:
async
with
session
.
get
(
file_url
,
headers
=
headers
)
as
response
:
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
+
current_size
with
open
(
file_path
,
"ab+"
)
as
file
:
async
for
data
in
response
.
content
.
iter_chunked
(
chunk_size
):
current_size
+=
len
(
data
)
file
.
write
(
data
)
done
=
current_size
==
total_size
progress
=
round
((
current_size
/
total_size
)
*
100
,
2
)
yield
f
'data: {{"progress":
{
progress
}
, "completed":
{
current_size
}
, "total":
{
total_size
}
}}
\n\n
'
if
done
:
file
.
seek
(
0
)
hashed
=
calculate_sha256
(
file
)
file
.
seek
(
0
)
url
=
f
"
{
ollama_url
}
/api/blobs/sha256:
{
hashed
}
"
response
=
requests
.
post
(
url
,
data
=
file
)
if
response
.
ok
:
res
=
{
"done"
:
done
,
"blob"
:
f
"sha256:
{
hashed
}
"
,
"name"
:
file_name
,
}
os
.
remove
(
file_path
)
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
else
:
raise
"Ollama: Could not create blob, Please try again."
# def number_generator():
# for i in range(1, 101):
# yield f"data: {i}\n"
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
@
app
.
post
(
"/models/download"
)
@
app
.
post
(
"/models/download/{url_idx}"
)
async
def
download_model
(
form_data
:
UrlForm
,
url_idx
:
Optional
[
int
]
=
None
,
):
if
url_idx
==
None
:
url_idx
=
0
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
file_name
=
parse_huggingface_url
(
form_data
.
url
)
if
file_name
:
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
file_name
}
"
return
StreamingResponse
(
download_file_stream
(
url
,
form_data
.
url
,
file_path
,
file_name
),
)
else
:
return
None
@
app
.
post
(
"/models/upload"
)
@
app
.
post
(
"/models/upload/{url_idx}"
)
def
upload_model
(
file
:
UploadFile
=
File
(...),
url_idx
:
Optional
[
int
]
=
None
):
if
url_idx
==
None
:
url_idx
=
0
ollama_url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
file
.
filename
}
"
# Save file in chunks
with
open
(
file_path
,
"wb+"
)
as
f
:
for
chunk
in
file
.
file
:
f
.
write
(
chunk
)
def
file_process_stream
():
nonlocal
ollama_url
total_size
=
os
.
path
.
getsize
(
file_path
)
chunk_size
=
1024
*
1024
try
:
with
open
(
file_path
,
"rb"
)
as
f
:
total
=
0
done
=
False
while
not
done
:
chunk
=
f
.
read
(
chunk_size
)
if
not
chunk
:
done
=
True
continue
total
+=
len
(
chunk
)
progress
=
round
((
total
/
total_size
)
*
100
,
2
)
res
=
{
"progress"
:
progress
,
"total"
:
total_size
,
"completed"
:
total
,
}
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
if
done
:
f
.
seek
(
0
)
hashed
=
calculate_sha256
(
f
)
f
.
seek
(
0
)
url
=
f
"
{
ollama_url
}
/api/blobs/sha256:
{
hashed
}
"
response
=
requests
.
post
(
url
,
data
=
f
)
if
response
.
ok
:
res
=
{
"done"
:
done
,
"blob"
:
f
"sha256:
{
hashed
}
"
,
"name"
:
file
.
filename
,
}
os
.
remove
(
file_path
)
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
else
:
raise
Exception
(
"Ollama: Could not create blob, Please try again."
)
except
Exception
as
e
:
res
=
{
"error"
:
str
(
e
)}
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
return
StreamingResponse
(
file_process_stream
(),
media_type
=
"text/event-stream"
)
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# url_idx = 0
# url = app.state.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
# async def file_upload_generator(file):
# print(file)
# try:
# async with aiofiles.open(file_location, "wb") as f:
# completed_size = 0
# while True:
# chunk = await file.read(1024*1024)
# if not chunk:
# break
# await f.write(chunk)
# completed_size += len(chunk)
# progress = (completed_size / total_size) * 100
# print(progress)
# yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n'
# except Exception as e:
# print(e)
# yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n"
# finally:
# await file.close()
# print("done")
# yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n'
# return StreamingResponse(
# file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream"
# )
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
async
def
deprecated_proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_current_user
)):
async
def
deprecated_proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_current_user
)):
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
0
]
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
0
]
...
@@ -947,7 +1187,7 @@ async def deprecated_proxy(path: str, request: Request, user=Depends(get_current
...
@@ -947,7 +1187,7 @@ async def deprecated_proxy(path: str, request: Request, user=Depends(get_current
if
request_id
in
REQUEST_POOL
:
if
request_id
in
REQUEST_POOL
:
yield
chunk
yield
chunk
else
:
else
:
pr
in
t
(
"User: canceled request"
)
log
.
warn
in
g
(
"User: canceled request"
)
break
break
finally
:
finally
:
if
hasattr
(
r
,
"close"
):
if
hasattr
(
r
,
"close"
):
...
...
backend/apps/openai/main.py
View file @
ac294a74
...
@@ -6,6 +6,7 @@ import requests
...
@@ -6,6 +6,7 @@ import requests
import
aiohttp
import
aiohttp
import
asyncio
import
asyncio
import
json
import
json
import
logging
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
...
@@ -19,6 +20,7 @@ from utils.utils import (
...
@@ -19,6 +20,7 @@ from utils.utils import (
get_admin_user
,
get_admin_user
,
)
)
from
config
import
(
from
config
import
(
SRC_LOG_LEVELS
,
OPENAI_API_BASE_URLS
,
OPENAI_API_BASE_URLS
,
OPENAI_API_KEYS
,
OPENAI_API_KEYS
,
CACHE_DIR
,
CACHE_DIR
,
...
@@ -31,6 +33,9 @@ from typing import List, Optional
...
@@ -31,6 +33,9 @@ from typing import List, Optional
import
hashlib
import
hashlib
from
pathlib
import
Path
from
pathlib
import
Path
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"OPENAI"
])
app
=
FastAPI
()
app
=
FastAPI
()
app
.
add_middleware
(
app
.
add_middleware
(
CORSMiddleware
,
CORSMiddleware
,
...
@@ -134,7 +139,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
...
@@ -134,7 +139,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return
FileResponse
(
file_path
)
return
FileResponse
(
file_path
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -160,7 +165,7 @@ async def fetch_url(url, key):
...
@@ -160,7 +165,7 @@ async def fetch_url(url, key):
return
await
response
.
json
()
return
await
response
.
json
()
except
Exception
as
e
:
except
Exception
as
e
:
# Handle connection error here
# Handle connection error here
print
(
f
"Connection error:
{
e
}
"
)
log
.
error
(
f
"Connection error:
{
e
}
"
)
return
None
return
None
...
@@ -182,7 +187,7 @@ def merge_models_lists(model_lists):
...
@@ -182,7 +187,7 @@ def merge_models_lists(model_lists):
async
def
get_all_models
():
async
def
get_all_models
():
print
(
"get_all_models"
)
log
.
info
(
"get_all_models
()
"
)
if
len
(
app
.
state
.
OPENAI_API_KEYS
)
==
1
and
app
.
state
.
OPENAI_API_KEYS
[
0
]
==
""
:
if
len
(
app
.
state
.
OPENAI_API_KEYS
)
==
1
and
app
.
state
.
OPENAI_API_KEYS
[
0
]
==
""
:
models
=
{
"data"
:
[]}
models
=
{
"data"
:
[]}
...
@@ -208,7 +213,7 @@ async def get_all_models():
...
@@ -208,7 +213,7 @@ async def get_all_models():
)
)
}
}
print
(
models
)
log
.
info
(
f
"models:
{
models
}
"
)
app
.
state
.
MODELS
=
{
model
[
"id"
]:
model
for
model
in
models
[
"data"
]}
app
.
state
.
MODELS
=
{
model
[
"id"
]:
model
for
model
in
models
[
"data"
]}
return
models
return
models
...
@@ -246,7 +251,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
...
@@ -246,7 +251,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return
response_data
return
response_data
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
@@ -280,7 +285,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
...
@@ -280,7 +285,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if
body
.
get
(
"model"
)
==
"gpt-4-vision-preview"
:
if
body
.
get
(
"model"
)
==
"gpt-4-vision-preview"
:
if
"max_tokens"
not
in
body
:
if
"max_tokens"
not
in
body
:
body
[
"max_tokens"
]
=
4000
body
[
"max_tokens"
]
=
4000
print
(
"Modified body_dict:"
,
body
)
log
.
debug
(
"Modified body_dict:"
,
body
)
# Fix for ChatGPT calls failing because the num_ctx key is in body
# Fix for ChatGPT calls failing because the num_ctx key is in body
if
"num_ctx"
in
body
:
if
"num_ctx"
in
body
:
...
@@ -292,7 +297,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
...
@@ -292,7 +297,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
# Convert the modified body back to JSON
# Convert the modified body back to JSON
body
=
json
.
dumps
(
body
)
body
=
json
.
dumps
(
body
)
except
json
.
JSONDecodeError
as
e
:
except
json
.
JSONDecodeError
as
e
:
print
(
"Error loading request body into a dictionary:"
,
e
)
log
.
error
(
"Error loading request body into a dictionary:"
,
e
)
url
=
app
.
state
.
OPENAI_API_BASE_URLS
[
idx
]
url
=
app
.
state
.
OPENAI_API_BASE_URLS
[
idx
]
key
=
app
.
state
.
OPENAI_API_KEYS
[
idx
]
key
=
app
.
state
.
OPENAI_API_KEYS
[
idx
]
...
@@ -330,7 +335,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
...
@@ -330,7 +335,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
response_data
=
r
.
json
()
response_data
=
r
.
json
()
return
response_data
return
response_data
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
...
...
backend/apps/rag/main.py
View file @
ac294a74
...
@@ -8,7 +8,7 @@ from fastapi import (
...
@@ -8,7 +8,7 @@ from fastapi import (
Form
,
Form
,
)
)
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
import
os
,
shutil
import
os
,
shutil
,
logging
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
from
typing
import
List
...
@@ -54,6 +54,7 @@ from utils.misc import (
...
@@ -54,6 +54,7 @@ from utils.misc import (
)
)
from
utils.utils
import
get_current_user
,
get_admin_user
from
utils.utils
import
get_current_user
,
get_admin_user
from
config
import
(
from
config
import
(
SRC_LOG_LEVELS
,
UPLOAD_DIR
,
UPLOAD_DIR
,
DOCS_DIR
,
DOCS_DIR
,
RAG_EMBEDDING_MODEL
,
RAG_EMBEDDING_MODEL
,
...
@@ -66,6 +67,9 @@ from config import (
...
@@ -66,6 +67,9 @@ from config import (
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
#
#
# if RAG_EMBEDDING_MODEL:
# if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer(
# sentence_transformer_ef = SentenceTransformer(
...
@@ -110,40 +114,6 @@ class CollectionNameForm(BaseModel):
...
@@ -110,40 +114,6 @@ class CollectionNameForm(BaseModel):
class
StoreWebForm
(
CollectionNameForm
):
class
StoreWebForm
(
CollectionNameForm
):
url
:
str
url
:
str
def
store_data_in_vector_db
(
data
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
text_splitter
=
RecursiveCharacterTextSplitter
(
chunk_size
=
app
.
state
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
CHUNK_OVERLAP
)
docs
=
text_splitter
.
split_documents
(
data
)
texts
=
[
doc
.
page_content
for
doc
in
docs
]
metadatas
=
[
doc
.
metadata
for
doc
in
docs
]
try
:
if
overwrite
:
for
collection
in
CHROMA_CLIENT
.
list_collections
():
if
collection_name
==
collection
.
name
:
print
(
f
"deleting existing collection
{
collection_name
}
"
)
CHROMA_CLIENT
.
delete_collection
(
name
=
collection_name
)
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
collection
.
add
(
documents
=
texts
,
metadatas
=
metadatas
,
ids
=
[
str
(
uuid
.
uuid1
())
for
_
in
texts
]
)
return
True
except
Exception
as
e
:
print
(
e
)
if
e
.
__class__
.
__name__
==
"UniqueConstraintError"
:
return
True
return
False
@
app
.
get
(
"/"
)
@
app
.
get
(
"/"
)
async
def
get_status
():
async
def
get_status
():
return
{
return
{
...
@@ -274,7 +244,7 @@ def query_doc_handler(
...
@@ -274,7 +244,7 @@ def query_doc_handler(
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
...
@@ -318,13 +288,63 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
...
@@ -318,13 +288,63 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
"filename"
:
form_data
.
url
,
"filename"
:
form_data
.
url
,
}
}
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
)
)
def
store_data_in_vector_db
(
data
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
text_splitter
=
RecursiveCharacterTextSplitter
(
chunk_size
=
app
.
state
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
CHUNK_OVERLAP
,
add_start_index
=
True
,
)
docs
=
text_splitter
.
split_documents
(
data
)
return
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
)
def
store_text_in_vector_db
(
text
,
metadata
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
text_splitter
=
RecursiveCharacterTextSplitter
(
chunk_size
=
app
.
state
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
CHUNK_OVERLAP
,
add_start_index
=
True
,
)
docs
=
text_splitter
.
create_documents
([
text
],
metadatas
=
[
metadata
])
return
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
)
def
store_docs_in_vector_db
(
docs
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
texts
=
[
doc
.
page_content
for
doc
in
docs
]
metadatas
=
[
doc
.
metadata
for
doc
in
docs
]
try
:
if
overwrite
:
for
collection
in
CHROMA_CLIENT
.
list_collections
():
if
collection_name
==
collection
.
name
:
print
(
f
"deleting existing collection
{
collection_name
}
"
)
CHROMA_CLIENT
.
delete_collection
(
name
=
collection_name
)
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
,
embedding_function
=
app
.
state
.
sentence_transformer_ef
,
)
collection
.
add
(
documents
=
texts
,
metadatas
=
metadatas
,
ids
=
[
str
(
uuid
.
uuid1
())
for
_
in
texts
]
)
return
True
except
Exception
as
e
:
print
(
e
)
if
e
.
__class__
.
__name__
==
"UniqueConstraintError"
:
return
True
return
False
def
get_loader
(
filename
:
str
,
file_content_type
:
str
,
file_path
:
str
):
def
get_loader
(
filename
:
str
,
file_content_type
:
str
,
file_path
:
str
):
file_ext
=
filename
.
split
(
"."
)[
-
1
].
lower
()
file_ext
=
filename
.
split
(
"."
)[
-
1
].
lower
()
known_type
=
True
known_type
=
True
...
@@ -416,7 +436,7 @@ def store_doc(
...
@@ -416,7 +436,7 @@ def store_doc(
):
):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
print
(
file
.
content_type
)
log
.
info
(
f
"file.content_type:
{
file
.
content_type
}
"
)
try
:
try
:
filename
=
file
.
filename
filename
=
file
.
filename
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
filename
}
"
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
filename
}
"
...
@@ -447,7 +467,7 @@ def store_doc(
...
@@ -447,7 +467,7 @@ def store_doc(
detail
=
ERROR_MESSAGES
.
DEFAULT
(),
detail
=
ERROR_MESSAGES
.
DEFAULT
(),
)
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
if
"No pandoc was found"
in
str
(
e
):
if
"No pandoc was found"
in
str
(
e
):
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
...
@@ -460,6 +480,37 @@ def store_doc(
...
@@ -460,6 +480,37 @@ def store_doc(
)
)
class
TextRAGForm
(
BaseModel
):
name
:
str
content
:
str
collection_name
:
Optional
[
str
]
=
None
@
app
.
post
(
"/text"
)
def
store_text
(
form_data
:
TextRAGForm
,
user
=
Depends
(
get_current_user
),
):
collection_name
=
form_data
.
collection_name
if
collection_name
==
None
:
collection_name
=
calculate_sha256_string
(
form_data
.
content
)
result
=
store_text_in_vector_db
(
form_data
.
content
,
metadata
=
{
"name"
:
form_data
.
name
,
"created_by"
:
user
.
id
},
collection_name
=
collection_name
,
)
if
result
:
return
{
"status"
:
True
,
"collection_name"
:
collection_name
}
else
:
raise
HTTPException
(
status_code
=
status
.
HTTP_500_INTERNAL_SERVER_ERROR
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(),
)
@
app
.
get
(
"/scan"
)
@
app
.
get
(
"/scan"
)
def
scan_docs_dir
(
user
=
Depends
(
get_admin_user
)):
def
scan_docs_dir
(
user
=
Depends
(
get_admin_user
)):
for
path
in
Path
(
DOCS_DIR
).
rglob
(
"./**/*"
):
for
path
in
Path
(
DOCS_DIR
).
rglob
(
"./**/*"
):
...
@@ -512,7 +563,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
...
@@ -512,7 +563,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
)
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
return
True
return
True
...
@@ -533,11 +584,11 @@ def reset(user=Depends(get_admin_user)) -> bool:
...
@@ -533,11 +584,11 @@ def reset(user=Depends(get_admin_user)) -> bool:
elif
os
.
path
.
isdir
(
file_path
):
elif
os
.
path
.
isdir
(
file_path
):
shutil
.
rmtree
(
file_path
)
shutil
.
rmtree
(
file_path
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"Failed to delete %s. Reason: %s"
%
(
file_path
,
e
))
log
.
error
(
"Failed to delete %s. Reason: %s"
%
(
file_path
,
e
))
try
:
try
:
CHROMA_CLIENT
.
reset
()
CHROMA_CLIENT
.
reset
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
return
True
return
True
backend/apps/rag/utils.py
View file @
ac294a74
import
re
import
re
import
logging
from
typing
import
List
from
typing
import
List
from
config
import
CHROMA_CLIENT
from
config
import
SRC_LOG_LEVELS
,
CHROMA_CLIENT
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
def
query_doc
(
collection_name
:
str
,
query
:
str
,
k
:
int
,
embedding_function
):
def
query_doc
(
collection_name
:
str
,
query
:
str
,
k
:
int
,
embedding_function
):
...
@@ -97,7 +101,7 @@ def rag_template(template: str, context: str, query: str):
...
@@ -97,7 +101,7 @@ def rag_template(template: str, context: str, query: str):
def
rag_messages
(
docs
,
messages
,
template
,
k
,
embedding_function
):
def
rag_messages
(
docs
,
messages
,
template
,
k
,
embedding_function
):
print
(
docs
)
log
.
debug
(
f
"docs:
{
docs
}
"
)
last_user_message_idx
=
None
last_user_message_idx
=
None
for
i
in
range
(
len
(
messages
)
-
1
,
-
1
,
-
1
):
for
i
in
range
(
len
(
messages
)
-
1
,
-
1
,
-
1
):
...
@@ -137,6 +141,8 @@ def rag_messages(docs, messages, template, k, embedding_function):
...
@@ -137,6 +141,8 @@ def rag_messages(docs, messages, template, k, embedding_function):
k
=
k
,
k
=
k
,
embedding_function
=
embedding_function
,
embedding_function
=
embedding_function
,
)
)
elif
doc
[
"type"
]
==
"text"
:
context
=
doc
[
"content"
]
else
:
else
:
context
=
query_doc
(
context
=
query_doc
(
collection_name
=
doc
[
"collection_name"
],
collection_name
=
doc
[
"collection_name"
],
...
@@ -145,7 +151,7 @@ def rag_messages(docs, messages, template, k, embedding_function):
...
@@ -145,7 +151,7 @@ def rag_messages(docs, messages, template, k, embedding_function):
embedding_function
=
embedding_function
,
embedding_function
=
embedding_function
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
context
=
None
context
=
None
relevant_contexts
.
append
(
context
)
relevant_contexts
.
append
(
context
)
...
...
backend/apps/web/internal/db.py
View file @
ac294a74
from
peewee
import
*
from
peewee
import
*
from
config
import
DATA_DIR
from
config
import
SRC_LOG_LEVELS
,
DATA_DIR
import
os
import
os
import
logging
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"DB"
])
# Check if the file exists
# Check if the file exists
if
os
.
path
.
exists
(
f
"
{
DATA_DIR
}
/ollama.db"
):
if
os
.
path
.
exists
(
f
"
{
DATA_DIR
}
/ollama.db"
):
# Rename the file
# Rename the file
os
.
rename
(
f
"
{
DATA_DIR
}
/ollama.db"
,
f
"
{
DATA_DIR
}
/webui.db"
)
os
.
rename
(
f
"
{
DATA_DIR
}
/ollama.db"
,
f
"
{
DATA_DIR
}
/webui.db"
)
print
(
"File renamed successfully."
)
log
.
info
(
"File renamed successfully."
)
else
:
else
:
pass
pass
...
...
backend/apps/web/models/auths.py
View file @
ac294a74
...
@@ -2,6 +2,7 @@ from pydantic import BaseModel
...
@@ -2,6 +2,7 @@ from pydantic import BaseModel
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
import
time
import
time
import
uuid
import
uuid
import
logging
from
peewee
import
*
from
peewee
import
*
from
apps.web.models.users
import
UserModel
,
Users
from
apps.web.models.users
import
UserModel
,
Users
...
@@ -9,6 +10,10 @@ from utils.utils import verify_password
...
@@ -9,6 +10,10 @@ from utils.utils import verify_password
from
apps.web.internal.db
import
DB
from
apps.web.internal.db
import
DB
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
####################
####################
# DB MODEL
# DB MODEL
####################
####################
...
@@ -86,7 +91,7 @@ class AuthsTable:
...
@@ -86,7 +91,7 @@ class AuthsTable:
def
insert_new_auth
(
def
insert_new_auth
(
self
,
email
:
str
,
password
:
str
,
name
:
str
,
role
:
str
=
"pending"
self
,
email
:
str
,
password
:
str
,
name
:
str
,
role
:
str
=
"pending"
)
->
Optional
[
UserModel
]:
)
->
Optional
[
UserModel
]:
print
(
"insert_new_auth"
)
log
.
info
(
"insert_new_auth"
)
id
=
str
(
uuid
.
uuid4
())
id
=
str
(
uuid
.
uuid4
())
...
@@ -103,7 +108,7 @@ class AuthsTable:
...
@@ -103,7 +108,7 @@ class AuthsTable:
return
None
return
None
def
authenticate_user
(
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
def
authenticate_user
(
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
print
(
"authenticate_user
"
,
email
)
log
.
info
(
f
"authenticate_user
:
{
email
}
"
)
try
:
try
:
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
if
auth
:
if
auth
:
...
...
backend/apps/web/models/chats.py
View file @
ac294a74
...
@@ -95,20 +95,6 @@ class ChatTable:
...
@@ -95,20 +95,6 @@ class ChatTable:
except
:
except
:
return
None
return
None
def
update_chat_by_id
(
self
,
id
:
str
,
chat
:
dict
)
->
Optional
[
ChatModel
]:
try
:
query
=
Chat
.
update
(
chat
=
json
.
dumps
(
chat
),
title
=
chat
[
"title"
]
if
"title"
in
chat
else
"New Chat"
,
timestamp
=
int
(
time
.
time
()),
).
where
(
Chat
.
id
==
id
)
query
.
execute
()
chat
=
Chat
.
get
(
Chat
.
id
==
id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
except
:
return
None
def
get_chat_lists_by_user_id
(
def
get_chat_lists_by_user_id
(
self
,
user_id
:
str
,
skip
:
int
=
0
,
limit
:
int
=
50
self
,
user_id
:
str
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
)
->
List
[
ChatModel
]:
...
...
backend/apps/web/models/documents.py
View file @
ac294a74
...
@@ -3,6 +3,7 @@ from peewee import *
...
@@ -3,6 +3,7 @@ from peewee import *
from
playhouse.shortcuts
import
model_to_dict
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
import
time
import
time
import
logging
from
utils.utils
import
decode_token
from
utils.utils
import
decode_token
from
utils.misc
import
get_gravatar_url
from
utils.misc
import
get_gravatar_url
...
@@ -11,6 +12,10 @@ from apps.web.internal.db import DB
...
@@ -11,6 +12,10 @@ from apps.web.internal.db import DB
import
json
import
json
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
####################
####################
# Documents DB Schema
# Documents DB Schema
####################
####################
...
@@ -118,7 +123,7 @@ class DocumentsTable:
...
@@ -118,7 +123,7 @@ class DocumentsTable:
doc
=
Document
.
get
(
Document
.
name
==
form_data
.
name
)
doc
=
Document
.
get
(
Document
.
name
==
form_data
.
name
)
return
DocumentModel
(
**
model_to_dict
(
doc
))
return
DocumentModel
(
**
model_to_dict
(
doc
))
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
return
None
return
None
def
update_doc_content_by_name
(
def
update_doc_content_by_name
(
...
@@ -138,7 +143,7 @@ class DocumentsTable:
...
@@ -138,7 +143,7 @@ class DocumentsTable:
doc
=
Document
.
get
(
Document
.
name
==
name
)
doc
=
Document
.
get
(
Document
.
name
==
name
)
return
DocumentModel
(
**
model_to_dict
(
doc
))
return
DocumentModel
(
**
model_to_dict
(
doc
))
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
return
None
return
None
def
delete_doc_by_name
(
self
,
name
:
str
)
->
bool
:
def
delete_doc_by_name
(
self
,
name
:
str
)
->
bool
:
...
...
backend/apps/web/models/tags.py
View file @
ac294a74
...
@@ -6,9 +6,14 @@ from playhouse.shortcuts import model_to_dict
...
@@ -6,9 +6,14 @@ from playhouse.shortcuts import model_to_dict
import
json
import
json
import
uuid
import
uuid
import
time
import
time
import
logging
from
apps.web.internal.db
import
DB
from
apps.web.internal.db
import
DB
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
####################
####################
# Tag DB Schema
# Tag DB Schema
####################
####################
...
@@ -173,7 +178,7 @@ class TagTable:
...
@@ -173,7 +178,7 @@ class TagTable:
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
)
(
ChatIdTag
.
tag_name
==
tag_name
)
&
(
ChatIdTag
.
user_id
==
user_id
)
)
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
print
(
res
)
log
.
debug
(
f
"res:
{
res
}
"
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
if
tag_count
==
0
:
...
@@ -185,7 +190,7 @@ class TagTable:
...
@@ -185,7 +190,7 @@ class TagTable:
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"delete_tag
"
,
e
)
log
.
error
(
f
"delete_tag
:
{
e
}
"
)
return
False
return
False
def
delete_tag_by_tag_name_and_chat_id_and_user_id
(
def
delete_tag_by_tag_name_and_chat_id_and_user_id
(
...
@@ -198,7 +203,7 @@ class TagTable:
...
@@ -198,7 +203,7 @@ class TagTable:
&
(
ChatIdTag
.
user_id
==
user_id
)
&
(
ChatIdTag
.
user_id
==
user_id
)
)
)
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
res
=
query
.
execute
()
# Remove the rows, return number of rows removed.
print
(
res
)
log
.
debug
(
f
"res:
{
res
}
"
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
tag_count
=
self
.
count_chat_ids_by_tag_name_and_user_id
(
tag_name
,
user_id
)
if
tag_count
==
0
:
if
tag_count
==
0
:
...
@@ -210,7 +215,7 @@ class TagTable:
...
@@ -210,7 +215,7 @@ class TagTable:
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"delete_tag
"
,
e
)
log
.
error
(
f
"delete_tag
:
{
e
}
"
)
return
False
return
False
def
delete_tags_by_chat_id_and_user_id
(
self
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
def
delete_tags_by_chat_id_and_user_id
(
self
,
chat_id
:
str
,
user_id
:
str
)
->
bool
:
...
...
backend/apps/web/routers/chats.py
View file @
ac294a74
...
@@ -5,6 +5,7 @@ from utils.utils import get_current_user, get_admin_user
...
@@ -5,6 +5,7 @@ from utils.utils import get_current_user, get_admin_user
from
fastapi
import
APIRouter
from
fastapi
import
APIRouter
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
import
json
import
json
import
logging
from
apps.web.models.users
import
Users
from
apps.web.models.users
import
Users
from
apps.web.models.chats
import
(
from
apps.web.models.chats
import
(
...
@@ -27,6 +28,10 @@ from apps.web.models.tags import (
...
@@ -27,6 +28,10 @@ from apps.web.models.tags import (
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
router
=
APIRouter
()
router
=
APIRouter
()
############################
############################
...
@@ -78,7 +83,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
...
@@ -78,7 +83,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
chat
=
Chats
.
insert_new_chat
(
user
.
id
,
form_data
)
chat
=
Chats
.
insert_new_chat
(
user
.
id
,
form_data
)
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
()
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
()
)
)
...
@@ -95,7 +100,7 @@ async def get_all_tags(user=Depends(get_current_user)):
...
@@ -95,7 +100,7 @@ async def get_all_tags(user=Depends(get_current_user)):
tags
=
Tags
.
get_tags_by_user_id
(
user
.
id
)
tags
=
Tags
.
get_tags_by_user_id
(
user
.
id
)
return
tags
return
tags
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
()
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
()
)
)
...
...
backend/apps/web/routers/users.py
View file @
ac294a74
...
@@ -7,6 +7,7 @@ from fastapi import APIRouter
...
@@ -7,6 +7,7 @@ from fastapi import APIRouter
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
import
time
import
time
import
uuid
import
uuid
import
logging
from
apps.web.models.users
import
UserModel
,
UserUpdateForm
,
UserRoleUpdateForm
,
Users
from
apps.web.models.users
import
UserModel
,
UserUpdateForm
,
UserRoleUpdateForm
,
Users
from
apps.web.models.auths
import
Auths
from
apps.web.models.auths
import
Auths
...
@@ -14,6 +15,10 @@ from apps.web.models.auths import Auths
...
@@ -14,6 +15,10 @@ from apps.web.models.auths import Auths
from
utils.utils
import
get_current_user
,
get_password_hash
,
get_admin_user
from
utils.utils
import
get_current_user
,
get_password_hash
,
get_admin_user
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
router
=
APIRouter
()
router
=
APIRouter
()
############################
############################
...
@@ -83,7 +88,7 @@ async def update_user_by_id(
...
@@ -83,7 +88,7 @@ async def update_user_by_id(
if
form_data
.
password
:
if
form_data
.
password
:
hashed
=
get_password_hash
(
form_data
.
password
)
hashed
=
get_password_hash
(
form_data
.
password
)
print
(
hashed
)
log
.
debug
(
f
"hashed:
{
hashed
}
"
)
Auths
.
update_user_password_by_id
(
user_id
,
hashed
)
Auths
.
update_user_password_by_id
(
user_id
,
hashed
)
Auths
.
update_email_by_id
(
user_id
,
form_data
.
email
.
lower
())
Auths
.
update_email_by_id
(
user_id
,
form_data
.
email
.
lower
())
...
...
backend/apps/web/routers/utils.py
View file @
ac294a74
...
@@ -21,155 +21,6 @@ from constants import ERROR_MESSAGES
...
@@ -21,155 +21,6 @@ from constants import ERROR_MESSAGES
router
=
APIRouter
()
router
=
APIRouter
()
class
UploadBlobForm
(
BaseModel
):
filename
:
str
from
urllib.parse
import
urlparse
def
parse_huggingface_url
(
hf_url
):
try
:
# Parse the URL
parsed_url
=
urlparse
(
hf_url
)
# Get the path and split it into components
path_components
=
parsed_url
.
path
.
split
(
"/"
)
# Extract the desired output
user_repo
=
"/"
.
join
(
path_components
[
1
:
3
])
model_file
=
path_components
[
-
1
]
return
model_file
except
ValueError
:
return
None
async
def
download_file_stream
(
url
,
file_path
,
file_name
,
chunk_size
=
1024
*
1024
):
done
=
False
if
os
.
path
.
exists
(
file_path
):
current_size
=
os
.
path
.
getsize
(
file_path
)
else
:
current_size
=
0
headers
=
{
"Range"
:
f
"bytes=
{
current_size
}
-"
}
if
current_size
>
0
else
{}
timeout
=
aiohttp
.
ClientTimeout
(
total
=
600
)
# Set the timeout
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
)
as
session
:
async
with
session
.
get
(
url
,
headers
=
headers
)
as
response
:
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
+
current_size
with
open
(
file_path
,
"ab+"
)
as
file
:
async
for
data
in
response
.
content
.
iter_chunked
(
chunk_size
):
current_size
+=
len
(
data
)
file
.
write
(
data
)
done
=
current_size
==
total_size
progress
=
round
((
current_size
/
total_size
)
*
100
,
2
)
yield
f
'data: {{"progress":
{
progress
}
, "completed":
{
current_size
}
, "total":
{
total_size
}
}}
\n\n
'
if
done
:
file
.
seek
(
0
)
hashed
=
calculate_sha256
(
file
)
file
.
seek
(
0
)
url
=
f
"
{
OLLAMA_BASE_URLS
[
0
]
}
/api/blobs/sha256:
{
hashed
}
"
response
=
requests
.
post
(
url
,
data
=
file
)
if
response
.
ok
:
res
=
{
"done"
:
done
,
"blob"
:
f
"sha256:
{
hashed
}
"
,
"name"
:
file_name
,
}
os
.
remove
(
file_path
)
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
else
:
raise
"Ollama: Could not create blob, Please try again."
@
router
.
get
(
"/download"
)
async
def
download
(
url
:
str
,
):
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
file_name
=
parse_huggingface_url
(
url
)
if
file_name
:
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
file_name
}
"
return
StreamingResponse
(
download_file_stream
(
url
,
file_path
,
file_name
),
media_type
=
"text/event-stream"
,
)
else
:
return
None
@
router
.
post
(
"/upload"
)
def
upload
(
file
:
UploadFile
=
File
(...)):
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
file
.
filename
}
"
# Save file in chunks
with
open
(
file_path
,
"wb+"
)
as
f
:
for
chunk
in
file
.
file
:
f
.
write
(
chunk
)
def
file_process_stream
():
total_size
=
os
.
path
.
getsize
(
file_path
)
chunk_size
=
1024
*
1024
try
:
with
open
(
file_path
,
"rb"
)
as
f
:
total
=
0
done
=
False
while
not
done
:
chunk
=
f
.
read
(
chunk_size
)
if
not
chunk
:
done
=
True
continue
total
+=
len
(
chunk
)
progress
=
round
((
total
/
total_size
)
*
100
,
2
)
res
=
{
"progress"
:
progress
,
"total"
:
total_size
,
"completed"
:
total
,
}
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
if
done
:
f
.
seek
(
0
)
hashed
=
calculate_sha256
(
f
)
f
.
seek
(
0
)
url
=
f
"
{
OLLAMA_BASE_URLS
[
0
]
}
/blobs/sha256:
{
hashed
}
"
response
=
requests
.
post
(
url
,
data
=
f
)
if
response
.
ok
:
res
=
{
"done"
:
done
,
"blob"
:
f
"sha256:
{
hashed
}
"
,
"name"
:
file
.
filename
,
}
os
.
remove
(
file_path
)
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
else
:
raise
Exception
(
"Ollama: Could not create blob, Please try again."
)
except
Exception
as
e
:
res
=
{
"error"
:
str
(
e
)}
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
return
StreamingResponse
(
file_process_stream
(),
media_type
=
"text/event-stream"
)
@
router
.
get
(
"/gravatar"
)
@
router
.
get
(
"/gravatar"
)
async
def
get_gravatar
(
async
def
get_gravatar
(
email
:
str
,
email
:
str
,
...
...
backend/config.py
View file @
ac294a74
import
os
import
os
import
sys
import
logging
import
chromadb
import
chromadb
from
chromadb
import
Settings
from
chromadb
import
Settings
from
base64
import
b64encode
from
base64
import
b64encode
...
@@ -21,7 +23,7 @@ try:
...
@@ -21,7 +23,7 @@ try:
load_dotenv
(
find_dotenv
(
"../.env"
))
load_dotenv
(
find_dotenv
(
"../.env"
))
except
ImportError
:
except
ImportError
:
pr
in
t
(
"dotenv not installed, skipping..."
)
log
.
warn
in
g
(
"dotenv not installed, skipping..."
)
WEBUI_NAME
=
"Open WebUI"
WEBUI_NAME
=
"Open WebUI"
shutil
.
copyfile
(
"../build/favicon.png"
,
"./static/favicon.png"
)
shutil
.
copyfile
(
"../build/favicon.png"
,
"./static/favicon.png"
)
...
@@ -100,6 +102,34 @@ for version in soup.find_all("h2"):
...
@@ -100,6 +102,34 @@ for version in soup.find_all("h2"):
CHANGELOG
=
changelog_json
CHANGELOG
=
changelog_json
####################################
# LOGGING
####################################
log_levels
=
[
"CRITICAL"
,
"ERROR"
,
"WARNING"
,
"INFO"
,
"DEBUG"
]
GLOBAL_LOG_LEVEL
=
os
.
environ
.
get
(
"GLOBAL_LOG_LEVEL"
,
""
).
upper
()
if
GLOBAL_LOG_LEVEL
in
log_levels
:
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
GLOBAL_LOG_LEVEL
,
force
=
True
)
else
:
GLOBAL_LOG_LEVEL
=
"INFO"
log
=
logging
.
getLogger
(
__name__
)
log
.
info
(
f
"GLOBAL_LOG_LEVEL:
{
GLOBAL_LOG_LEVEL
}
"
)
log_sources
=
[
"AUDIO"
,
"CONFIG"
,
"DB"
,
"IMAGES"
,
"LITELLM"
,
"MAIN"
,
"MODELS"
,
"OLLAMA"
,
"OPENAI"
,
"RAG"
]
SRC_LOG_LEVELS
=
{}
for
source
in
log_sources
:
log_env_var
=
source
+
"_LOG_LEVEL"
SRC_LOG_LEVELS
[
source
]
=
os
.
environ
.
get
(
log_env_var
,
""
).
upper
()
if
SRC_LOG_LEVELS
[
source
]
not
in
log_levels
:
SRC_LOG_LEVELS
[
source
]
=
GLOBAL_LOG_LEVEL
log
.
info
(
f
"
{
log_env_var
}
:
{
SRC_LOG_LEVELS
[
source
]
}
"
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"CONFIG"
])
####################################
####################################
# CUSTOM_NAME
# CUSTOM_NAME
####################################
####################################
...
@@ -125,7 +155,7 @@ if CUSTOM_NAME:
...
@@ -125,7 +155,7 @@ if CUSTOM_NAME:
WEBUI_NAME
=
data
[
"name"
]
WEBUI_NAME
=
data
[
"name"
]
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
log
.
exception
(
e
)
pass
pass
...
@@ -194,9 +224,9 @@ def create_config_file(file_path):
...
@@ -194,9 +224,9 @@ def create_config_file(file_path):
LITELLM_CONFIG_PATH
=
f
"
{
DATA_DIR
}
/litellm/config.yaml"
LITELLM_CONFIG_PATH
=
f
"
{
DATA_DIR
}
/litellm/config.yaml"
if
not
os
.
path
.
exists
(
LITELLM_CONFIG_PATH
):
if
not
os
.
path
.
exists
(
LITELLM_CONFIG_PATH
):
print
(
"Config file doesn't exist. Creating..."
)
log
.
info
(
"Config file doesn't exist. Creating..."
)
create_config_file
(
LITELLM_CONFIG_PATH
)
create_config_file
(
LITELLM_CONFIG_PATH
)
print
(
"Config file created successfully."
)
log
.
info
(
"Config file created successfully."
)
####################################
####################################
...
@@ -376,3 +406,4 @@ WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models"
...
@@ -376,3 +406,4 @@ WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models"
####################################
####################################
AUTOMATIC1111_BASE_URL
=
os
.
getenv
(
"AUTOMATIC1111_BASE_URL"
,
""
)
AUTOMATIC1111_BASE_URL
=
os
.
getenv
(
"AUTOMATIC1111_BASE_URL"
,
""
)
COMFYUI_BASE_URL
=
os
.
getenv
(
"COMFYUI_BASE_URL"
,
""
)
backend/main.py
View file @
ac294a74
...
@@ -4,6 +4,7 @@ import markdown
...
@@ -4,6 +4,7 @@ import markdown
import
time
import
time
import
os
import
os
import
sys
import
sys
import
logging
import
requests
import
requests
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
...
@@ -38,10 +39,15 @@ from config import (
...
@@ -38,10 +39,15 @@ from config import (
FRONTEND_BUILD_DIR
,
FRONTEND_BUILD_DIR
,
MODEL_FILTER_ENABLED
,
MODEL_FILTER_ENABLED
,
MODEL_FILTER_LIST
,
MODEL_FILTER_LIST
,
GLOBAL_LOG_LEVEL
,
SRC_LOG_LEVELS
,
WEBHOOK_URL
,
WEBHOOK_URL
,
)
)
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
GLOBAL_LOG_LEVEL
)
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MAIN"
])
class
SPAStaticFiles
(
StaticFiles
):
class
SPAStaticFiles
(
StaticFiles
):
async
def
get_response
(
self
,
path
:
str
,
scope
):
async
def
get_response
(
self
,
path
:
str
,
scope
):
...
@@ -70,7 +76,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
...
@@ -70,7 +76,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
if
request
.
method
==
"POST"
and
(
if
request
.
method
==
"POST"
and
(
"/api/chat"
in
request
.
url
.
path
or
"/chat/completions"
in
request
.
url
.
path
"/api/chat"
in
request
.
url
.
path
or
"/chat/completions"
in
request
.
url
.
path
):
):
print
(
request
.
url
.
path
)
log
.
debug
(
f
"request.url.path:
{
request
.
url
.
path
}
"
)
# Read the original request body
# Read the original request body
body
=
await
request
.
body
()
body
=
await
request
.
body
()
...
@@ -93,7 +99,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
...
@@ -93,7 +99,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
)
)
del
data
[
"docs"
]
del
data
[
"docs"
]
print
(
data
[
"
messages"
]
)
log
.
debug
(
f
"data['messages']:
{
data
[
'
messages
'
]
}
"
)
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment