Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
8b0144cd
Unverified
Commit
8b0144cd
authored
May 13, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
May 13, 2024
Browse files
Merge pull request #2156 from cheahjs/feat/save-config
feat: save UI config changes to config.json
parents
7e0d3496
0c033b5b
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
627 additions
and
364 deletions
+627
-364
backend/apps/audio/main.py
backend/apps/audio/main.py
+20
-19
backend/apps/images/main.py
backend/apps/images/main.py
+89
-68
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+29
-25
backend/apps/openai/main.py
backend/apps/openai/main.py
+25
-18
backend/apps/rag/main.py
backend/apps/rag/main.py
+126
-107
backend/apps/web/main.py
backend/apps/web/main.py
+13
-9
backend/apps/web/routers/auths.py
backend/apps/web/routers/auths.py
+16
-16
backend/apps/web/routers/configs.py
backend/apps/web/routers/configs.py
+4
-4
backend/apps/web/routers/users.py
backend/apps/web/routers/users.py
+3
-3
backend/config.py
backend/config.py
+278
-73
backend/main.py
backend/main.py
+24
-22
No files found.
backend/apps/audio/main.py
View file @
8b0144cd
...
...
@@ -45,6 +45,7 @@ from config import (
AUDIO_OPENAI_API_KEY
,
AUDIO_OPENAI_API_MODEL
,
AUDIO_OPENAI_API_VOICE
,
AppConfig
,
)
log
=
logging
.
getLogger
(
__name__
)
...
...
@@ -59,11 +60,11 @@ app.add_middleware(
allow_headers
=
[
"*"
],
)
app
.
state
.
OPENAI_API_BASE_URL
=
AUDIO_OPENAI_API_BASE_URL
app
.
state
.
OPENAI_API_KEY
=
AUDIO_OPENAI_API_KEY
app
.
state
.
OPENAI_API_MODEL
=
AUDIO_OPENAI_API_MODEL
app
.
state
.
OPENAI_API_VOICE
=
AUDIO_OPENAI_API_VOICE
app
.
state
.
config
=
AppConfig
()
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
AUDIO_OPENAI_API_BASE_URL
app
.
state
.
config
.
OPENAI_API_KEY
=
AUDIO_OPENAI_API_KEY
app
.
state
.
config
.
OPENAI_API_MODEL
=
AUDIO_OPENAI_API_MODEL
app
.
state
.
config
.
OPENAI_API_VOICE
=
AUDIO_OPENAI_API_VOICE
# setting device type for whisper model
whisper_device_type
=
DEVICE_TYPE
if
DEVICE_TYPE
and
DEVICE_TYPE
==
"cuda"
else
"cpu"
...
...
@@ -83,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel):
@
app
.
get
(
"/config"
)
async
def
get_openai_config
(
user
=
Depends
(
get_admin_user
)):
return
{
"OPENAI_API_BASE_URL"
:
app
.
state
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
OPENAI_API_KEY
,
"OPENAI_API_MODEL"
:
app
.
state
.
OPENAI_API_MODEL
,
"OPENAI_API_VOICE"
:
app
.
state
.
OPENAI_API_VOICE
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
"OPENAI_API_MODEL"
:
app
.
state
.
config
.
OPENAI_API_MODEL
,
"OPENAI_API_VOICE"
:
app
.
state
.
config
.
OPENAI_API_VOICE
,
}
...
...
@@ -97,17 +98,17 @@ async def update_openai_config(
if
form_data
.
key
==
""
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
API_KEY_NOT_FOUND
)
app
.
state
.
OPENAI_API_BASE_URL
=
form_data
.
url
app
.
state
.
OPENAI_API_KEY
=
form_data
.
key
app
.
state
.
OPENAI_API_MODEL
=
form_data
.
model
app
.
state
.
OPENAI_API_VOICE
=
form_data
.
speaker
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
form_data
.
url
app
.
state
.
config
.
OPENAI_API_KEY
=
form_data
.
key
app
.
state
.
config
.
OPENAI_API_MODEL
=
form_data
.
model
app
.
state
.
config
.
OPENAI_API_VOICE
=
form_data
.
speaker
return
{
"status"
:
True
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
OPENAI_API_KEY
,
"OPENAI_API_MODEL"
:
app
.
state
.
OPENAI_API_MODEL
,
"OPENAI_API_VOICE"
:
app
.
state
.
OPENAI_API_VOICE
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
"OPENAI_API_MODEL"
:
app
.
state
.
config
.
OPENAI_API_MODEL
,
"OPENAI_API_VOICE"
:
app
.
state
.
config
.
OPENAI_API_VOICE
,
}
...
...
@@ -124,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return
FileResponse
(
file_path
)
headers
=
{}
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
OPENAI_API_KEY
}
"
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
config
.
OPENAI_API_KEY
}
"
headers
[
"Content-Type"
]
=
"application/json"
r
=
None
try
:
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
OPENAI_API_BASE_URL
}
/audio/speech"
,
url
=
f
"
{
app
.
state
.
config
.
OPENAI_API_BASE_URL
}
/audio/speech"
,
data
=
body
,
headers
=
headers
,
stream
=
True
,
...
...
backend/apps/images/main.py
View file @
8b0144cd
...
...
@@ -42,6 +42,7 @@ from config import (
IMAGE_GENERATION_MODEL
,
IMAGE_SIZE
,
IMAGE_STEPS
,
AppConfig
,
)
...
...
@@ -60,26 +61,31 @@ app.add_middleware(
allow_headers
=
[
"*"
],
)
app
.
state
.
ENGINE
=
IMAGE_GENERATION_ENGINE
app
.
state
.
ENABLED
=
ENABLE_IMAGE_GENERATION
app
.
state
.
config
=
AppConfig
()
app
.
state
.
OPENAI_API_BASE_URL
=
IMAGES_OPENAI_API_BASE_URL
app
.
state
.
OPENAI_API_KEY
=
IMAGES_OPENAI_API_KEY
app
.
state
.
config
.
ENGINE
=
IMAGE_GENERATION_ENGINE
app
.
state
.
config
.
ENABLED
=
ENABLE_IMAGE_GENERATION
app
.
state
.
MODEL
=
IMAGE_GENERATION_MODEL
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
IMAGES_OPENAI_API_BASE_URL
app
.
state
.
config
.
OPENAI_API_KEY
=
IMAGES_OPENAI_API_KEY
app
.
state
.
config
.
MODEL
=
IMAGE_GENERATION_MODEL
app
.
state
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
app
.
state
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
app
.
state
.
config
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
app
.
state
.
IMAGE_SIZE
=
IMAGE_SIZE
app
.
state
.
IMAGE_STEPS
=
IMAGE_STEPS
app
.
state
.
config
.
IMAGE_SIZE
=
IMAGE_SIZE
app
.
state
.
config
.
IMAGE_STEPS
=
IMAGE_STEPS
@
app
.
get
(
"/config"
)
async
def
get_config
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
{
"engine"
:
app
.
state
.
ENGINE
,
"enabled"
:
app
.
state
.
ENABLED
}
return
{
"engine"
:
app
.
state
.
config
.
ENGINE
,
"enabled"
:
app
.
state
.
config
.
ENABLED
,
}
class
ConfigUpdateForm
(
BaseModel
):
...
...
@@ -89,9 +95,12 @@ class ConfigUpdateForm(BaseModel):
@
app
.
post
(
"/config/update"
)
async
def
update_config
(
form_data
:
ConfigUpdateForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
ENGINE
=
form_data
.
engine
app
.
state
.
ENABLED
=
form_data
.
enabled
return
{
"engine"
:
app
.
state
.
ENGINE
,
"enabled"
:
app
.
state
.
ENABLED
}
app
.
state
.
config
.
ENGINE
=
form_data
.
engine
app
.
state
.
config
.
ENABLED
=
form_data
.
enabled
return
{
"engine"
:
app
.
state
.
config
.
ENGINE
,
"enabled"
:
app
.
state
.
config
.
ENABLED
,
}
class
EngineUrlUpdateForm
(
BaseModel
):
...
...
@@ -102,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel):
@
app
.
get
(
"/url"
)
async
def
get_engine_url
(
user
=
Depends
(
get_admin_user
)):
return
{
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
app
.
state
.
COMFYUI_BASE_URL
,
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
app
.
state
.
config
.
COMFYUI_BASE_URL
,
}
...
...
@@ -113,29 +122,29 @@ async def update_engine_url(
):
if
form_data
.
AUTOMATIC1111_BASE_URL
==
None
:
app
.
state
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
else
:
url
=
form_data
.
AUTOMATIC1111_BASE_URL
.
strip
(
"/"
)
try
:
r
=
requests
.
head
(
url
)
app
.
state
.
AUTOMATIC1111_BASE_URL
=
url
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
=
url
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
if
form_data
.
COMFYUI_BASE_URL
==
None
:
app
.
state
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
app
.
state
.
config
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
else
:
url
=
form_data
.
COMFYUI_BASE_URL
.
strip
(
"/"
)
try
:
r
=
requests
.
head
(
url
)
app
.
state
.
COMFYUI_BASE_URL
=
url
app
.
state
.
config
.
COMFYUI_BASE_URL
=
url
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
return
{
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
app
.
state
.
COMFYUI_BASE_URL
,
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
app
.
state
.
config
.
COMFYUI_BASE_URL
,
"status"
:
True
,
}
...
...
@@ -148,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel):
@
app
.
get
(
"/openai/config"
)
async
def
get_openai_config
(
user
=
Depends
(
get_admin_user
)):
return
{
"OPENAI_API_BASE_URL"
:
app
.
state
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
OPENAI_API_KEY
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
}
...
...
@@ -160,13 +169,13 @@ async def update_openai_config(
if
form_data
.
key
==
""
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
API_KEY_NOT_FOUND
)
app
.
state
.
OPENAI_API_BASE_URL
=
form_data
.
url
app
.
state
.
OPENAI_API_KEY
=
form_data
.
key
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
form_data
.
url
app
.
state
.
config
.
OPENAI_API_KEY
=
form_data
.
key
return
{
"status"
:
True
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
OPENAI_API_KEY
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
}
...
...
@@ -176,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel):
@
app
.
get
(
"/size"
)
async
def
get_image_size
(
user
=
Depends
(
get_admin_user
)):
return
{
"IMAGE_SIZE"
:
app
.
state
.
IMAGE_SIZE
}
return
{
"IMAGE_SIZE"
:
app
.
state
.
config
.
IMAGE_SIZE
}
@
app
.
post
(
"/size/update"
)
...
...
@@ -185,9 +194,9 @@ async def update_image_size(
):
pattern
=
r
"^\d+x\d+$"
# Regular expression pattern
if
re
.
match
(
pattern
,
form_data
.
size
):
app
.
state
.
IMAGE_SIZE
=
form_data
.
size
app
.
state
.
config
.
IMAGE_SIZE
=
form_data
.
size
return
{
"IMAGE_SIZE"
:
app
.
state
.
IMAGE_SIZE
,
"IMAGE_SIZE"
:
app
.
state
.
config
.
IMAGE_SIZE
,
"status"
:
True
,
}
else
:
...
...
@@ -203,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel):
@
app
.
get
(
"/steps"
)
async
def
get_image_size
(
user
=
Depends
(
get_admin_user
)):
return
{
"IMAGE_STEPS"
:
app
.
state
.
IMAGE_STEPS
}
return
{
"IMAGE_STEPS"
:
app
.
state
.
config
.
IMAGE_STEPS
}
@
app
.
post
(
"/steps/update"
)
...
...
@@ -211,9 +220,9 @@ async def update_image_size(
form_data
:
ImageStepsUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
if
form_data
.
steps
>=
0
:
app
.
state
.
IMAGE_STEPS
=
form_data
.
steps
app
.
state
.
config
.
IMAGE_STEPS
=
form_data
.
steps
return
{
"IMAGE_STEPS"
:
app
.
state
.
IMAGE_STEPS
,
"IMAGE_STEPS"
:
app
.
state
.
config
.
IMAGE_STEPS
,
"status"
:
True
,
}
else
:
...
...
@@ -226,14 +235,14 @@ async def update_image_size(
@
app
.
get
(
"/models"
)
def
get_models
(
user
=
Depends
(
get_current_user
)):
try
:
if
app
.
state
.
ENGINE
==
"openai"
:
if
app
.
state
.
config
.
ENGINE
==
"openai"
:
return
[
{
"id"
:
"dall-e-2"
,
"name"
:
"DALL·E 2"
},
{
"id"
:
"dall-e-3"
,
"name"
:
"DALL·E 3"
},
]
elif
app
.
state
.
ENGINE
==
"comfyui"
:
elif
app
.
state
.
config
.
ENGINE
==
"comfyui"
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
COMFYUI_BASE_URL
}
/object_info"
)
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
config
.
COMFYUI_BASE_URL
}
/object_info"
)
info
=
r
.
json
()
return
list
(
...
...
@@ -245,7 +254,7 @@ def get_models(user=Depends(get_current_user)):
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/sd-models"
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/sd-models"
)
models
=
r
.
json
()
return
list
(
...
...
@@ -255,23 +264,29 @@ def get_models(user=Depends(get_current_user)):
)
)
except
Exception
as
e
:
app
.
state
.
ENABLED
=
False
app
.
state
.
config
.
ENABLED
=
False
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
@
app
.
get
(
"/models/default"
)
async
def
get_default_model
(
user
=
Depends
(
get_admin_user
)):
try
:
if
app
.
state
.
ENGINE
==
"openai"
:
return
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
else
"dall-e-2"
}
elif
app
.
state
.
ENGINE
==
"comfyui"
:
return
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
else
""
}
if
app
.
state
.
config
.
ENGINE
==
"openai"
:
return
{
"model"
:
(
app
.
state
.
config
.
MODEL
if
app
.
state
.
config
.
MODEL
else
"dall-e-2"
)
}
elif
app
.
state
.
config
.
ENGINE
==
"comfyui"
:
return
{
"model"
:
(
app
.
state
.
config
.
MODEL
if
app
.
state
.
config
.
MODEL
else
""
)}
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
return
{
"model"
:
options
[
"sd_model_checkpoint"
]}
except
Exception
as
e
:
app
.
state
.
ENABLED
=
False
app
.
state
.
config
.
ENABLED
=
False
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
...
...
@@ -280,20 +295,20 @@ class UpdateModelForm(BaseModel):
def
set_model_handler
(
model
:
str
):
if
app
.
state
.
ENGINE
==
"openai"
:
app
.
state
.
MODEL
=
model
return
app
.
state
.
MODEL
if
app
.
state
.
ENGINE
==
"comfyui"
:
app
.
state
.
MODEL
=
model
return
app
.
state
.
MODEL
if
app
.
state
.
config
.
ENGINE
in
[
"openai"
,
"comfyui"
]:
app
.
state
.
config
.
MODEL
=
model
return
app
.
state
.
config
.
MODEL
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
if
model
!=
options
[
"sd_model_checkpoint"
]:
options
[
"sd_model_checkpoint"
]
=
model
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
,
json
=
options
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
,
json
=
options
,
)
return
options
...
...
@@ -382,26 +397,32 @@ def generate_image(
user
=
Depends
(
get_current_user
),
):
width
,
height
=
tuple
(
map
(
int
,
app
.
state
.
IMAGE_SIZE
.
split
(
"x"
))
)
width
,
height
=
tuple
(
map
(
int
,
app
.
state
.
config
.
IMAGE_SIZE
)
.
split
(
"x"
))
r
=
None
try
:
if
app
.
state
.
ENGINE
==
"openai"
:
if
app
.
state
.
config
.
ENGINE
==
"openai"
:
headers
=
{}
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
OPENAI_API_KEY
}
"
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
config
.
OPENAI_API_KEY
}
"
headers
[
"Content-Type"
]
=
"application/json"
data
=
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
!=
""
else
"dall-e-2"
,
"model"
:
(
app
.
state
.
config
.
MODEL
if
app
.
state
.
config
.
MODEL
!=
""
else
"dall-e-2"
),
"prompt"
:
form_data
.
prompt
,
"n"
:
form_data
.
n
,
"size"
:
form_data
.
size
if
form_data
.
size
else
app
.
state
.
IMAGE_SIZE
,
"size"
:
(
form_data
.
size
if
form_data
.
size
else
app
.
state
.
config
.
IMAGE_SIZE
),
"response_format"
:
"b64_json"
,
}
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
OPENAI_API_BASE_URL
}
/images/generations"
,
url
=
f
"
{
app
.
state
.
config
.
OPENAI_API_BASE_URL
}
/images/generations"
,
json
=
data
,
headers
=
headers
,
)
...
...
@@ -421,7 +442,7 @@ def generate_image(
return
images
elif
app
.
state
.
ENGINE
==
"comfyui"
:
elif
app
.
state
.
config
.
ENGINE
==
"comfyui"
:
data
=
{
"prompt"
:
form_data
.
prompt
,
...
...
@@ -430,19 +451,19 @@ def generate_image(
"n"
:
form_data
.
n
,
}
if
app
.
state
.
IMAGE_STEPS
!=
None
:
data
[
"steps"
]
=
app
.
state
.
IMAGE_STEPS
if
app
.
state
.
config
.
IMAGE_STEPS
is
not
None
:
data
[
"steps"
]
=
app
.
state
.
config
.
IMAGE_STEPS
if
form_data
.
negative_prompt
!=
None
:
if
form_data
.
negative_prompt
is
not
None
:
data
[
"negative_prompt"
]
=
form_data
.
negative_prompt
data
=
ImageGenerationPayload
(
**
data
)
res
=
comfyui_generate_image
(
app
.
state
.
MODEL
,
app
.
state
.
config
.
MODEL
,
data
,
user
.
id
,
app
.
state
.
COMFYUI_BASE_URL
,
app
.
state
.
config
.
COMFYUI_BASE_URL
,
)
log
.
debug
(
f
"res:
{
res
}
"
)
...
...
@@ -469,14 +490,14 @@ def generate_image(
"height"
:
height
,
}
if
app
.
state
.
IMAGE_STEPS
!=
None
:
data
[
"steps"
]
=
app
.
state
.
IMAGE_STEPS
if
app
.
state
.
config
.
IMAGE_STEPS
is
not
None
:
data
[
"steps"
]
=
app
.
state
.
config
.
IMAGE_STEPS
if
form_data
.
negative_prompt
!=
None
:
if
form_data
.
negative_prompt
is
not
None
:
data
[
"negative_prompt"
]
=
form_data
.
negative_prompt
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/txt2img"
,
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/txt2img"
,
json
=
data
,
)
...
...
backend/apps/ollama/main.py
View file @
8b0144cd
...
...
@@ -46,6 +46,7 @@ from config import (
ENABLE_MODEL_FILTER
,
MODEL_FILTER_LIST
,
UPLOAD_DIR
,
AppConfig
,
)
from
utils.misc
import
calculate_sha256
...
...
@@ -61,11 +62,12 @@ app.add_middleware(
allow_headers
=
[
"*"
],
)
app
.
state
.
config
=
AppConfig
()
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
OLLAMA_BASE_URLS
=
OLLAMA_BASE_URLS
app
.
state
.
config
.
OLLAMA_BASE_URLS
=
OLLAMA_BASE_URLS
app
.
state
.
MODELS
=
{}
...
...
@@ -96,7 +98,7 @@ async def get_status():
@
app
.
get
(
"/urls"
)
async
def
get_ollama_api_urls
(
user
=
Depends
(
get_admin_user
)):
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
OLLAMA_BASE_URLS
}
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
config
.
OLLAMA_BASE_URLS
}
class
UrlUpdateForm
(
BaseModel
):
...
...
@@ -105,10 +107,10 @@ class UrlUpdateForm(BaseModel):
@
app
.
post
(
"/urls/update"
)
async
def
update_ollama_api_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
OLLAMA_BASE_URLS
=
form_data
.
urls
app
.
state
.
config
.
OLLAMA_BASE_URLS
=
form_data
.
urls
log
.
info
(
f
"app.state.OLLAMA_BASE_URLS:
{
app
.
state
.
OLLAMA_BASE_URLS
}
"
)
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
OLLAMA_BASE_URLS
}
log
.
info
(
f
"app.state.
config.
OLLAMA_BASE_URLS:
{
app
.
state
.
config
.
OLLAMA_BASE_URLS
}
"
)
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
config
.
OLLAMA_BASE_URLS
}
@
app
.
get
(
"/cancel/{request_id}"
)
...
...
@@ -153,7 +155,7 @@ def merge_models_lists(model_lists):
async
def
get_all_models
():
log
.
info
(
"get_all_models()"
)
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/tags"
)
for
url
in
app
.
state
.
OLLAMA_BASE_URLS
]
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/tags"
)
for
url
in
app
.
state
.
config
.
OLLAMA_BASE_URLS
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
models
=
{
...
...
@@ -186,7 +188,7 @@ async def get_ollama_tags(
return
models
return
models
else
:
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
try
:
r
=
requests
.
request
(
method
=
"GET"
,
url
=
f
"
{
url
}
/api/tags"
)
r
.
raise_for_status
()
...
...
@@ -216,7 +218,9 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
if
url_idx
==
None
:
# returns lowest version
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/version"
)
for
url
in
app
.
state
.
OLLAMA_BASE_URLS
]
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/version"
)
for
url
in
app
.
state
.
config
.
OLLAMA_BASE_URLS
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
responses
=
list
(
filter
(
lambda
x
:
x
is
not
None
,
responses
))
...
...
@@ -235,7 +239,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
detail
=
ERROR_MESSAGES
.
OLLAMA_NOT_FOUND
,
)
else
:
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
try
:
r
=
requests
.
request
(
method
=
"GET"
,
url
=
f
"
{
url
}
/api/version"
)
r
.
raise_for_status
()
...
...
@@ -267,7 +271,7 @@ class ModelNameForm(BaseModel):
async
def
pull_model
(
form_data
:
ModelNameForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
):
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -355,7 +359,7 @@ async def push_model(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
name
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
debug
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -417,7 +421,7 @@ async def create_model(
form_data
:
CreateModelForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
):
log
.
debug
(
f
"form_data:
{
form_data
}
"
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -490,7 +494,7 @@ async def copy_model(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
source
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
try
:
...
...
@@ -537,7 +541,7 @@ async def delete_model(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
name
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
try
:
...
...
@@ -577,7 +581,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
)
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
])
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
try
:
...
...
@@ -634,7 +638,7 @@ async def generate_embeddings(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
try
:
...
...
@@ -684,7 +688,7 @@ def generate_ollama_embeddings(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
try
:
...
...
@@ -753,7 +757,7 @@ async def generate_completion(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -856,7 +860,7 @@ async def generate_chat_completion(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -965,7 +969,7 @@ async def generate_openai_chat_completion(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -1064,7 +1068,7 @@ async def get_openai_models(
}
else
:
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
try
:
r
=
requests
.
request
(
method
=
"GET"
,
url
=
f
"
{
url
}
/api/tags"
)
r
.
raise_for_status
()
...
...
@@ -1198,7 +1202,7 @@ async def download_model(
if
url_idx
==
None
:
url_idx
=
0
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
file_name
=
parse_huggingface_url
(
form_data
.
url
)
...
...
@@ -1217,7 +1221,7 @@ async def download_model(
def
upload_model
(
file
:
UploadFile
=
File
(...),
url_idx
:
Optional
[
int
]
=
None
):
if
url_idx
==
None
:
url_idx
=
0
ollama_url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
ollama_url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
file
.
filename
}
"
...
...
@@ -1282,7 +1286,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# url_idx = 0
# url = app.state.OLLAMA_BASE_URLS[url_idx]
# url = app.state.
config.
OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
...
...
@@ -1319,7 +1323,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
async
def
deprecated_proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_verified_user
)
):
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
0
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
0
]
target_url
=
f
"
{
url
}
/
{
path
}
"
body
=
await
request
.
body
()
...
...
backend/apps/openai/main.py
View file @
8b0144cd
...
...
@@ -26,6 +26,7 @@ from config import (
CACHE_DIR
,
ENABLE_MODEL_FILTER
,
MODEL_FILTER_LIST
,
AppConfig
,
)
from
typing
import
List
,
Optional
...
...
@@ -45,11 +46,13 @@ app.add_middleware(
allow_headers
=
[
"*"
],
)
app
.
state
.
config
=
AppConfig
()
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
OPENAI_API_BASE_URLS
=
OPENAI_API_BASE_URLS
app
.
state
.
OPENAI_API_KEYS
=
OPENAI_API_KEYS
app
.
state
.
config
.
OPENAI_API_BASE_URLS
=
OPENAI_API_BASE_URLS
app
.
state
.
config
.
OPENAI_API_KEYS
=
OPENAI_API_KEYS
app
.
state
.
MODELS
=
{}
...
...
@@ -75,32 +78,32 @@ class KeysUpdateForm(BaseModel):
@
app
.
get
(
"/urls"
)
async
def
get_openai_urls
(
user
=
Depends
(
get_admin_user
)):
return
{
"OPENAI_API_BASE_URLS"
:
app
.
state
.
OPENAI_API_BASE_URLS
}
return
{
"OPENAI_API_BASE_URLS"
:
app
.
state
.
config
.
OPENAI_API_BASE_URLS
}
@
app
.
post
(
"/urls/update"
)
async
def
update_openai_urls
(
form_data
:
UrlsUpdateForm
,
user
=
Depends
(
get_admin_user
)):
await
get_all_models
()
app
.
state
.
OPENAI_API_BASE_URLS
=
form_data
.
urls
return
{
"OPENAI_API_BASE_URLS"
:
app
.
state
.
OPENAI_API_BASE_URLS
}
app
.
state
.
config
.
OPENAI_API_BASE_URLS
=
form_data
.
urls
return
{
"OPENAI_API_BASE_URLS"
:
app
.
state
.
config
.
OPENAI_API_BASE_URLS
}
@
app
.
get
(
"/keys"
)
async
def
get_openai_keys
(
user
=
Depends
(
get_admin_user
)):
return
{
"OPENAI_API_KEYS"
:
app
.
state
.
OPENAI_API_KEYS
}
return
{
"OPENAI_API_KEYS"
:
app
.
state
.
config
.
OPENAI_API_KEYS
}
@
app
.
post
(
"/keys/update"
)
async
def
update_openai_key
(
form_data
:
KeysUpdateForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
OPENAI_API_KEYS
=
form_data
.
keys
return
{
"OPENAI_API_KEYS"
:
app
.
state
.
OPENAI_API_KEYS
}
app
.
state
.
config
.
OPENAI_API_KEYS
=
form_data
.
keys
return
{
"OPENAI_API_KEYS"
:
app
.
state
.
config
.
OPENAI_API_KEYS
}
@
app
.
post
(
"/audio/speech"
)
async
def
speech
(
request
:
Request
,
user
=
Depends
(
get_verified_user
)):
idx
=
None
try
:
idx
=
app
.
state
.
OPENAI_API_BASE_URLS
.
index
(
"https://api.openai.com/v1"
)
idx
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
.
index
(
"https://api.openai.com/v1"
)
body
=
await
request
.
body
()
name
=
hashlib
.
sha256
(
body
).
hexdigest
()
...
...
@@ -114,7 +117,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return
FileResponse
(
file_path
)
headers
=
{}
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
OPENAI_API_KEYS
[
idx
]
}
"
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
]
}
"
headers
[
"Content-Type"
]
=
"application/json"
if
"openrouter.ai"
in
app
.
state
.
OPENAI_API_BASE_URLS
[
idx
]:
headers
[
'HTTP-Referer'
]
=
"https://openwebui.com/"
...
...
@@ -122,7 +125,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
r
=
None
try
:
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
OPENAI_API_BASE_URLS
[
idx
]
}
/audio/speech"
,
url
=
f
"
{
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
}
/audio/speech"
,
data
=
body
,
headers
=
headers
,
stream
=
True
,
...
...
@@ -182,7 +185,8 @@ def merge_models_lists(model_lists):
[
{
**
model
,
"urlIdx"
:
idx
}
for
model
in
models
if
"api.openai.com"
not
in
app
.
state
.
OPENAI_API_BASE_URLS
[
idx
]
if
"api.openai.com"
not
in
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
or
"gpt"
in
model
[
"id"
]
]
)
...
...
@@ -193,12 +197,15 @@ def merge_models_lists(model_lists):
async
def
get_all_models
():
log
.
info
(
"get_all_models()"
)
if
len
(
app
.
state
.
OPENAI_API_KEYS
)
==
1
and
app
.
state
.
OPENAI_API_KEYS
[
0
]
==
""
:
if
(
len
(
app
.
state
.
config
.
OPENAI_API_KEYS
)
==
1
and
app
.
state
.
config
.
OPENAI_API_KEYS
[
0
]
==
""
):
models
=
{
"data"
:
[]}
else
:
tasks
=
[
fetch_url
(
f
"
{
url
}
/models"
,
app
.
state
.
OPENAI_API_KEYS
[
idx
])
for
idx
,
url
in
enumerate
(
app
.
state
.
OPENAI_API_BASE_URLS
)
fetch_url
(
f
"
{
url
}
/models"
,
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
])
for
idx
,
url
in
enumerate
(
app
.
state
.
config
.
OPENAI_API_BASE_URLS
)
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
...
...
@@ -241,7 +248,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return
models
return
models
else
:
url
=
app
.
state
.
OPENAI_API_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
url_idx
]
r
=
None
...
...
@@ -305,8 +312,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except
json
.
JSONDecodeError
as
e
:
log
.
error
(
"Error loading request body into a dictionary:"
,
e
)
url
=
app
.
state
.
OPENAI_API_BASE_URLS
[
idx
]
key
=
app
.
state
.
OPENAI_API_KEYS
[
idx
]
url
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
key
=
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
]
target_url
=
f
"
{
url
}
/
{
path
}
"
...
...
backend/apps/rag/main.py
View file @
8b0144cd
...
...
@@ -93,6 +93,7 @@ from config import (
RAG_TEMPLATE
,
ENABLE_RAG_LOCAL_WEB_FETCH
,
YOUTUBE_LOADER_LANGUAGE
,
AppConfig
,
)
from
constants
import
ERROR_MESSAGES
...
...
@@ -102,30 +103,32 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app
=
FastAPI
()
app
.
state
.
TOP_K
=
RAG_TOP_K
app
.
state
.
RELEVANCE_THRESHOLD
=
RAG_RELEVANCE_THRESHOLD
app
.
state
.
config
=
AppConfig
()
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
=
ENABLE_RAG_HYBRID_SEARCH
app
.
state
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
=
(
app
.
state
.
config
.
TOP_K
=
RAG_TOP_K
app
.
state
.
config
.
RELEVANCE_THRESHOLD
=
RAG_RELEVANCE_THRESHOLD
app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
=
ENABLE_RAG_HYBRID_SEARCH
app
.
state
.
config
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
=
(
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app
.
state
.
CHUNK_SIZE
=
CHUNK_SIZE
app
.
state
.
CHUNK_OVERLAP
=
CHUNK_OVERLAP
app
.
state
.
config
.
CHUNK_SIZE
=
CHUNK_SIZE
app
.
state
.
config
.
CHUNK_OVERLAP
=
CHUNK_OVERLAP
app
.
state
.
RAG_EMBEDDING_ENGINE
=
RAG_EMBEDDING_ENGINE
app
.
state
.
RAG_EMBEDDING_MODEL
=
RAG_EMBEDDING_MODEL
app
.
state
.
RAG_RERANKING_MODEL
=
RAG_RERANKING_MODEL
app
.
state
.
RAG_TEMPLATE
=
RAG_TEMPLATE
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
=
RAG_EMBEDDING_ENGINE
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
=
RAG_EMBEDDING_MODEL
app
.
state
.
config
.
RAG_RERANKING_MODEL
=
RAG_RERANKING_MODEL
app
.
state
.
config
.
RAG_TEMPLATE
=
RAG_TEMPLATE
app
.
state
.
OPENAI_API_BASE_URL
=
RAG_OPENAI_API_BASE_URL
app
.
state
.
OPENAI_API_KEY
=
RAG_OPENAI_API_KEY
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
RAG_OPENAI_API_BASE_URL
app
.
state
.
config
.
OPENAI_API_KEY
=
RAG_OPENAI_API_KEY
app
.
state
.
PDF_EXTRACT_IMAGES
=
PDF_EXTRACT_IMAGES
app
.
state
.
config
.
PDF_EXTRACT_IMAGES
=
PDF_EXTRACT_IMAGES
app
.
state
.
YOUTUBE_LOADER_LANGUAGE
=
YOUTUBE_LOADER_LANGUAGE
app
.
state
.
config
.
YOUTUBE_LOADER_LANGUAGE
=
YOUTUBE_LOADER_LANGUAGE
app
.
state
.
YOUTUBE_LOADER_TRANSLATION
=
None
...
...
@@ -133,7 +136,7 @@ def update_embedding_model(
embedding_model
:
str
,
update_model
:
bool
=
False
,
):
if
embedding_model
and
app
.
state
.
RAG_EMBEDDING_ENGINE
==
""
:
if
embedding_model
and
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
==
""
:
app
.
state
.
sentence_transformer_ef
=
sentence_transformers
.
SentenceTransformer
(
get_model_path
(
embedding_model
,
update_model
),
device
=
DEVICE_TYPE
,
...
...
@@ -158,22 +161,22 @@ def update_reranking_model(
update_embedding_model
(
app
.
state
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
RAG_EMBEDDING_MODEL_AUTO_UPDATE
,
)
update_reranking_model
(
app
.
state
.
RAG_RERANKING_MODEL
,
app
.
state
.
config
.
RAG_RERANKING_MODEL
,
RAG_RERANKING_MODEL_AUTO_UPDATE
,
)
app
.
state
.
EMBEDDING_FUNCTION
=
get_embedding_function
(
app
.
state
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
sentence_transformer_ef
,
app
.
state
.
OPENAI_API_KEY
,
app
.
state
.
OPENAI_API_BASE_URL
,
app
.
state
.
config
.
OPENAI_API_KEY
,
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
)
origins
=
[
"*"
]
...
...
@@ -200,12 +203,12 @@ class UrlForm(CollectionNameForm):
async
def
get_status
():
return
{
"status"
:
True
,
"chunk_size"
:
app
.
state
.
CHUNK_SIZE
,
"chunk_overlap"
:
app
.
state
.
CHUNK_OVERLAP
,
"template"
:
app
.
state
.
RAG_TEMPLATE
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"reranking_model"
:
app
.
state
.
RAG_RERANKING_MODEL
,
"chunk_size"
:
app
.
state
.
config
.
CHUNK_SIZE
,
"chunk_overlap"
:
app
.
state
.
config
.
CHUNK_OVERLAP
,
"template"
:
app
.
state
.
config
.
RAG_TEMPLATE
,
"embedding_engine"
:
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
"reranking_model"
:
app
.
state
.
config
.
RAG_RERANKING_MODEL
,
}
...
...
@@ -213,18 +216,21 @@ async def get_status():
async
def
get_embedding_config
(
user
=
Depends
(
get_admin_user
)):
return
{
"status"
:
True
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"embedding_engine"
:
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
"openai_config"
:
{
"url"
:
app
.
state
.
OPENAI_API_BASE_URL
,
"key"
:
app
.
state
.
OPENAI_API_KEY
,
"url"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"key"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
},
}
@
app
.
get
(
"/reranking"
)
async
def
get_reraanking_config
(
user
=
Depends
(
get_admin_user
)):
return
{
"status"
:
True
,
"reranking_model"
:
app
.
state
.
RAG_RERANKING_MODEL
}
return
{
"status"
:
True
,
"reranking_model"
:
app
.
state
.
config
.
RAG_RERANKING_MODEL
,
}
class
OpenAIConfigForm
(
BaseModel
):
...
...
@@ -243,34 +249,34 @@ async def update_embedding_config(
form_data
:
EmbeddingModelUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
log
.
info
(
f
"Updating embedding model:
{
app
.
state
.
RAG_EMBEDDING_MODEL
}
to
{
form_data
.
embedding_model
}
"
f
"Updating embedding model:
{
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
}
to
{
form_data
.
embedding_model
}
"
)
try
:
app
.
state
.
RAG_EMBEDDING_ENGINE
=
form_data
.
embedding_engine
app
.
state
.
RAG_EMBEDDING_MODEL
=
form_data
.
embedding_model
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
=
form_data
.
embedding_engine
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
=
form_data
.
embedding_model
if
app
.
state
.
RAG_EMBEDDING_ENGINE
in
[
"ollama"
,
"openai"
]:
if
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
in
[
"ollama"
,
"openai"
]:
if
form_data
.
openai_config
!=
None
:
app
.
state
.
OPENAI_API_BASE_URL
=
form_data
.
openai_config
.
url
app
.
state
.
OPENAI_API_KEY
=
form_data
.
openai_config
.
key
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
form_data
.
openai_config
.
url
app
.
state
.
config
.
OPENAI_API_KEY
=
form_data
.
openai_config
.
key
update_embedding_model
(
app
.
state
.
RAG_EMBEDDING_MODEL
,
True
)
update_embedding_model
(
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
)
,
True
app
.
state
.
EMBEDDING_FUNCTION
=
get_embedding_function
(
app
.
state
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
sentence_transformer_ef
,
app
.
state
.
OPENAI_API_KEY
,
app
.
state
.
OPENAI_API_BASE_URL
,
app
.
state
.
config
.
OPENAI_API_KEY
,
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
)
return
{
"status"
:
True
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"embedding_engine"
:
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
"openai_config"
:
{
"url"
:
app
.
state
.
OPENAI_API_BASE_URL
,
"key"
:
app
.
state
.
OPENAI_API_KEY
,
"url"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"key"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
},
}
except
Exception
as
e
:
...
...
@@ -290,16 +296,16 @@ async def update_reranking_config(
form_data
:
RerankingModelUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
log
.
info
(
f
"Updating reranking model:
{
app
.
state
.
RAG_RERANKING_MODEL
}
to
{
form_data
.
reranking_model
}
"
f
"Updating reranking model:
{
app
.
state
.
config
.
RAG_RERANKING_MODEL
}
to
{
form_data
.
reranking_model
}
"
)
try
:
app
.
state
.
RAG_RERANKING_MODEL
=
form_data
.
reranking_model
app
.
state
.
config
.
RAG_RERANKING_MODEL
=
form_data
.
reranking_model
update_reranking_model
(
app
.
state
.
RAG_RERANKING_MODEL
,
True
)
update_reranking_model
(
app
.
state
.
config
.
RAG_RERANKING_MODEL
)
,
True
return
{
"status"
:
True
,
"reranking_model"
:
app
.
state
.
RAG_RERANKING_MODEL
,
"reranking_model"
:
app
.
state
.
config
.
RAG_RERANKING_MODEL
,
}
except
Exception
as
e
:
log
.
exception
(
f
"Problem updating reranking model:
{
e
}
"
)
...
...
@@ -313,14 +319,14 @@ async def update_reranking_config(
async
def
get_rag_config
(
user
=
Depends
(
get_admin_user
)):
return
{
"status"
:
True
,
"pdf_extract_images"
:
app
.
state
.
PDF_EXTRACT_IMAGES
,
"pdf_extract_images"
:
app
.
state
.
config
.
PDF_EXTRACT_IMAGES
,
"chunk"
:
{
"chunk_size"
:
app
.
state
.
CHUNK_SIZE
,
"chunk_overlap"
:
app
.
state
.
CHUNK_OVERLAP
,
"chunk_size"
:
app
.
state
.
config
.
CHUNK_SIZE
,
"chunk_overlap"
:
app
.
state
.
config
.
CHUNK_OVERLAP
,
},
"web_loader_ssl_verification"
:
app
.
state
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
,
"web_loader_ssl_verification"
:
app
.
state
.
config
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
,
"youtube"
:
{
"language"
:
app
.
state
.
YOUTUBE_LOADER_LANGUAGE
,
"language"
:
app
.
state
.
config
.
YOUTUBE_LOADER_LANGUAGE
,
"translation"
:
app
.
state
.
YOUTUBE_LOADER_TRANSLATION
,
},
}
...
...
@@ -345,50 +351,52 @@ class ConfigUpdateForm(BaseModel):
@
app
.
post
(
"/config/update"
)
async
def
update_rag_config
(
form_data
:
ConfigUpdateForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
PDF_EXTRACT_IMAGES
=
(
app
.
state
.
config
.
PDF_EXTRACT_IMAGES
=
(
form_data
.
pdf_extract_images
if
form_data
.
pdf_extract_images
!=
None
else
app
.
state
.
PDF_EXTRACT_IMAGES
if
form_data
.
pdf_extract_images
is
not
None
else
app
.
state
.
config
.
PDF_EXTRACT_IMAGES
)
app
.
state
.
CHUNK_SIZE
=
(
form_data
.
chunk
.
chunk_size
if
form_data
.
chunk
!=
None
else
app
.
state
.
CHUNK_SIZE
app
.
state
.
config
.
CHUNK_SIZE
=
(
form_data
.
chunk
.
chunk_size
if
form_data
.
chunk
is
not
None
else
app
.
state
.
config
.
CHUNK_SIZE
)
app
.
state
.
CHUNK_OVERLAP
=
(
app
.
state
.
config
.
CHUNK_OVERLAP
=
(
form_data
.
chunk
.
chunk_overlap
if
form_data
.
chunk
!=
None
else
app
.
state
.
CHUNK_OVERLAP
if
form_data
.
chunk
is
not
None
else
app
.
state
.
config
.
CHUNK_OVERLAP
)
app
.
state
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
=
(
app
.
state
.
config
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
=
(
form_data
.
web_loader_ssl_verification
if
form_data
.
web_loader_ssl_verification
!=
None
else
app
.
state
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
else
app
.
state
.
config
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app
.
state
.
YOUTUBE_LOADER_LANGUAGE
=
(
app
.
state
.
config
.
YOUTUBE_LOADER_LANGUAGE
=
(
form_data
.
youtube
.
language
if
form_data
.
youtube
!=
None
else
app
.
state
.
YOUTUBE_LOADER_LANGUAGE
if
form_data
.
youtube
is
not
None
else
app
.
state
.
config
.
YOUTUBE_LOADER_LANGUAGE
)
app
.
state
.
YOUTUBE_LOADER_TRANSLATION
=
(
form_data
.
youtube
.
translation
if
form_data
.
youtube
!=
None
if
form_data
.
youtube
is
not
None
else
app
.
state
.
YOUTUBE_LOADER_TRANSLATION
)
return
{
"status"
:
True
,
"pdf_extract_images"
:
app
.
state
.
PDF_EXTRACT_IMAGES
,
"pdf_extract_images"
:
app
.
state
.
config
.
PDF_EXTRACT_IMAGES
,
"chunk"
:
{
"chunk_size"
:
app
.
state
.
CHUNK_SIZE
,
"chunk_overlap"
:
app
.
state
.
CHUNK_OVERLAP
,
"chunk_size"
:
app
.
state
.
config
.
CHUNK_SIZE
,
"chunk_overlap"
:
app
.
state
.
config
.
CHUNK_OVERLAP
,
},
"web_loader_ssl_verification"
:
app
.
state
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
,
"web_loader_ssl_verification"
:
app
.
state
.
config
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
,
"youtube"
:
{
"language"
:
app
.
state
.
YOUTUBE_LOADER_LANGUAGE
,
"language"
:
app
.
state
.
config
.
YOUTUBE_LOADER_LANGUAGE
,
"translation"
:
app
.
state
.
YOUTUBE_LOADER_TRANSLATION
,
},
}
...
...
@@ -398,7 +406,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
async
def
get_rag_template
(
user
=
Depends
(
get_current_user
)):
return
{
"status"
:
True
,
"template"
:
app
.
state
.
RAG_TEMPLATE
,
"template"
:
app
.
state
.
config
.
RAG_TEMPLATE
,
}
...
...
@@ -406,10 +414,10 @@ async def get_rag_template(user=Depends(get_current_user)):
async
def
get_query_settings
(
user
=
Depends
(
get_admin_user
)):
return
{
"status"
:
True
,
"template"
:
app
.
state
.
RAG_TEMPLATE
,
"k"
:
app
.
state
.
TOP_K
,
"r"
:
app
.
state
.
RELEVANCE_THRESHOLD
,
"hybrid"
:
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
,
"template"
:
app
.
state
.
config
.
RAG_TEMPLATE
,
"k"
:
app
.
state
.
config
.
TOP_K
,
"r"
:
app
.
state
.
config
.
RELEVANCE_THRESHOLD
,
"hybrid"
:
app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
,
}
...
...
@@ -424,16 +432,20 @@ class QuerySettingsForm(BaseModel):
async
def
update_query_settings
(
form_data
:
QuerySettingsForm
,
user
=
Depends
(
get_admin_user
)
):
app
.
state
.
RAG_TEMPLATE
=
form_data
.
template
if
form_data
.
template
else
RAG_TEMPLATE
app
.
state
.
TOP_K
=
form_data
.
k
if
form_data
.
k
else
4
app
.
state
.
RELEVANCE_THRESHOLD
=
form_data
.
r
if
form_data
.
r
else
0.0
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
=
form_data
.
hybrid
if
form_data
.
hybrid
else
False
app
.
state
.
config
.
RAG_TEMPLATE
=
(
form_data
.
template
if
form_data
.
template
else
RAG_TEMPLATE
,
)
app
.
state
.
config
.
TOP_K
=
form_data
.
k
if
form_data
.
k
else
4
app
.
state
.
config
.
RELEVANCE_THRESHOLD
=
form_data
.
r
if
form_data
.
r
else
0.0
app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
=
(
form_data
.
hybrid
if
form_data
.
hybrid
else
False
,
)
return
{
"status"
:
True
,
"template"
:
app
.
state
.
RAG_TEMPLATE
,
"k"
:
app
.
state
.
TOP_K
,
"r"
:
app
.
state
.
RELEVANCE_THRESHOLD
,
"hybrid"
:
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
,
"template"
:
app
.
state
.
config
.
RAG_TEMPLATE
,
"k"
:
app
.
state
.
config
.
TOP_K
,
"r"
:
app
.
state
.
config
.
RELEVANCE_THRESHOLD
,
"hybrid"
:
app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
,
}
...
...
@@ -451,21 +463,23 @@ def query_doc_handler(
user
=
Depends
(
get_current_user
),
):
try
:
if
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
:
if
app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
:
return
query_doc_with_hybrid_search
(
collection_name
=
form_data
.
collection_name
,
query
=
form_data
.
query
,
embedding_function
=
app
.
state
.
EMBEDDING_FUNCTION
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
config
.
TOP_K
,
reranking_function
=
app
.
state
.
sentence_transformer_rf
,
r
=
form_data
.
r
if
form_data
.
r
else
app
.
state
.
RELEVANCE_THRESHOLD
,
r
=
(
form_data
.
r
if
form_data
.
r
else
app
.
state
.
config
.
RELEVANCE_THRESHOLD
),
)
else
:
return
query_doc
(
collection_name
=
form_data
.
collection_name
,
query
=
form_data
.
query
,
embedding_function
=
app
.
state
.
EMBEDDING_FUNCTION
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
config
.
TOP_K
,
)
except
Exception
as
e
:
log
.
exception
(
e
)
...
...
@@ -489,21 +503,23 @@ def query_collection_handler(
user
=
Depends
(
get_current_user
),
):
try
:
if
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
:
if
app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
:
return
query_collection_with_hybrid_search
(
collection_names
=
form_data
.
collection_names
,
query
=
form_data
.
query
,
embedding_function
=
app
.
state
.
EMBEDDING_FUNCTION
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
config
.
TOP_K
,
reranking_function
=
app
.
state
.
sentence_transformer_rf
,
r
=
form_data
.
r
if
form_data
.
r
else
app
.
state
.
RELEVANCE_THRESHOLD
,
r
=
(
form_data
.
r
if
form_data
.
r
else
app
.
state
.
config
.
RELEVANCE_THRESHOLD
),
)
else
:
return
query_collection
(
collection_names
=
form_data
.
collection_names
,
query
=
form_data
.
query
,
embedding_function
=
app
.
state
.
EMBEDDING_FUNCTION
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
config
.
TOP_K
,
)
except
Exception
as
e
:
...
...
@@ -520,7 +536,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
loader
=
YoutubeLoader
.
from_youtube_url
(
form_data
.
url
,
add_video_info
=
True
,
language
=
app
.
state
.
YOUTUBE_LOADER_LANGUAGE
,
language
=
app
.
state
.
config
.
YOUTUBE_LOADER_LANGUAGE
,
translation
=
app
.
state
.
YOUTUBE_LOADER_TRANSLATION
,
)
data
=
loader
.
load
()
...
...
@@ -548,7 +564,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try
:
loader
=
get_web_loader
(
form_data
.
url
,
verify_ssl
=
app
.
state
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
form_data
.
url
,
verify_ssl
=
app
.
state
.
config
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
,
)
data
=
loader
.
load
()
...
...
@@ -604,8 +621,8 @@ def resolve_hostname(hostname):
def
store_data_in_vector_db
(
data
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
text_splitter
=
RecursiveCharacterTextSplitter
(
chunk_size
=
app
.
state
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
CHUNK_OVERLAP
,
chunk_size
=
app
.
state
.
config
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
config
.
CHUNK_OVERLAP
,
add_start_index
=
True
,
)
...
...
@@ -622,8 +639,8 @@ def store_text_in_vector_db(
text
,
metadata
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
text_splitter
=
RecursiveCharacterTextSplitter
(
chunk_size
=
app
.
state
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
CHUNK_OVERLAP
,
chunk_size
=
app
.
state
.
config
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
config
.
CHUNK_OVERLAP
,
add_start_index
=
True
,
)
docs
=
text_splitter
.
create_documents
([
text
],
metadatas
=
[
metadata
])
...
...
@@ -646,11 +663,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
)
embedding_func
=
get_embedding_function
(
app
.
state
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
sentence_transformer_ef
,
app
.
state
.
OPENAI_API_KEY
,
app
.
state
.
OPENAI_API_BASE_URL
,
app
.
state
.
config
.
OPENAI_API_KEY
,
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
)
embedding_texts
=
list
(
map
(
lambda
x
:
x
.
replace
(
"
\n
"
,
" "
),
texts
))
...
...
@@ -724,7 +741,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
]
if
file_ext
==
"pdf"
:
loader
=
PyPDFLoader
(
file_path
,
extract_images
=
app
.
state
.
PDF_EXTRACT_IMAGES
)
loader
=
PyPDFLoader
(
file_path
,
extract_images
=
app
.
state
.
config
.
PDF_EXTRACT_IMAGES
)
elif
file_ext
==
"csv"
:
loader
=
CSVLoader
(
file_path
)
elif
file_ext
==
"rst"
:
...
...
backend/apps/web/main.py
View file @
8b0144cd
...
...
@@ -21,20 +21,24 @@ from config import (
USER_PERMISSIONS
,
WEBHOOK_URL
,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
,
JWT_EXPIRES_IN
,
AppConfig
,
)
app
=
FastAPI
()
origins
=
[
"*"
]
app
.
state
.
ENABLE_SIGNUP
=
ENABLE_SIGNUP
app
.
state
.
JWT_EXPIRES_IN
=
"-1"
app
.
state
.
config
=
AppConfig
()
app
.
state
.
DEFAULT_MODELS
=
DEFAULT_MODELS
app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
=
DEFAULT_PROMPT_SUGGESTIONS
app
.
state
.
DEFAULT_USER_ROLE
=
DEFAULT_USER_ROLE
app
.
state
.
USER_PERMISSIONS
=
USER_PERMISSIONS
app
.
state
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
config
.
ENABLE_SIGNUP
=
ENABLE_SIGNUP
app
.
state
.
config
.
JWT_EXPIRES_IN
=
JWT_EXPIRES_IN
app
.
state
.
config
.
DEFAULT_MODELS
=
DEFAULT_MODELS
app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
=
DEFAULT_PROMPT_SUGGESTIONS
app
.
state
.
config
.
DEFAULT_USER_ROLE
=
DEFAULT_USER_ROLE
app
.
state
.
config
.
USER_PERMISSIONS
=
USER_PERMISSIONS
app
.
state
.
config
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
AUTH_TRUSTED_EMAIL_HEADER
=
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app
.
add_middleware
(
...
...
@@ -61,6 +65,6 @@ async def get_status():
return
{
"status"
:
True
,
"auth"
:
WEBUI_AUTH
,
"default_models"
:
app
.
state
.
DEFAULT_MODELS
,
"default_prompt_suggestions"
:
app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
,
"default_models"
:
app
.
state
.
config
.
DEFAULT_MODELS
,
"default_prompt_suggestions"
:
app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
,
}
backend/apps/web/routers/auths.py
View file @
8b0144cd
...
...
@@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm):
if
user
:
token
=
create_token
(
data
=
{
"id"
:
user
.
id
},
expires_delta
=
parse_duration
(
request
.
app
.
state
.
JWT_EXPIRES_IN
),
expires_delta
=
parse_duration
(
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
),
)
return
{
...
...
@@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm):
@
router
.
post
(
"/signup"
,
response_model
=
SigninResponse
)
async
def
signup
(
request
:
Request
,
form_data
:
SignupForm
):
if
not
request
.
app
.
state
.
ENABLE_SIGNUP
and
WEBUI_AUTH
:
if
not
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
and
WEBUI_AUTH
:
raise
HTTPException
(
status
.
HTTP_403_FORBIDDEN
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
...
...
@@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm):
role
=
(
"admin"
if
Users
.
get_num_users
()
==
0
else
request
.
app
.
state
.
DEFAULT_USER_ROLE
else
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
)
hashed
=
get_password_hash
(
form_data
.
password
)
user
=
Auths
.
insert_new_auth
(
...
...
@@ -194,13 +194,13 @@ async def signup(request: Request, form_data: SignupForm):
if
user
:
token
=
create_token
(
data
=
{
"id"
:
user
.
id
},
expires_delta
=
parse_duration
(
request
.
app
.
state
.
JWT_EXPIRES_IN
),
expires_delta
=
parse_duration
(
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
),
)
# response.set_cookie(key='token', value=token, httponly=True)
if
request
.
app
.
state
.
WEBHOOK_URL
:
if
request
.
app
.
state
.
config
.
WEBHOOK_URL
:
post_webhook
(
request
.
app
.
state
.
WEBHOOK_URL
,
request
.
app
.
state
.
config
.
WEBHOOK_URL
,
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
{
"action"
:
"signup"
,
...
...
@@ -276,13 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
@
router
.
get
(
"/signup/enabled"
,
response_model
=
bool
)
async
def
get_sign_up_status
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
request
.
app
.
state
.
ENABLE_SIGNUP
return
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
@
router
.
get
(
"/signup/enabled/toggle"
,
response_model
=
bool
)
async
def
toggle_sign_up
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
request
.
app
.
state
.
ENABLE_SIGNUP
=
not
request
.
app
.
state
.
ENABLE_SIGNUP
return
request
.
app
.
state
.
ENABLE_SIGNUP
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
=
not
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
return
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
############################
...
...
@@ -292,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
@
router
.
get
(
"/signup/user/role"
)
async
def
get_default_user_role
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
request
.
app
.
state
.
DEFAULT_USER_ROLE
return
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
class
UpdateRoleForm
(
BaseModel
):
...
...
@@ -304,8 +304,8 @@ async def update_default_user_role(
request
:
Request
,
form_data
:
UpdateRoleForm
,
user
=
Depends
(
get_admin_user
)
):
if
form_data
.
role
in
[
"pending"
,
"user"
,
"admin"
]:
request
.
app
.
state
.
DEFAULT_USER_ROLE
=
form_data
.
role
return
request
.
app
.
state
.
DEFAULT_USER_ROLE
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
=
form_data
.
role
return
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
############################
...
...
@@ -315,7 +315,7 @@ async def update_default_user_role(
@
router
.
get
(
"/token/expires"
)
async
def
get_token_expires_duration
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
request
.
app
.
state
.
JWT_EXPIRES_IN
return
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
class
UpdateJWTExpiresDurationForm
(
BaseModel
):
...
...
@@ -332,10 +332,10 @@ async def update_token_expires_duration(
# Check if the input string matches the pattern
if
re
.
match
(
pattern
,
form_data
.
duration
):
request
.
app
.
state
.
JWT_EXPIRES_IN
=
form_data
.
duration
return
request
.
app
.
state
.
JWT_EXPIRES_IN
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
=
form_data
.
duration
return
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
else
:
return
request
.
app
.
state
.
JWT_EXPIRES_IN
return
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
############################
...
...
backend/apps/web/routers/configs.py
View file @
8b0144cd
...
...
@@ -44,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel):
async
def
set_global_default_models
(
request
:
Request
,
form_data
:
SetDefaultModelsForm
,
user
=
Depends
(
get_admin_user
)
):
request
.
app
.
state
.
DEFAULT_MODELS
=
form_data
.
models
return
request
.
app
.
state
.
DEFAULT_MODELS
request
.
app
.
state
.
config
.
DEFAULT_MODELS
=
form_data
.
models
return
request
.
app
.
state
.
config
.
DEFAULT_MODELS
@
router
.
post
(
"/default/suggestions"
,
response_model
=
List
[
PromptSuggestion
])
...
...
@@ -55,5 +55,5 @@ async def set_global_default_suggestions(
user
=
Depends
(
get_admin_user
),
):
data
=
form_data
.
model_dump
()
request
.
app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
=
data
[
"suggestions"
]
return
request
.
app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
request
.
app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
=
data
[
"suggestions"
]
return
request
.
app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
backend/apps/web/routers/users.py
View file @
8b0144cd
...
...
@@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
@
router
.
get
(
"/permissions/user"
)
async
def
get_user_permissions
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
request
.
app
.
state
.
USER_PERMISSIONS
return
request
.
app
.
state
.
config
.
USER_PERMISSIONS
@
router
.
post
(
"/permissions/user"
)
async
def
update_user_permissions
(
request
:
Request
,
form_data
:
dict
,
user
=
Depends
(
get_admin_user
)
):
request
.
app
.
state
.
USER_PERMISSIONS
=
form_data
return
request
.
app
.
state
.
USER_PERMISSIONS
request
.
app
.
state
.
config
.
USER_PERMISSIONS
=
form_data
return
request
.
app
.
state
.
config
.
USER_PERMISSIONS
############################
...
...
backend/config.py
View file @
8b0144cd
...
...
@@ -5,6 +5,7 @@ import chromadb
from
chromadb
import
Settings
from
base64
import
b64encode
from
bs4
import
BeautifulSoup
from
typing
import
TypeVar
,
Generic
,
Union
from
pathlib
import
Path
import
json
...
...
@@ -17,7 +18,6 @@ import shutil
from
secrets
import
token_bytes
from
constants
import
ERROR_MESSAGES
####################################
# Load .env file
####################################
...
...
@@ -71,7 +71,6 @@ for source in log_sources:
log
.
setLevel
(
SRC_LOG_LEVELS
[
"CONFIG"
])
WEBUI_NAME
=
os
.
environ
.
get
(
"WEBUI_NAME"
,
"Open WebUI"
)
if
WEBUI_NAME
!=
"Open WebUI"
:
WEBUI_NAME
+=
" (Open WebUI)"
...
...
@@ -161,16 +160,6 @@ CHANGELOG = changelog_json
WEBUI_VERSION
=
os
.
environ
.
get
(
"WEBUI_VERSION"
,
"v1.0.0-alpha.100"
)
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH
=
os
.
environ
.
get
(
"WEBUI_AUTH"
,
"True"
).
lower
()
==
"true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
=
os
.
environ
.
get
(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER"
,
None
)
####################################
# DATA/FRONTEND BUILD DIR
####################################
...
...
@@ -184,6 +173,108 @@ try:
except
:
CONFIG_DATA
=
{}
####################################
# Config helpers
####################################
def
save_config
():
try
:
with
open
(
f
"
{
DATA_DIR
}
/config.json"
,
"w"
)
as
f
:
json
.
dump
(
CONFIG_DATA
,
f
,
indent
=
"
\t
"
)
except
Exception
as
e
:
log
.
exception
(
e
)
def
get_config_value
(
config_path
:
str
):
path_parts
=
config_path
.
split
(
"."
)
cur_config
=
CONFIG_DATA
for
key
in
path_parts
:
if
key
in
cur_config
:
cur_config
=
cur_config
[
key
]
else
:
return
None
return
cur_config
T
=
TypeVar
(
"T"
)
class
PersistentConfig
(
Generic
[
T
]):
def
__init__
(
self
,
env_name
:
str
,
config_path
:
str
,
env_value
:
T
):
self
.
env_name
=
env_name
self
.
config_path
=
config_path
self
.
env_value
=
env_value
self
.
config_value
=
get_config_value
(
config_path
)
if
self
.
config_value
is
not
None
:
log
.
info
(
f
"'
{
env_name
}
' loaded from config.json"
)
self
.
value
=
self
.
config_value
else
:
self
.
value
=
env_value
def
__str__
(
self
):
return
str
(
self
.
value
)
@
property
def
__dict__
(
self
):
raise
TypeError
(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
def
__getattribute__
(
self
,
item
):
if
item
==
"__dict__"
:
raise
TypeError
(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
return
super
().
__getattribute__
(
item
)
def
save
(
self
):
# Don't save if the value is the same as the env value and the config value
if
self
.
env_value
==
self
.
value
:
if
self
.
config_value
==
self
.
value
:
return
log
.
info
(
f
"Saving '
{
self
.
env_name
}
' to config.json"
)
path_parts
=
self
.
config_path
.
split
(
"."
)
config
=
CONFIG_DATA
for
key
in
path_parts
[:
-
1
]:
if
key
not
in
config
:
config
[
key
]
=
{}
config
=
config
[
key
]
config
[
path_parts
[
-
1
]]
=
self
.
value
save_config
()
self
.
config_value
=
self
.
value
class
AppConfig
:
_state
:
dict
[
str
,
PersistentConfig
]
def
__init__
(
self
):
super
().
__setattr__
(
"_state"
,
{})
def
__setattr__
(
self
,
key
,
value
):
if
isinstance
(
value
,
PersistentConfig
):
self
.
_state
[
key
]
=
value
else
:
self
.
_state
[
key
].
value
=
value
self
.
_state
[
key
].
save
()
def
__getattr__
(
self
,
key
):
return
self
.
_state
[
key
].
value
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH
=
os
.
environ
.
get
(
"WEBUI_AUTH"
,
"True"
).
lower
()
==
"true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
=
os
.
environ
.
get
(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER"
,
None
)
JWT_EXPIRES_IN
=
PersistentConfig
(
"JWT_EXPIRES_IN"
,
"auth.jwt_expiry"
,
os
.
environ
.
get
(
"JWT_EXPIRES_IN"
,
"-1"
)
)
####################################
# Static DIR
####################################
...
...
@@ -318,7 +409,9 @@ OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
OLLAMA_BASE_URLS
=
OLLAMA_BASE_URLS
if
OLLAMA_BASE_URLS
!=
""
else
OLLAMA_BASE_URL
OLLAMA_BASE_URLS
=
[
url
.
strip
()
for
url
in
OLLAMA_BASE_URLS
.
split
(
";"
)]
OLLAMA_BASE_URLS
=
PersistentConfig
(
"OLLAMA_BASE_URLS"
,
"ollama.base_urls"
,
OLLAMA_BASE_URLS
)
####################################
# OPENAI_API
...
...
@@ -335,7 +428,9 @@ OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
OPENAI_API_KEYS
=
OPENAI_API_KEYS
if
OPENAI_API_KEYS
!=
""
else
OPENAI_API_KEY
OPENAI_API_KEYS
=
[
url
.
strip
()
for
url
in
OPENAI_API_KEYS
.
split
(
";"
)]
OPENAI_API_KEYS
=
PersistentConfig
(
"OPENAI_API_KEYS"
,
"openai.api_keys"
,
OPENAI_API_KEYS
)
OPENAI_API_BASE_URLS
=
os
.
environ
.
get
(
"OPENAI_API_BASE_URLS"
,
""
)
OPENAI_API_BASE_URLS
=
(
...
...
@@ -346,37 +441,42 @@ OPENAI_API_BASE_URLS = [
url
.
strip
()
if
url
!=
""
else
"https://api.openai.com/v1"
for
url
in
OPENAI_API_BASE_URLS
.
split
(
";"
)
]
OPENAI_API_BASE_URLS
=
PersistentConfig
(
"OPENAI_API_BASE_URLS"
,
"openai.api_base_urls"
,
OPENAI_API_BASE_URLS
)
OPENAI_API_KEY
=
""
try
:
OPENAI_API_KEY
=
OPENAI_API_KEYS
[
OPENAI_API_BASE_URLS
.
index
(
"https://api.openai.com/v1"
)
OPENAI_API_KEY
=
OPENAI_API_KEYS
.
value
[
OPENAI_API_BASE_URLS
.
value
.
index
(
"https://api.openai.com/v1"
)
]
except
:
pass
OPENAI_API_BASE_URL
=
"https://api.openai.com/v1"
####################################
# WEBUI
####################################
ENABLE_SIGNUP
=
(
False
if
WEBUI_AUTH
==
False
else
os
.
environ
.
get
(
"ENABLE_SIGNUP"
,
"True"
).
lower
()
==
"true"
ENABLE_SIGNUP
=
PersistentConfig
(
"ENABLE_SIGNUP"
,
"ui.enable_signup"
,
(
False
if
not
WEBUI_AUTH
else
os
.
environ
.
get
(
"ENABLE_SIGNUP"
,
"True"
).
lower
()
==
"true"
),
)
DEFAULT_MODELS
=
PersistentConfig
(
"DEFAULT_MODELS"
,
"ui.default_models"
,
os
.
environ
.
get
(
"DEFAULT_MODELS"
,
None
)
)
DEFAULT_MODELS
=
os
.
environ
.
get
(
"DEFAULT_MODELS"
,
None
)
DEFAULT_PROMPT_SUGGESTIONS
=
(
CONFIG_DATA
[
"ui"
][
"prompt_suggestions"
]
if
"ui"
in
CONFIG_DATA
and
"prompt_suggestions"
in
CONFIG_DATA
[
"ui"
]
and
type
(
CONFIG_DATA
[
"ui"
][
"prompt_suggestions"
])
is
list
else
[
DEFAULT_PROMPT_SUGGESTIONS
=
PersistentConfig
(
"DEFAULT_PROMPT_SUGGESTIONS"
,
"ui.prompt_suggestions"
,
[
{
"title"
:
[
"Help me study"
,
"vocabulary for a college entrance exam"
],
"content"
:
"Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."
,
...
...
@@ -404,23 +504,40 @@ DEFAULT_PROMPT_SUGGESTIONS = (
"title"
:
[
"Overcome procrastination"
,
"give me tips"
],
"content"
:
"Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?"
,
},
]
]
,
)
DEFAULT_USER_ROLE
=
os
.
getenv
(
"DEFAULT_USER_ROLE"
,
"pending"
)
DEFAULT_USER_ROLE
=
PersistentConfig
(
"DEFAULT_USER_ROLE"
,
"ui.default_user_role"
,
os
.
getenv
(
"DEFAULT_USER_ROLE"
,
"pending"
),
)
USER_PERMISSIONS_CHAT_DELETION
=
(
os
.
environ
.
get
(
"USER_PERMISSIONS_CHAT_DELETION"
,
"True"
).
lower
()
==
"true"
)
USER_PERMISSIONS
=
{
"chat"
:
{
"deletion"
:
USER_PERMISSIONS_CHAT_DELETION
}}
USER_PERMISSIONS
=
PersistentConfig
(
"USER_PERMISSIONS"
,
"ui.user_permissions"
,
{
"chat"
:
{
"deletion"
:
USER_PERMISSIONS_CHAT_DELETION
}},
)
ENABLE_MODEL_FILTER
=
os
.
environ
.
get
(
"ENABLE_MODEL_FILTER"
,
"False"
).
lower
()
==
"true"
ENABLE_MODEL_FILTER
=
PersistentConfig
(
"ENABLE_MODEL_FILTER"
,
"model_filter.enable"
,
os
.
environ
.
get
(
"ENABLE_MODEL_FILTER"
,
"False"
).
lower
()
==
"true"
,
)
MODEL_FILTER_LIST
=
os
.
environ
.
get
(
"MODEL_FILTER_LIST"
,
""
)
MODEL_FILTER_LIST
=
[
model
.
strip
()
for
model
in
MODEL_FILTER_LIST
.
split
(
";"
)]
MODEL_FILTER_LIST
=
PersistentConfig
(
"MODEL_FILTER_LIST"
,
"model_filter.list"
,
[
model
.
strip
()
for
model
in
MODEL_FILTER_LIST
.
split
(
";"
)],
)
WEBHOOK_URL
=
os
.
environ
.
get
(
"WEBHOOK_URL"
,
""
)
WEBHOOK_URL
=
PersistentConfig
(
"WEBHOOK_URL"
,
"webhook_url"
,
os
.
environ
.
get
(
"WEBHOOK_URL"
,
""
)
)
ENABLE_ADMIN_EXPORT
=
os
.
environ
.
get
(
"ENABLE_ADMIN_EXPORT"
,
"True"
).
lower
()
==
"true"
...
...
@@ -458,26 +575,45 @@ else:
CHROMA_HTTP_SSL
=
os
.
environ
.
get
(
"CHROMA_HTTP_SSL"
,
"false"
).
lower
()
==
"true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
RAG_TOP_K
=
int
(
os
.
environ
.
get
(
"RAG_TOP_K"
,
"5"
))
RAG_RELEVANCE_THRESHOLD
=
float
(
os
.
environ
.
get
(
"RAG_RELEVANCE_THRESHOLD"
,
"0.0"
))
ENABLE_RAG_HYBRID_SEARCH
=
(
os
.
environ
.
get
(
"ENABLE_RAG_HYBRID_SEARCH"
,
""
).
lower
()
==
"true"
RAG_TOP_K
=
PersistentConfig
(
"RAG_TOP_K"
,
"rag.top_k"
,
int
(
os
.
environ
.
get
(
"RAG_TOP_K"
,
"5"
))
)
RAG_RELEVANCE_THRESHOLD
=
PersistentConfig
(
"RAG_RELEVANCE_THRESHOLD"
,
"rag.relevance_threshold"
,
float
(
os
.
environ
.
get
(
"RAG_RELEVANCE_THRESHOLD"
,
"0.0"
)),
)
ENABLE_RAG_HYBRID_SEARCH
=
PersistentConfig
(
"ENABLE_RAG_HYBRID_SEARCH"
,
"rag.enable_hybrid_search"
,
os
.
environ
.
get
(
"ENABLE_RAG_HYBRID_SEARCH"
,
""
).
lower
()
==
"true"
,
)
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
=
(
os
.
environ
.
get
(
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION"
,
"True"
).
lower
()
==
"true"
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
=
PersistentConfig
(
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION"
,
"rag.enable_web_loader_ssl_verification"
,
os
.
environ
.
get
(
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION"
,
"True"
).
lower
()
==
"true"
,
)
RAG_EMBEDDING_ENGINE
=
os
.
environ
.
get
(
"RAG_EMBEDDING_ENGINE"
,
""
)
RAG_EMBEDDING_ENGINE
=
PersistentConfig
(
"RAG_EMBEDDING_ENGINE"
,
"rag.embedding_engine"
,
os
.
environ
.
get
(
"RAG_EMBEDDING_ENGINE"
,
""
),
)
PDF_EXTRACT_IMAGES
=
os
.
environ
.
get
(
"PDF_EXTRACT_IMAGES"
,
"False"
).
lower
()
==
"true"
PDF_EXTRACT_IMAGES
=
PersistentConfig
(
"PDF_EXTRACT_IMAGES"
,
"rag.pdf_extract_images"
,
os
.
environ
.
get
(
"PDF_EXTRACT_IMAGES"
,
"False"
).
lower
()
==
"true"
,
)
RAG_EMBEDDING_MODEL
=
os
.
environ
.
get
(
"RAG_EMBEDDING_MODEL"
,
"sentence-transformers/all-MiniLM-L6-v2"
RAG_EMBEDDING_MODEL
=
PersistentConfig
(
"RAG_EMBEDDING_MODEL"
,
"rag.embedding_model"
,
os
.
environ
.
get
(
"RAG_EMBEDDING_MODEL"
,
"sentence-transformers/all-MiniLM-L6-v2"
),
)
log
.
info
(
f
"Embedding model set:
{
RAG_EMBEDDING_MODEL
}
"
),
log
.
info
(
f
"Embedding model set:
{
RAG_EMBEDDING_MODEL
.
value
}
"
),
RAG_EMBEDDING_MODEL_AUTO_UPDATE
=
(
os
.
environ
.
get
(
"RAG_EMBEDDING_MODEL_AUTO_UPDATE"
,
""
).
lower
()
==
"true"
...
...
@@ -487,9 +623,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os
.
environ
.
get
(
"RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE"
,
""
).
lower
()
==
"true"
)
RAG_RERANKING_MODEL
=
os
.
environ
.
get
(
"RAG_RERANKING_MODEL"
,
""
)
if
not
RAG_RERANKING_MODEL
==
""
:
log
.
info
(
f
"Reranking model set:
{
RAG_RERANKING_MODEL
}
"
),
RAG_RERANKING_MODEL
=
PersistentConfig
(
"RAG_RERANKING_MODEL"
,
"rag.reranking_model"
,
os
.
environ
.
get
(
"RAG_RERANKING_MODEL"
,
""
),
)
if
RAG_RERANKING_MODEL
.
value
!=
""
:
log
.
info
(
f
"Reranking model set:
{
RAG_RERANKING_MODEL
.
value
}
"
),
RAG_RERANKING_MODEL_AUTO_UPDATE
=
(
os
.
environ
.
get
(
"RAG_RERANKING_MODEL_AUTO_UPDATE"
,
""
).
lower
()
==
"true"
...
...
@@ -527,9 +667,14 @@ if USE_CUDA.lower() == "true":
else
:
DEVICE_TYPE
=
"cpu"
CHUNK_SIZE
=
int
(
os
.
environ
.
get
(
"CHUNK_SIZE"
,
"1500"
))
CHUNK_OVERLAP
=
int
(
os
.
environ
.
get
(
"CHUNK_OVERLAP"
,
"100"
))
CHUNK_SIZE
=
PersistentConfig
(
"CHUNK_SIZE"
,
"rag.chunk_size"
,
int
(
os
.
environ
.
get
(
"CHUNK_SIZE"
,
"1500"
))
)
CHUNK_OVERLAP
=
PersistentConfig
(
"CHUNK_OVERLAP"
,
"rag.chunk_overlap"
,
int
(
os
.
environ
.
get
(
"CHUNK_OVERLAP"
,
"100"
)),
)
DEFAULT_RAG_TEMPLATE
=
"""Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
...
...
@@ -545,16 +690,32 @@ And answer according to the language of the user's question.
Given the context information, answer the query.
Query: [query]"""
RAG_TEMPLATE
=
os
.
environ
.
get
(
"RAG_TEMPLATE"
,
DEFAULT_RAG_TEMPLATE
)
RAG_TEMPLATE
=
PersistentConfig
(
"RAG_TEMPLATE"
,
"rag.template"
,
os
.
environ
.
get
(
"RAG_TEMPLATE"
,
DEFAULT_RAG_TEMPLATE
),
)
RAG_OPENAI_API_BASE_URL
=
os
.
getenv
(
"RAG_OPENAI_API_BASE_URL"
,
OPENAI_API_BASE_URL
)
RAG_OPENAI_API_KEY
=
os
.
getenv
(
"RAG_OPENAI_API_KEY"
,
OPENAI_API_KEY
)
RAG_OPENAI_API_BASE_URL
=
PersistentConfig
(
"RAG_OPENAI_API_BASE_URL"
,
"rag.openai_api_base_url"
,
os
.
getenv
(
"RAG_OPENAI_API_BASE_URL"
,
OPENAI_API_BASE_URL
),
)
RAG_OPENAI_API_KEY
=
PersistentConfig
(
"RAG_OPENAI_API_KEY"
,
"rag.openai_api_key"
,
os
.
getenv
(
"RAG_OPENAI_API_KEY"
,
OPENAI_API_KEY
),
)
ENABLE_RAG_LOCAL_WEB_FETCH
=
(
os
.
getenv
(
"ENABLE_RAG_LOCAL_WEB_FETCH"
,
"False"
).
lower
()
==
"true"
)
YOUTUBE_LOADER_LANGUAGE
=
os
.
getenv
(
"YOUTUBE_LOADER_LANGUAGE"
,
"en"
).
split
(
","
)
YOUTUBE_LOADER_LANGUAGE
=
PersistentConfig
(
"YOUTUBE_LOADER_LANGUAGE"
,
"rag.youtube_loader_language"
,
os
.
getenv
(
"YOUTUBE_LOADER_LANGUAGE"
,
"en"
).
split
(
","
),
)
####################################
# Transcribe
...
...
@@ -571,34 +732,78 @@ WHISPER_MODEL_AUTO_UPDATE = (
# Images
####################################
IMAGE_GENERATION_ENGINE
=
os
.
getenv
(
"IMAGE_GENERATION_ENGINE"
,
""
)
IMAGE_GENERATION_ENGINE
=
PersistentConfig
(
"IMAGE_GENERATION_ENGINE"
,
"image_generation.engine"
,
os
.
getenv
(
"IMAGE_GENERATION_ENGINE"
,
""
),
)
ENABLE_IMAGE_GENERATION
=
(
os
.
environ
.
get
(
"ENABLE_IMAGE_GENERATION"
,
""
).
lower
()
==
"true"
ENABLE_IMAGE_GENERATION
=
PersistentConfig
(
"ENABLE_IMAGE_GENERATION"
,
"image_generation.enable"
,
os
.
environ
.
get
(
"ENABLE_IMAGE_GENERATION"
,
""
).
lower
()
==
"true"
,
)
AUTOMATIC1111_BASE_URL
=
PersistentConfig
(
"AUTOMATIC1111_BASE_URL"
,
"image_generation.automatic1111.base_url"
,
os
.
getenv
(
"AUTOMATIC1111_BASE_URL"
,
""
),
)
AUTOMATIC1111_BASE_URL
=
os
.
getenv
(
"AUTOMATIC1111_BASE_URL"
,
""
)
COMFYUI_BASE_URL
=
os
.
getenv
(
"COMFYUI_BASE_URL"
,
""
)
COMFYUI_BASE_URL
=
PersistentConfig
(
"COMFYUI_BASE_URL"
,
"image_generation.comfyui.base_url"
,
os
.
getenv
(
"COMFYUI_BASE_URL"
,
""
),
)
IMAGES_OPENAI_API_BASE_URL
=
os
.
getenv
(
"IMAGES_OPENAI_API_BASE_URL"
,
OPENAI_API_BASE_URL
IMAGES_OPENAI_API_BASE_URL
=
PersistentConfig
(
"IMAGES_OPENAI_API_BASE_URL"
,
"image_generation.openai.api_base_url"
,
os
.
getenv
(
"IMAGES_OPENAI_API_BASE_URL"
,
OPENAI_API_BASE_URL
),
)
IMAGES_OPENAI_API_KEY
=
PersistentConfig
(
"IMAGES_OPENAI_API_KEY"
,
"image_generation.openai.api_key"
,
os
.
getenv
(
"IMAGES_OPENAI_API_KEY"
,
OPENAI_API_KEY
),
)
IMAGES_OPENAI_API_KEY
=
os
.
getenv
(
"IMAGES_OPENAI_API_KEY"
,
OPENAI_API_KEY
)
IMAGE_SIZE
=
os
.
getenv
(
"IMAGE_SIZE"
,
"512x512"
)
IMAGE_SIZE
=
PersistentConfig
(
"IMAGE_SIZE"
,
"image_generation.size"
,
os
.
getenv
(
"IMAGE_SIZE"
,
"512x512"
)
)
IMAGE_STEPS
=
int
(
os
.
getenv
(
"IMAGE_STEPS"
,
50
))
IMAGE_STEPS
=
PersistentConfig
(
"IMAGE_STEPS"
,
"image_generation.steps"
,
int
(
os
.
getenv
(
"IMAGE_STEPS"
,
50
))
)
IMAGE_GENERATION_MODEL
=
os
.
getenv
(
"IMAGE_GENERATION_MODEL"
,
""
)
IMAGE_GENERATION_MODEL
=
PersistentConfig
(
"IMAGE_GENERATION_MODEL"
,
"image_generation.model"
,
os
.
getenv
(
"IMAGE_GENERATION_MODEL"
,
""
),
)
####################################
# Audio
####################################
AUDIO_OPENAI_API_BASE_URL
=
os
.
getenv
(
"AUDIO_OPENAI_API_BASE_URL"
,
OPENAI_API_BASE_URL
)
AUDIO_OPENAI_API_KEY
=
os
.
getenv
(
"AUDIO_OPENAI_API_KEY"
,
OPENAI_API_KEY
)
AUDIO_OPENAI_API_MODEL
=
os
.
getenv
(
"AUDIO_OPENAI_API_MODEL"
,
"tts-1"
)
AUDIO_OPENAI_API_VOICE
=
os
.
getenv
(
"AUDIO_OPENAI_API_VOICE"
,
"alloy"
)
AUDIO_OPENAI_API_BASE_URL
=
PersistentConfig
(
"AUDIO_OPENAI_API_BASE_URL"
,
"audio.openai.api_base_url"
,
os
.
getenv
(
"AUDIO_OPENAI_API_BASE_URL"
,
OPENAI_API_BASE_URL
),
)
AUDIO_OPENAI_API_KEY
=
PersistentConfig
(
"AUDIO_OPENAI_API_KEY"
,
"audio.openai.api_key"
,
os
.
getenv
(
"AUDIO_OPENAI_API_KEY"
,
OPENAI_API_KEY
),
)
AUDIO_OPENAI_API_MODEL
=
PersistentConfig
(
"AUDIO_OPENAI_API_MODEL"
,
"audio.openai.api_model"
,
os
.
getenv
(
"AUDIO_OPENAI_API_MODEL"
,
"tts-1"
),
)
AUDIO_OPENAI_API_VOICE
=
PersistentConfig
(
"AUDIO_OPENAI_API_VOICE"
,
"audio.openai.api_voice"
,
os
.
getenv
(
"AUDIO_OPENAI_API_VOICE"
,
"alloy"
),
)
####################################
# LiteLLM
...
...
backend/main.py
View file @
8b0144cd
...
...
@@ -59,6 +59,7 @@ from config import (
SRC_LOG_LEVELS
,
WEBHOOK_URL
,
ENABLE_ADMIN_EXPORT
,
AppConfig
,
)
from
constants
import
ERROR_MESSAGES
...
...
@@ -107,10 +108,11 @@ app = FastAPI(
docs_url
=
"/docs"
if
ENV
==
"dev"
else
None
,
redoc_url
=
None
,
lifespan
=
lifespan
)
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
config
=
AppConfig
()
app
.
state
.
config
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
config
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
config
.
WEBHOOK_URL
=
WEBHOOK_URL
origins
=
[
"*"
]
...
...
@@ -250,9 +252,9 @@ async def get_app_config():
"version"
:
VERSION
,
"auth"
:
WEBUI_AUTH
,
"default_locale"
:
default_locale
,
"images"
:
images_app
.
state
.
ENABLED
,
"default_models"
:
webui_app
.
state
.
DEFAULT_MODELS
,
"default_prompt_suggestions"
:
webui_app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
,
"images"
:
images_app
.
state
.
config
.
ENABLED
,
"default_models"
:
webui_app
.
state
.
config
.
DEFAULT_MODELS
,
"default_prompt_suggestions"
:
webui_app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
,
"trusted_header_auth"
:
bool
(
webui_app
.
state
.
AUTH_TRUSTED_EMAIL_HEADER
),
"admin_export_enabled"
:
ENABLE_ADMIN_EXPORT
,
}
...
...
@@ -261,8 +263,8 @@ async def get_app_config():
@
app
.
get
(
"/api/config/model/filter"
)
async
def
get_model_filter_config
(
user
=
Depends
(
get_admin_user
)):
return
{
"enabled"
:
app
.
state
.
ENABLE_MODEL_FILTER
,
"models"
:
app
.
state
.
MODEL_FILTER_LIST
,
"enabled"
:
app
.
state
.
config
.
ENABLE_MODEL_FILTER
,
"models"
:
app
.
state
.
config
.
MODEL_FILTER_LIST
,
}
...
...
@@ -275,28 +277,28 @@ class ModelFilterConfigForm(BaseModel):
async
def
update_model_filter_config
(
form_data
:
ModelFilterConfigForm
,
user
=
Depends
(
get_admin_user
)
):
app
.
state
.
ENABLE_MODEL_FILTER
=
form_data
.
enabled
app
.
state
.
MODEL_FILTER_LIST
=
form_data
.
models
app
.
state
.
config
.
ENABLE_MODEL_FILTER
,
form_data
.
enabled
app
.
state
.
config
.
MODEL_FILTER_LIST
,
form_data
.
models
ollama_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
ENABLE_MODEL_FILTER
ollama_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
MODEL_FILTER_LIST
ollama_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
config
.
ENABLE_MODEL_FILTER
ollama_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
config
.
MODEL_FILTER_LIST
openai_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
ENABLE_MODEL_FILTER
openai_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
MODEL_FILTER_LIST
openai_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
config
.
ENABLE_MODEL_FILTER
openai_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
config
.
MODEL_FILTER_LIST
litellm_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
ENABLE_MODEL_FILTER
litellm_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
MODEL_FILTER_LIST
litellm_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
config
.
ENABLE_MODEL_FILTER
litellm_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
config
.
MODEL_FILTER_LIST
return
{
"enabled"
:
app
.
state
.
ENABLE_MODEL_FILTER
,
"models"
:
app
.
state
.
MODEL_FILTER_LIST
,
"enabled"
:
app
.
state
.
config
.
ENABLE_MODEL_FILTER
,
"models"
:
app
.
state
.
config
.
MODEL_FILTER_LIST
,
}
@
app
.
get
(
"/api/webhook"
)
async
def
get_webhook_url
(
user
=
Depends
(
get_admin_user
)):
return
{
"url"
:
app
.
state
.
WEBHOOK_URL
,
"url"
:
app
.
state
.
config
.
WEBHOOK_URL
,
}
...
...
@@ -306,12 +308,12 @@ class UrlForm(BaseModel):
@
app
.
post
(
"/api/webhook"
)
async
def
update_webhook_url
(
form_data
:
UrlForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
WEBHOOK_URL
=
form_data
.
url
app
.
state
.
config
.
WEBHOOK_URL
=
form_data
.
url
webui_app
.
state
.
WEBHOOK_URL
=
app
.
state
.
WEBHOOK_URL
webui_app
.
state
.
WEBHOOK_URL
=
app
.
state
.
config
.
WEBHOOK_URL
return
{
"url"
:
app
.
state
.
WEBHOOK_URL
,
"url"
:
app
.
state
.
config
.
WEBHOOK_URL
,
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment