Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
298e6848
Commit
298e6848
authored
May 10, 2024
by
Jun Siang Cheah
Browse files
feat: switch to config proxy, remove config_get/set
parent
f712c900
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
342 additions
and
381 deletions
+342
-381
backend/apps/audio/main.py
backend/apps/audio/main.py
+20
-26
backend/apps/images/main.py
backend/apps/images/main.py
+77
-75
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+28
-32
backend/apps/openai/main.py
backend/apps/openai/main.py
+24
-28
backend/apps/rag/main.py
backend/apps/rag/main.py
+117
-141
backend/apps/web/main.py
backend/apps/web/main.py
+12
-10
backend/apps/web/routers/auths.py
backend/apps/web/routers/auths.py
+17
-21
backend/apps/web/routers/configs.py
backend/apps/web/routers/configs.py
+4
-5
backend/apps/web/routers/users.py
backend/apps/web/routers/users.py
+4
-4
backend/config.py
backend/config.py
+15
-13
backend/main.py
backend/main.py
+24
-26
No files found.
backend/apps/audio/main.py
View file @
298e6848
...
@@ -45,8 +45,7 @@ from config import (
...
@@ -45,8 +45,7 @@ from config import (
AUDIO_OPENAI_API_KEY
,
AUDIO_OPENAI_API_KEY
,
AUDIO_OPENAI_API_MODEL
,
AUDIO_OPENAI_API_MODEL
,
AUDIO_OPENAI_API_VOICE
,
AUDIO_OPENAI_API_VOICE
,
config_get
,
AppConfig
,
config_set
,
)
)
log
=
logging
.
getLogger
(
__name__
)
log
=
logging
.
getLogger
(
__name__
)
...
@@ -61,11 +60,11 @@ app.add_middleware(
...
@@ -61,11 +60,11 @@ app.add_middleware(
allow_headers
=
[
"*"
],
allow_headers
=
[
"*"
],
)
)
app
.
state
.
config
=
AppConfig
()
app
.
state
.
OPENAI_API_BASE_URL
=
AUDIO_OPENAI_API_BASE_URL
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
AUDIO_OPENAI_API_BASE_URL
app
.
state
.
OPENAI_API_KEY
=
AUDIO_OPENAI_API_KEY
app
.
state
.
config
.
OPENAI_API_KEY
=
AUDIO_OPENAI_API_KEY
app
.
state
.
OPENAI_API_MODEL
=
AUDIO_OPENAI_API_MODEL
app
.
state
.
config
.
OPENAI_API_MODEL
=
AUDIO_OPENAI_API_MODEL
app
.
state
.
OPENAI_API_VOICE
=
AUDIO_OPENAI_API_VOICE
app
.
state
.
config
.
OPENAI_API_VOICE
=
AUDIO_OPENAI_API_VOICE
# setting device type for whisper model
# setting device type for whisper model
whisper_device_type
=
DEVICE_TYPE
if
DEVICE_TYPE
and
DEVICE_TYPE
==
"cuda"
else
"cpu"
whisper_device_type
=
DEVICE_TYPE
if
DEVICE_TYPE
and
DEVICE_TYPE
==
"cuda"
else
"cpu"
...
@@ -85,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel):
...
@@ -85,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel):
@
app
.
get
(
"/config"
)
@
app
.
get
(
"/config"
)
async
def
get_openai_config
(
user
=
Depends
(
get_admin_user
)):
async
def
get_openai_config
(
user
=
Depends
(
get_admin_user
)):
return
{
return
{
"OPENAI_API_BASE_URL"
:
config_get
(
app
.
state
.
OPENAI_API_BASE_URL
)
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
config_get
(
app
.
state
.
OPENAI_API_KEY
)
,
"OPENAI_API_KEY"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
"OPENAI_API_MODEL"
:
config_get
(
app
.
state
.
OPENAI_API_MODEL
)
,
"OPENAI_API_MODEL"
:
app
.
state
.
config
.
OPENAI_API_MODEL
,
"OPENAI_API_VOICE"
:
config_get
(
app
.
state
.
OPENAI_API_VOICE
)
,
"OPENAI_API_VOICE"
:
app
.
state
.
config
.
OPENAI_API_VOICE
,
}
}
...
@@ -99,22 +98,17 @@ async def update_openai_config(
...
@@ -99,22 +98,17 @@ async def update_openai_config(
if
form_data
.
key
==
""
:
if
form_data
.
key
==
""
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
API_KEY_NOT_FOUND
)
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
API_KEY_NOT_FOUND
)
config_set
(
app
.
state
.
OPENAI_API_BASE_URL
,
form_data
.
url
)
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
form_data
.
url
config_set
(
app
.
state
.
OPENAI_API_KEY
,
form_data
.
key
)
app
.
state
.
config
.
OPENAI_API_KEY
=
form_data
.
key
config_set
(
app
.
state
.
OPENAI_API_MODEL
,
form_data
.
model
)
app
.
state
.
config
.
OPENAI_API_MODEL
=
form_data
.
model
config_set
(
app
.
state
.
OPENAI_API_VOICE
,
form_data
.
speaker
)
app
.
state
.
config
.
OPENAI_API_VOICE
=
form_data
.
speaker
app
.
state
.
OPENAI_API_BASE_URL
.
save
()
app
.
state
.
OPENAI_API_KEY
.
save
()
app
.
state
.
OPENAI_API_MODEL
.
save
()
app
.
state
.
OPENAI_API_VOICE
.
save
()
return
{
return
{
"status"
:
True
,
"status"
:
True
,
"OPENAI_API_BASE_URL"
:
config_get
(
app
.
state
.
OPENAI_API_BASE_URL
)
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
config_get
(
app
.
state
.
OPENAI_API_KEY
)
,
"OPENAI_API_KEY"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
"OPENAI_API_MODEL"
:
config_get
(
app
.
state
.
OPENAI_API_MODEL
)
,
"OPENAI_API_MODEL"
:
app
.
state
.
config
.
OPENAI_API_MODEL
,
"OPENAI_API_VOICE"
:
config_get
(
app
.
state
.
OPENAI_API_VOICE
)
,
"OPENAI_API_VOICE"
:
app
.
state
.
config
.
OPENAI_API_VOICE
,
}
}
...
@@ -131,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
...
@@ -131,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return
FileResponse
(
file_path
)
return
FileResponse
(
file_path
)
headers
=
{}
headers
=
{}
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
OPENAI_API_KEY
}
"
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
config
.
OPENAI_API_KEY
}
"
headers
[
"Content-Type"
]
=
"application/json"
headers
[
"Content-Type"
]
=
"application/json"
r
=
None
r
=
None
try
:
try
:
r
=
requests
.
post
(
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
OPENAI_API_BASE_URL
}
/audio/speech"
,
url
=
f
"
{
app
.
state
.
config
.
OPENAI_API_BASE_URL
}
/audio/speech"
,
data
=
body
,
data
=
body
,
headers
=
headers
,
headers
=
headers
,
stream
=
True
,
stream
=
True
,
...
...
backend/apps/images/main.py
View file @
298e6848
...
@@ -42,8 +42,7 @@ from config import (
...
@@ -42,8 +42,7 @@ from config import (
IMAGE_GENERATION_MODEL
,
IMAGE_GENERATION_MODEL
,
IMAGE_SIZE
,
IMAGE_SIZE
,
IMAGE_STEPS
,
IMAGE_STEPS
,
config_get
,
AppConfig
,
config_set
,
)
)
...
@@ -62,28 +61,30 @@ app.add_middleware(
...
@@ -62,28 +61,30 @@ app.add_middleware(
allow_headers
=
[
"*"
],
allow_headers
=
[
"*"
],
)
)
app
.
state
.
ENGINE
=
IMAGE_GENERATION_ENGINE
app
.
state
.
config
=
AppConfig
()
app
.
state
.
ENABLED
=
ENABLE_IMAGE_GENERATION
app
.
state
.
OPENAI_API_BASE_URL
=
IMAGES_OPENAI_API_BASE_URL
app
.
state
.
config
.
ENGINE
=
IMAGE_GENERATION_ENGINE
app
.
state
.
OPENAI_API_KEY
=
IMAGES_OPENAI_API_KEY
app
.
state
.
config
.
ENABLED
=
ENABLE_IMAGE_GENERATION
app
.
state
.
MODEL
=
IMAGE_GENERATION_MODEL
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
IMAGES_OPENAI_API_BASE_URL
app
.
state
.
config
.
OPENAI_API_KEY
=
IMAGES_OPENAI_API_KEY
app
.
state
.
config
.
MODEL
=
IMAGE_GENERATION_MODEL
app
.
state
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
app
.
state
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
app
.
state
.
config
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
app
.
state
.
IMAGE_SIZE
=
IMAGE_SIZE
app
.
state
.
IMAGE_STEPS
=
IMAGE_STEPS
app
.
state
.
config
.
IMAGE_SIZE
=
IMAGE_SIZE
app
.
state
.
config
.
IMAGE_STEPS
=
IMAGE_STEPS
@
app
.
get
(
"/config"
)
@
app
.
get
(
"/config"
)
async
def
get_config
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
async
def
get_config
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
{
return
{
"engine"
:
config_get
(
app
.
state
.
ENGINE
)
,
"engine"
:
app
.
state
.
config
.
ENGINE
,
"enabled"
:
config_get
(
app
.
state
.
ENABLED
)
,
"enabled"
:
app
.
state
.
config
.
ENABLED
,
}
}
...
@@ -94,11 +95,11 @@ class ConfigUpdateForm(BaseModel):
...
@@ -94,11 +95,11 @@ class ConfigUpdateForm(BaseModel):
@
app
.
post
(
"/config/update"
)
@
app
.
post
(
"/config/update"
)
async
def
update_config
(
form_data
:
ConfigUpdateForm
,
user
=
Depends
(
get_admin_user
)):
async
def
update_config
(
form_data
:
ConfigUpdateForm
,
user
=
Depends
(
get_admin_user
)):
config_set
(
app
.
state
.
ENGINE
,
form_data
.
engine
)
app
.
state
.
config
.
ENGINE
=
form_data
.
engine
config_set
(
app
.
state
.
ENABLED
,
form_data
.
enabled
)
app
.
state
.
config
.
ENABLED
=
form_data
.
enabled
return
{
return
{
"engine"
:
config_get
(
app
.
state
.
ENGINE
)
,
"engine"
:
app
.
state
.
config
.
ENGINE
,
"enabled"
:
config_get
(
app
.
state
.
ENABLED
)
,
"enabled"
:
app
.
state
.
config
.
ENABLED
,
}
}
...
@@ -110,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel):
...
@@ -110,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel):
@
app
.
get
(
"/url"
)
@
app
.
get
(
"/url"
)
async
def
get_engine_url
(
user
=
Depends
(
get_admin_user
)):
async
def
get_engine_url
(
user
=
Depends
(
get_admin_user
)):
return
{
return
{
"AUTOMATIC1111_BASE_URL"
:
config_get
(
app
.
state
.
AUTOMATIC1111_BASE_URL
)
,
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
config_get
(
app
.
state
.
COMFYUI_BASE_URL
)
,
"COMFYUI_BASE_URL"
:
app
.
state
.
config
.
COMFYUI_BASE_URL
,
}
}
...
@@ -121,29 +122,29 @@ async def update_engine_url(
...
@@ -121,29 +122,29 @@ async def update_engine_url(
):
):
if
form_data
.
AUTOMATIC1111_BASE_URL
==
None
:
if
form_data
.
AUTOMATIC1111_BASE_URL
==
None
:
config_set
(
app
.
state
.
AUTOMATIC1111_BASE_URL
,
config_get
(
AUTOMATIC1111_BASE_URL
))
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
else
:
else
:
url
=
form_data
.
AUTOMATIC1111_BASE_URL
.
strip
(
"/"
)
url
=
form_data
.
AUTOMATIC1111_BASE_URL
.
strip
(
"/"
)
try
:
try
:
r
=
requests
.
head
(
url
)
r
=
requests
.
head
(
url
)
config_set
(
app
.
state
.
AUTOMATIC1111_BASE_URL
,
url
)
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
=
url
except
Exception
as
e
:
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
if
form_data
.
COMFYUI_BASE_URL
==
None
:
if
form_data
.
COMFYUI_BASE_URL
==
None
:
config_set
(
app
.
state
.
COMFYUI_BASE_URL
,
COMFYUI_BASE_URL
)
app
.
state
.
config
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
else
:
else
:
url
=
form_data
.
COMFYUI_BASE_URL
.
strip
(
"/"
)
url
=
form_data
.
COMFYUI_BASE_URL
.
strip
(
"/"
)
try
:
try
:
r
=
requests
.
head
(
url
)
r
=
requests
.
head
(
url
)
config_set
(
app
.
state
.
COMFYUI_BASE_URL
,
url
)
app
.
state
.
config
.
COMFYUI_BASE_URL
=
url
except
Exception
as
e
:
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
return
{
return
{
"AUTOMATIC1111_BASE_URL"
:
config_get
(
app
.
state
.
AUTOMATIC1111_BASE_URL
)
,
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
config_get
(
app
.
state
.
COMFYUI_BASE_URL
)
,
"COMFYUI_BASE_URL"
:
app
.
state
.
config
.
COMFYUI_BASE_URL
,
"status"
:
True
,
"status"
:
True
,
}
}
...
@@ -156,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel):
...
@@ -156,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel):
@
app
.
get
(
"/openai/config"
)
@
app
.
get
(
"/openai/config"
)
async
def
get_openai_config
(
user
=
Depends
(
get_admin_user
)):
async
def
get_openai_config
(
user
=
Depends
(
get_admin_user
)):
return
{
return
{
"OPENAI_API_BASE_URL"
:
config_get
(
app
.
state
.
OPENAI_API_BASE_URL
)
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
config_get
(
app
.
state
.
OPENAI_API_KEY
)
,
"OPENAI_API_KEY"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
}
}
...
@@ -168,13 +169,13 @@ async def update_openai_config(
...
@@ -168,13 +169,13 @@ async def update_openai_config(
if
form_data
.
key
==
""
:
if
form_data
.
key
==
""
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
API_KEY_NOT_FOUND
)
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
API_KEY_NOT_FOUND
)
config_set
(
app
.
state
.
OPENAI_API_BASE_URL
,
form_data
.
url
)
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
form_data
.
url
config_set
(
app
.
state
.
OPENAI_API_KEY
,
form_data
.
key
)
app
.
state
.
config
.
OPENAI_API_KEY
=
form_data
.
key
return
{
return
{
"status"
:
True
,
"status"
:
True
,
"OPENAI_API_BASE_URL"
:
config_get
(
app
.
state
.
OPENAI_API_BASE_URL
)
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
config_get
(
app
.
state
.
OPENAI_API_KEY
)
,
"OPENAI_API_KEY"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
}
}
...
@@ -184,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel):
...
@@ -184,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel):
@
app
.
get
(
"/size"
)
@
app
.
get
(
"/size"
)
async
def
get_image_size
(
user
=
Depends
(
get_admin_user
)):
async
def
get_image_size
(
user
=
Depends
(
get_admin_user
)):
return
{
"IMAGE_SIZE"
:
config_get
(
app
.
state
.
IMAGE_SIZE
)
}
return
{
"IMAGE_SIZE"
:
app
.
state
.
config
.
IMAGE_SIZE
}
@
app
.
post
(
"/size/update"
)
@
app
.
post
(
"/size/update"
)
...
@@ -193,9 +194,9 @@ async def update_image_size(
...
@@ -193,9 +194,9 @@ async def update_image_size(
):
):
pattern
=
r
"^\d+x\d+$"
# Regular expression pattern
pattern
=
r
"^\d+x\d+$"
# Regular expression pattern
if
re
.
match
(
pattern
,
form_data
.
size
):
if
re
.
match
(
pattern
,
form_data
.
size
):
config_set
(
app
.
state
.
IMAGE_SIZE
,
form_data
.
size
)
app
.
state
.
config
.
IMAGE_SIZE
=
form_data
.
size
return
{
return
{
"IMAGE_SIZE"
:
config_get
(
app
.
state
.
IMAGE_SIZE
)
,
"IMAGE_SIZE"
:
app
.
state
.
config
.
IMAGE_SIZE
,
"status"
:
True
,
"status"
:
True
,
}
}
else
:
else
:
...
@@ -211,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel):
...
@@ -211,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel):
@
app
.
get
(
"/steps"
)
@
app
.
get
(
"/steps"
)
async
def
get_image_size
(
user
=
Depends
(
get_admin_user
)):
async
def
get_image_size
(
user
=
Depends
(
get_admin_user
)):
return
{
"IMAGE_STEPS"
:
config_get
(
app
.
state
.
IMAGE_STEPS
)
}
return
{
"IMAGE_STEPS"
:
app
.
state
.
config
.
IMAGE_STEPS
}
@
app
.
post
(
"/steps/update"
)
@
app
.
post
(
"/steps/update"
)
...
@@ -219,9 +220,9 @@ async def update_image_size(
...
@@ -219,9 +220,9 @@ async def update_image_size(
form_data
:
ImageStepsUpdateForm
,
user
=
Depends
(
get_admin_user
)
form_data
:
ImageStepsUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
):
if
form_data
.
steps
>=
0
:
if
form_data
.
steps
>=
0
:
config_set
(
app
.
state
.
IMAGE_STEPS
,
form_data
.
steps
)
app
.
state
.
config
.
IMAGE_STEPS
=
form_data
.
steps
return
{
return
{
"IMAGE_STEPS"
:
config_get
(
app
.
state
.
IMAGE_STEPS
)
,
"IMAGE_STEPS"
:
app
.
state
.
config
.
IMAGE_STEPS
,
"status"
:
True
,
"status"
:
True
,
}
}
else
:
else
:
...
@@ -234,14 +235,14 @@ async def update_image_size(
...
@@ -234,14 +235,14 @@ async def update_image_size(
@
app
.
get
(
"/models"
)
@
app
.
get
(
"/models"
)
def
get_models
(
user
=
Depends
(
get_current_user
)):
def
get_models
(
user
=
Depends
(
get_current_user
)):
try
:
try
:
if
app
.
state
.
ENGINE
==
"openai"
:
if
app
.
state
.
config
.
ENGINE
==
"openai"
:
return
[
return
[
{
"id"
:
"dall-e-2"
,
"name"
:
"DALL·E 2"
},
{
"id"
:
"dall-e-2"
,
"name"
:
"DALL·E 2"
},
{
"id"
:
"dall-e-3"
,
"name"
:
"DALL·E 3"
},
{
"id"
:
"dall-e-3"
,
"name"
:
"DALL·E 3"
},
]
]
elif
app
.
state
.
ENGINE
==
"comfyui"
:
elif
app
.
state
.
config
.
ENGINE
==
"comfyui"
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
COMFYUI_BASE_URL
}
/object_info"
)
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
config
.
COMFYUI_BASE_URL
}
/object_info"
)
info
=
r
.
json
()
info
=
r
.
json
()
return
list
(
return
list
(
...
@@ -253,7 +254,7 @@ def get_models(user=Depends(get_current_user)):
...
@@ -253,7 +254,7 @@ def get_models(user=Depends(get_current_user)):
else
:
else
:
r
=
requests
.
get
(
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/sd-models"
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/sd-models"
)
)
models
=
r
.
json
()
models
=
r
.
json
()
return
list
(
return
list
(
...
@@ -263,33 +264,29 @@ def get_models(user=Depends(get_current_user)):
...
@@ -263,33 +264,29 @@ def get_models(user=Depends(get_current_user)):
)
)
)
)
except
Exception
as
e
:
except
Exception
as
e
:
app
.
state
.
ENABLED
=
False
app
.
state
.
config
.
ENABLED
=
False
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
@
app
.
get
(
"/models/default"
)
@
app
.
get
(
"/models/default"
)
async
def
get_default_model
(
user
=
Depends
(
get_admin_user
)):
async
def
get_default_model
(
user
=
Depends
(
get_admin_user
)):
try
:
try
:
if
app
.
state
.
ENGINE
==
"openai"
:
if
app
.
state
.
config
.
ENGINE
==
"openai"
:
return
{
"model"
:
(
config_get
(
app
.
state
.
MODEL
)
if
config_get
(
app
.
state
.
MODEL
)
else
"dall-e-2"
)
}
elif
app
.
state
.
ENGINE
==
"comfyui"
:
return
{
return
{
"model"
:
(
"model"
:
(
config_get
(
app
.
state
.
MODEL
)
if
config_get
(
app
.
state
.
MODEL
)
else
""
app
.
state
.
config
.
MODEL
if
app
.
state
.
config
.
MODEL
else
"
dall-e-2
"
)
)
}
}
elif
app
.
state
.
config
.
ENGINE
==
"comfyui"
:
return
{
"model"
:
(
app
.
state
.
config
.
MODEL
if
app
.
state
.
config
.
MODEL
else
""
)}
else
:
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
options
=
r
.
json
()
return
{
"model"
:
options
[
"sd_model_checkpoint"
]}
return
{
"model"
:
options
[
"sd_model_checkpoint"
]}
except
Exception
as
e
:
except
Exception
as
e
:
config_set
(
app
.
state
.
ENABLED
,
False
)
app
.
state
.
config
.
ENABLED
=
False
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
...
@@ -298,17 +295,20 @@ class UpdateModelForm(BaseModel):
...
@@ -298,17 +295,20 @@ class UpdateModelForm(BaseModel):
def
set_model_handler
(
model
:
str
):
def
set_model_handler
(
model
:
str
):
if
app
.
state
.
ENGINE
in
[
"openai"
,
"comfyui"
]:
if
app
.
state
.
config
.
ENGINE
in
[
"openai"
,
"comfyui"
]:
config_set
(
app
.
state
.
MODEL
,
model
)
app
.
state
.
config
.
MODEL
=
model
return
config_get
(
app
.
state
.
MODEL
)
return
app
.
state
.
config
.
MODEL
else
:
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
options
=
r
.
json
()
if
model
!=
options
[
"sd_model_checkpoint"
]:
if
model
!=
options
[
"sd_model_checkpoint"
]:
options
[
"sd_model_checkpoint"
]
=
model
options
[
"sd_model_checkpoint"
]
=
model
r
=
requests
.
post
(
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
,
json
=
options
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
,
json
=
options
,
)
)
return
options
return
options
...
@@ -397,30 +397,32 @@ def generate_image(
...
@@ -397,30 +397,32 @@ def generate_image(
user
=
Depends
(
get_current_user
),
user
=
Depends
(
get_current_user
),
):
):
width
,
height
=
tuple
(
map
(
int
,
config_get
(
app
.
state
.
IMAGE_SIZE
).
split
(
"x"
))
)
width
,
height
=
tuple
(
map
(
int
,
app
.
state
.
config
.
IMAGE_SIZE
).
split
(
"x"
))
r
=
None
r
=
None
try
:
try
:
if
app
.
state
.
ENGINE
==
"openai"
:
if
app
.
state
.
config
.
ENGINE
==
"openai"
:
headers
=
{}
headers
=
{}
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
OPENAI_API_KEY
}
"
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
config
.
OPENAI_API_KEY
}
"
headers
[
"Content-Type"
]
=
"application/json"
headers
[
"Content-Type"
]
=
"application/json"
data
=
{
data
=
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
!=
""
else
"dall-e-2"
,
"model"
:
(
app
.
state
.
config
.
MODEL
if
app
.
state
.
config
.
MODEL
!=
""
else
"dall-e-2"
),
"prompt"
:
form_data
.
prompt
,
"prompt"
:
form_data
.
prompt
,
"n"
:
form_data
.
n
,
"n"
:
form_data
.
n
,
"size"
:
(
"size"
:
(
form_data
.
size
form_data
.
size
if
form_data
.
size
else
app
.
state
.
config
.
IMAGE_SIZE
if
form_data
.
size
else
config_get
(
app
.
state
.
IMAGE_SIZE
)
),
),
"response_format"
:
"b64_json"
,
"response_format"
:
"b64_json"
,
}
}
r
=
requests
.
post
(
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
OPENAI_API_BASE_URL
}
/images/generations"
,
url
=
f
"
{
app
.
state
.
config
.
OPENAI_API_BASE_URL
}
/images/generations"
,
json
=
data
,
json
=
data
,
headers
=
headers
,
headers
=
headers
,
)
)
...
@@ -440,7 +442,7 @@ def generate_image(
...
@@ -440,7 +442,7 @@ def generate_image(
return
images
return
images
elif
app
.
state
.
ENGINE
==
"comfyui"
:
elif
app
.
state
.
config
.
ENGINE
==
"comfyui"
:
data
=
{
data
=
{
"prompt"
:
form_data
.
prompt
,
"prompt"
:
form_data
.
prompt
,
...
@@ -449,8 +451,8 @@ def generate_image(
...
@@ -449,8 +451,8 @@ def generate_image(
"n"
:
form_data
.
n
,
"n"
:
form_data
.
n
,
}
}
if
config_get
(
app
.
state
.
IMAGE_STEPS
)
is
not
None
:
if
app
.
state
.
config
.
IMAGE_STEPS
is
not
None
:
data
[
"steps"
]
=
config_get
(
app
.
state
.
IMAGE_STEPS
)
data
[
"steps"
]
=
app
.
state
.
config
.
IMAGE_STEPS
if
form_data
.
negative_prompt
is
not
None
:
if
form_data
.
negative_prompt
is
not
None
:
data
[
"negative_prompt"
]
=
form_data
.
negative_prompt
data
[
"negative_prompt"
]
=
form_data
.
negative_prompt
...
@@ -458,10 +460,10 @@ def generate_image(
...
@@ -458,10 +460,10 @@ def generate_image(
data
=
ImageGenerationPayload
(
**
data
)
data
=
ImageGenerationPayload
(
**
data
)
res
=
comfyui_generate_image
(
res
=
comfyui_generate_image
(
config_get
(
app
.
state
.
MODEL
)
,
app
.
state
.
config
.
MODEL
,
data
,
data
,
user
.
id
,
user
.
id
,
config_get
(
app
.
state
.
COMFYUI_BASE_URL
)
,
app
.
state
.
config
.
COMFYUI_BASE_URL
,
)
)
log
.
debug
(
f
"res:
{
res
}
"
)
log
.
debug
(
f
"res:
{
res
}
"
)
...
@@ -488,14 +490,14 @@ def generate_image(
...
@@ -488,14 +490,14 @@ def generate_image(
"height"
:
height
,
"height"
:
height
,
}
}
if
config_get
(
app
.
state
.
IMAGE_STEPS
)
is
not
None
:
if
app
.
state
.
config
.
IMAGE_STEPS
is
not
None
:
data
[
"steps"
]
=
config_get
(
app
.
state
.
IMAGE_STEPS
)
data
[
"steps"
]
=
app
.
state
.
config
.
IMAGE_STEPS
if
form_data
.
negative_prompt
is
not
None
:
if
form_data
.
negative_prompt
is
not
None
:
data
[
"negative_prompt"
]
=
form_data
.
negative_prompt
data
[
"negative_prompt"
]
=
form_data
.
negative_prompt
r
=
requests
.
post
(
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/txt2img"
,
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/txt2img"
,
json
=
data
,
json
=
data
,
)
)
...
...
backend/apps/ollama/main.py
View file @
298e6848
...
@@ -46,8 +46,7 @@ from config import (
...
@@ -46,8 +46,7 @@ from config import (
ENABLE_MODEL_FILTER
,
ENABLE_MODEL_FILTER
,
MODEL_FILTER_LIST
,
MODEL_FILTER_LIST
,
UPLOAD_DIR
,
UPLOAD_DIR
,
config_set
,
AppConfig
,
config_get
,
)
)
from
utils.misc
import
calculate_sha256
from
utils.misc
import
calculate_sha256
...
@@ -63,11 +62,12 @@ app.add_middleware(
...
@@ -63,11 +62,12 @@ app.add_middleware(
allow_headers
=
[
"*"
],
allow_headers
=
[
"*"
],
)
)
app
.
state
.
config
=
AppConfig
()
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
OLLAMA_BASE_URLS
=
OLLAMA_BASE_URLS
app
.
state
.
config
.
OLLAMA_BASE_URLS
=
OLLAMA_BASE_URLS
app
.
state
.
MODELS
=
{}
app
.
state
.
MODELS
=
{}
...
@@ -98,7 +98,7 @@ async def get_status():
...
@@ -98,7 +98,7 @@ async def get_status():
@
app
.
get
(
"/urls"
)
@
app
.
get
(
"/urls"
)
async
def
get_ollama_api_urls
(
user
=
Depends
(
get_admin_user
)):
async
def
get_ollama_api_urls
(
user
=
Depends
(
get_admin_user
)):
return
{
"OLLAMA_BASE_URLS"
:
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
}
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
config
.
OLLAMA_BASE_URLS
}
class
UrlUpdateForm
(
BaseModel
):
class
UrlUpdateForm
(
BaseModel
):
...
@@ -107,10 +107,10 @@ class UrlUpdateForm(BaseModel):
...
@@ -107,10 +107,10 @@ class UrlUpdateForm(BaseModel):
@
app
.
post
(
"/urls/update"
)
@
app
.
post
(
"/urls/update"
)
async
def
update_ollama_api_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)):
async
def
update_ollama_api_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)):
config_set
(
app
.
state
.
OLLAMA_BASE_URLS
,
form_data
.
urls
)
app
.
state
.
config
.
OLLAMA_BASE_URLS
=
form_data
.
urls
log
.
info
(
f
"app.state.OLLAMA_BASE_URLS:
{
app
.
state
.
OLLAMA_BASE_URLS
}
"
)
log
.
info
(
f
"app.state.
config.
OLLAMA_BASE_URLS:
{
app
.
state
.
config
.
OLLAMA_BASE_URLS
}
"
)
return
{
"OLLAMA_BASE_URLS"
:
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
}
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
config
.
OLLAMA_BASE_URLS
}
@
app
.
get
(
"/cancel/{request_id}"
)
@
app
.
get
(
"/cancel/{request_id}"
)
...
@@ -155,9 +155,7 @@ def merge_models_lists(model_lists):
...
@@ -155,9 +155,7 @@ def merge_models_lists(model_lists):
async
def
get_all_models
():
async
def
get_all_models
():
log
.
info
(
"get_all_models()"
)
log
.
info
(
"get_all_models()"
)
tasks
=
[
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/tags"
)
for
url
in
app
.
state
.
config
.
OLLAMA_BASE_URLS
]
fetch_url
(
f
"
{
url
}
/api/tags"
)
for
url
in
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
responses
=
await
asyncio
.
gather
(
*
tasks
)
models
=
{
models
=
{
...
@@ -183,15 +181,14 @@ async def get_ollama_tags(
...
@@ -183,15 +181,14 @@ async def get_ollama_tags(
if
user
.
role
==
"user"
:
if
user
.
role
==
"user"
:
models
[
"models"
]
=
list
(
models
[
"models"
]
=
list
(
filter
(
filter
(
lambda
model
:
model
[
"name"
]
lambda
model
:
model
[
"name"
]
in
app
.
state
.
MODEL_FILTER_LIST
,
in
config_get
(
app
.
state
.
MODEL_FILTER_LIST
),
models
[
"models"
],
models
[
"models"
],
)
)
)
)
return
models
return
models
return
models
return
models
else
:
else
:
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
try
:
try
:
r
=
requests
.
request
(
method
=
"GET"
,
url
=
f
"
{
url
}
/api/tags"
)
r
=
requests
.
request
(
method
=
"GET"
,
url
=
f
"
{
url
}
/api/tags"
)
r
.
raise_for_status
()
r
.
raise_for_status
()
...
@@ -222,8 +219,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
...
@@ -222,8 +219,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
# returns lowest version
# returns lowest version
tasks
=
[
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/version"
)
fetch_url
(
f
"
{
url
}
/api/version"
)
for
url
in
app
.
state
.
config
.
OLLAMA_BASE_URLS
for
url
in
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
]
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
responses
=
await
asyncio
.
gather
(
*
tasks
)
responses
=
list
(
filter
(
lambda
x
:
x
is
not
None
,
responses
))
responses
=
list
(
filter
(
lambda
x
:
x
is
not
None
,
responses
))
...
@@ -243,7 +239,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
...
@@ -243,7 +239,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
detail
=
ERROR_MESSAGES
.
OLLAMA_NOT_FOUND
,
detail
=
ERROR_MESSAGES
.
OLLAMA_NOT_FOUND
,
)
)
else
:
else
:
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
try
:
try
:
r
=
requests
.
request
(
method
=
"GET"
,
url
=
f
"
{
url
}
/api/version"
)
r
=
requests
.
request
(
method
=
"GET"
,
url
=
f
"
{
url
}
/api/version"
)
r
.
raise_for_status
()
r
.
raise_for_status
()
...
@@ -275,7 +271,7 @@ class ModelNameForm(BaseModel):
...
@@ -275,7 +271,7 @@ class ModelNameForm(BaseModel):
async
def
pull_model
(
async
def
pull_model
(
form_data
:
ModelNameForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
form_data
:
ModelNameForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
):
):
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -363,7 +359,7 @@ async def push_model(
...
@@ -363,7 +359,7 @@ async def push_model(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
name
),
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
name
),
)
)
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
debug
(
f
"url:
{
url
}
"
)
log
.
debug
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -425,7 +421,7 @@ async def create_model(
...
@@ -425,7 +421,7 @@ async def create_model(
form_data
:
CreateModelForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
form_data
:
CreateModelForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
):
):
log
.
debug
(
f
"form_data:
{
form_data
}
"
)
log
.
debug
(
f
"form_data:
{
form_data
}
"
)
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -498,7 +494,7 @@ async def copy_model(
...
@@ -498,7 +494,7 @@ async def copy_model(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
source
),
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
source
),
)
)
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
try
:
...
@@ -545,7 +541,7 @@ async def delete_model(
...
@@ -545,7 +541,7 @@ async def delete_model(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
name
),
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
name
),
)
)
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
try
:
...
@@ -585,7 +581,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
...
@@ -585,7 +581,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
)
)
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
])
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
])
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
try
:
...
@@ -642,7 +638,7 @@ async def generate_embeddings(
...
@@ -642,7 +638,7 @@ async def generate_embeddings(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
)
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
try
:
...
@@ -692,7 +688,7 @@ def generate_ollama_embeddings(
...
@@ -692,7 +688,7 @@ def generate_ollama_embeddings(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
)
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
try
:
try
:
...
@@ -761,7 +757,7 @@ async def generate_completion(
...
@@ -761,7 +757,7 @@ async def generate_completion(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
)
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -864,7 +860,7 @@ async def generate_chat_completion(
...
@@ -864,7 +860,7 @@ async def generate_chat_completion(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
)
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -973,7 +969,7 @@ async def generate_openai_chat_completion(
...
@@ -973,7 +969,7 @@ async def generate_openai_chat_completion(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
)
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
r
=
None
...
@@ -1072,7 +1068,7 @@ async def get_openai_models(
...
@@ -1072,7 +1068,7 @@ async def get_openai_models(
}
}
else
:
else
:
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
try
:
try
:
r
=
requests
.
request
(
method
=
"GET"
,
url
=
f
"
{
url
}
/api/tags"
)
r
=
requests
.
request
(
method
=
"GET"
,
url
=
f
"
{
url
}
/api/tags"
)
r
.
raise_for_status
()
r
.
raise_for_status
()
...
@@ -1206,7 +1202,7 @@ async def download_model(
...
@@ -1206,7 +1202,7 @@ async def download_model(
if
url_idx
==
None
:
if
url_idx
==
None
:
url_idx
=
0
url_idx
=
0
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
file_name
=
parse_huggingface_url
(
form_data
.
url
)
file_name
=
parse_huggingface_url
(
form_data
.
url
)
...
@@ -1225,7 +1221,7 @@ async def download_model(
...
@@ -1225,7 +1221,7 @@ async def download_model(
def
upload_model
(
file
:
UploadFile
=
File
(...),
url_idx
:
Optional
[
int
]
=
None
):
def
upload_model
(
file
:
UploadFile
=
File
(...),
url_idx
:
Optional
[
int
]
=
None
):
if
url_idx
==
None
:
if
url_idx
==
None
:
url_idx
=
0
url_idx
=
0
ollama_url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
url_idx
]
ollama_url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
file
.
filename
}
"
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
file
.
filename
}
"
...
@@ -1290,7 +1286,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
...
@@ -1290,7 +1286,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# if url_idx == None:
# url_idx = 0
# url_idx = 0
# url =
config_get(
app.state.OLLAMA_BASE_URLS
)
[url_idx]
# url = app.state.
config.
OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
# total_size = file.size
...
@@ -1327,7 +1323,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
...
@@ -1327,7 +1323,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
async
def
deprecated_proxy
(
async
def
deprecated_proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_verified_user
)
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_verified_user
)
):
):
url
=
config_get
(
app
.
state
.
OLLAMA_BASE_URLS
)
[
0
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
0
]
target_url
=
f
"
{
url
}
/
{
path
}
"
target_url
=
f
"
{
url
}
/
{
path
}
"
body
=
await
request
.
body
()
body
=
await
request
.
body
()
...
...
backend/apps/openai/main.py
View file @
298e6848
...
@@ -26,8 +26,7 @@ from config import (
...
@@ -26,8 +26,7 @@ from config import (
CACHE_DIR
,
CACHE_DIR
,
ENABLE_MODEL_FILTER
,
ENABLE_MODEL_FILTER
,
MODEL_FILTER_LIST
,
MODEL_FILTER_LIST
,
config_set
,
AppConfig
,
config_get
,
)
)
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
...
@@ -47,11 +46,13 @@ app.add_middleware(
...
@@ -47,11 +46,13 @@ app.add_middleware(
allow_headers
=
[
"*"
],
allow_headers
=
[
"*"
],
)
)
app
.
state
.
config
=
AppConfig
()
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
OPENAI_API_BASE_URLS
=
OPENAI_API_BASE_URLS
app
.
state
.
config
.
OPENAI_API_BASE_URLS
=
OPENAI_API_BASE_URLS
app
.
state
.
OPENAI_API_KEYS
=
OPENAI_API_KEYS
app
.
state
.
config
.
OPENAI_API_KEYS
=
OPENAI_API_KEYS
app
.
state
.
MODELS
=
{}
app
.
state
.
MODELS
=
{}
...
@@ -77,34 +78,32 @@ class KeysUpdateForm(BaseModel):
...
@@ -77,34 +78,32 @@ class KeysUpdateForm(BaseModel):
@
app
.
get
(
"/urls"
)
@
app
.
get
(
"/urls"
)
async
def
get_openai_urls
(
user
=
Depends
(
get_admin_user
)):
async
def
get_openai_urls
(
user
=
Depends
(
get_admin_user
)):
return
{
"OPENAI_API_BASE_URLS"
:
config_get
(
app
.
state
.
OPENAI_API_BASE_URLS
)
}
return
{
"OPENAI_API_BASE_URLS"
:
app
.
state
.
config
.
OPENAI_API_BASE_URLS
}
@
app
.
post
(
"/urls/update"
)
@
app
.
post
(
"/urls/update"
)
async
def
update_openai_urls
(
form_data
:
UrlsUpdateForm
,
user
=
Depends
(
get_admin_user
)):
async
def
update_openai_urls
(
form_data
:
UrlsUpdateForm
,
user
=
Depends
(
get_admin_user
)):
await
get_all_models
()
await
get_all_models
()
config_set
(
app
.
state
.
OPENAI_API_BASE_URLS
,
form_data
.
urls
)
app
.
state
.
config
.
OPENAI_API_BASE_URLS
=
form_data
.
urls
return
{
"OPENAI_API_BASE_URLS"
:
config_get
(
app
.
state
.
OPENAI_API_BASE_URLS
)
}
return
{
"OPENAI_API_BASE_URLS"
:
app
.
state
.
config
.
OPENAI_API_BASE_URLS
}
@
app
.
get
(
"/keys"
)
@
app
.
get
(
"/keys"
)
async
def
get_openai_keys
(
user
=
Depends
(
get_admin_user
)):
async
def
get_openai_keys
(
user
=
Depends
(
get_admin_user
)):
return
{
"OPENAI_API_KEYS"
:
config_get
(
app
.
state
.
OPENAI_API_KEYS
)
}
return
{
"OPENAI_API_KEYS"
:
app
.
state
.
config
.
OPENAI_API_KEYS
}
@
app
.
post
(
"/keys/update"
)
@
app
.
post
(
"/keys/update"
)
async
def
update_openai_key
(
form_data
:
KeysUpdateForm
,
user
=
Depends
(
get_admin_user
)):
async
def
update_openai_key
(
form_data
:
KeysUpdateForm
,
user
=
Depends
(
get_admin_user
)):
config_set
(
app
.
state
.
OPENAI_API_KEYS
,
form_data
.
keys
)
app
.
state
.
config
.
OPENAI_API_KEYS
=
form_data
.
keys
return
{
"OPENAI_API_KEYS"
:
config_get
(
app
.
state
.
OPENAI_API_KEYS
)
}
return
{
"OPENAI_API_KEYS"
:
app
.
state
.
config
.
OPENAI_API_KEYS
}
@
app
.
post
(
"/audio/speech"
)
@
app
.
post
(
"/audio/speech"
)
async
def
speech
(
request
:
Request
,
user
=
Depends
(
get_verified_user
)):
async
def
speech
(
request
:
Request
,
user
=
Depends
(
get_verified_user
)):
idx
=
None
idx
=
None
try
:
try
:
idx
=
config_get
(
app
.
state
.
OPENAI_API_BASE_URLS
).
index
(
idx
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
.
index
(
"https://api.openai.com/v1"
)
"https://api.openai.com/v1"
)
body
=
await
request
.
body
()
body
=
await
request
.
body
()
name
=
hashlib
.
sha256
(
body
).
hexdigest
()
name
=
hashlib
.
sha256
(
body
).
hexdigest
()
...
@@ -118,15 +117,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
...
@@ -118,15 +117,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return
FileResponse
(
file_path
)
return
FileResponse
(
file_path
)
headers
=
{}
headers
=
{}
headers
[
"Authorization"
]
=
(
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
]
}
"
f
"Bearer
{
config_get
(
app
.
state
.
OPENAI_API_KEYS
)[
idx
]
}
"
)
headers
[
"Content-Type"
]
=
"application/json"
headers
[
"Content-Type"
]
=
"application/json"
r
=
None
r
=
None
try
:
try
:
r
=
requests
.
post
(
r
=
requests
.
post
(
url
=
f
"
{
config_get
(
app
.
state
.
OPENAI_API_BASE_URLS
)
[
idx
]
}
/audio/speech"
,
url
=
f
"
{
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
}
/audio/speech"
,
data
=
body
,
data
=
body
,
headers
=
headers
,
headers
=
headers
,
stream
=
True
,
stream
=
True
,
...
@@ -187,7 +184,7 @@ def merge_models_lists(model_lists):
...
@@ -187,7 +184,7 @@ def merge_models_lists(model_lists):
{
**
model
,
"urlIdx"
:
idx
}
{
**
model
,
"urlIdx"
:
idx
}
for
model
in
models
for
model
in
models
if
"api.openai.com"
if
"api.openai.com"
not
in
config_get
(
app
.
state
.
OPENAI_API_BASE_URLS
)
[
idx
]
not
in
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
or
"gpt"
in
model
[
"id"
]
or
"gpt"
in
model
[
"id"
]
]
]
)
)
...
@@ -199,14 +196,14 @@ async def get_all_models():
...
@@ -199,14 +196,14 @@ async def get_all_models():
log
.
info
(
"get_all_models()"
)
log
.
info
(
"get_all_models()"
)
if
(
if
(
len
(
config_get
(
app
.
state
.
OPENAI_API_KEYS
)
)
==
1
len
(
app
.
state
.
config
.
OPENAI_API_KEYS
)
==
1
and
config_get
(
app
.
state
.
OPENAI_API_KEYS
)
[
0
]
==
""
and
app
.
state
.
config
.
OPENAI_API_KEYS
[
0
]
==
""
):
):
models
=
{
"data"
:
[]}
models
=
{
"data"
:
[]}
else
:
else
:
tasks
=
[
tasks
=
[
fetch_url
(
f
"
{
url
}
/models"
,
config_get
(
app
.
state
.
OPENAI_API_KEYS
)
[
idx
])
fetch_url
(
f
"
{
url
}
/models"
,
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
])
for
idx
,
url
in
enumerate
(
config_get
(
app
.
state
.
OPENAI_API_BASE_URLS
)
)
for
idx
,
url
in
enumerate
(
app
.
state
.
config
.
OPENAI_API_BASE_URLS
)
]
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
responses
=
await
asyncio
.
gather
(
*
tasks
)
...
@@ -238,19 +235,18 @@ async def get_all_models():
...
@@ -238,19 +235,18 @@ async def get_all_models():
async
def
get_models
(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_current_user
)):
async
def
get_models
(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_current_user
)):
if
url_idx
==
None
:
if
url_idx
==
None
:
models
=
await
get_all_models
()
models
=
await
get_all_models
()
if
config_get
(
app
.
state
.
ENABLE_MODEL_FILTER
)
:
if
app
.
state
.
ENABLE_MODEL_FILTER
:
if
user
.
role
==
"user"
:
if
user
.
role
==
"user"
:
models
[
"data"
]
=
list
(
models
[
"data"
]
=
list
(
filter
(
filter
(
lambda
model
:
model
[
"id"
]
lambda
model
:
model
[
"id"
]
in
app
.
state
.
MODEL_FILTER_LIST
,
in
config_get
(
app
.
state
.
MODEL_FILTER_LIST
),
models
[
"data"
],
models
[
"data"
],
)
)
)
)
return
models
return
models
return
models
return
models
else
:
else
:
url
=
config_get
(
app
.
state
.
OPENAI_API_BASE_URLS
)
[
url_idx
]
url
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
url_idx
]
r
=
None
r
=
None
...
@@ -314,8 +310,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
...
@@ -314,8 +310,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except
json
.
JSONDecodeError
as
e
:
except
json
.
JSONDecodeError
as
e
:
log
.
error
(
"Error loading request body into a dictionary:"
,
e
)
log
.
error
(
"Error loading request body into a dictionary:"
,
e
)
url
=
config_get
(
app
.
state
.
OPENAI_API_BASE_URLS
)
[
idx
]
url
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
key
=
config_get
(
app
.
state
.
OPENAI_API_KEYS
)
[
idx
]
key
=
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
]
target_url
=
f
"
{
url
}
/
{
path
}
"
target_url
=
f
"
{
url
}
/
{
path
}
"
...
...
backend/apps/rag/main.py
View file @
298e6848
This diff is collapsed.
Click to expand it.
backend/apps/web/main.py
View file @
298e6848
...
@@ -22,21 +22,23 @@ from config import (
...
@@ -22,21 +22,23 @@ from config import (
WEBHOOK_URL
,
WEBHOOK_URL
,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
,
JWT_EXPIRES_IN
,
JWT_EXPIRES_IN
,
c
onfig
_get
,
AppC
onfig
,
)
)
app
=
FastAPI
()
app
=
FastAPI
()
origins
=
[
"*"
]
origins
=
[
"*"
]
app
.
state
.
ENABLE_SIGNUP
=
ENABLE_SIGNUP
app
.
state
.
config
=
AppConfig
()
app
.
state
.
JWT_EXPIRES_IN
=
JWT_EXPIRES_IN
app
.
state
.
DEFAULT_MODELS
=
DEFAULT_MODELS
app
.
state
.
config
.
ENABLE_SIGNUP
=
ENABLE_SIGNUP
app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
=
DEFAULT_PROMPT_SUGGESTIONS
app
.
state
.
config
.
JWT_EXPIRES_IN
=
JWT_EXPIRES_IN
app
.
state
.
DEFAULT_USER_ROLE
=
DEFAULT_USER_ROLE
app
.
state
.
USER_PERMISSIONS
=
USER_PERMISSIONS
app
.
state
.
config
.
DEFAULT_MODELS
=
DEFAULT_MODELS
app
.
state
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
=
DEFAULT_PROMPT_SUGGESTIONS
app
.
state
.
config
.
DEFAULT_USER_ROLE
=
DEFAULT_USER_ROLE
app
.
state
.
config
.
USER_PERMISSIONS
=
USER_PERMISSIONS
app
.
state
.
config
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
AUTH_TRUSTED_EMAIL_HEADER
=
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app
.
state
.
AUTH_TRUSTED_EMAIL_HEADER
=
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app
.
add_middleware
(
app
.
add_middleware
(
...
@@ -63,6 +65,6 @@ async def get_status():
...
@@ -63,6 +65,6 @@ async def get_status():
return
{
return
{
"status"
:
True
,
"status"
:
True
,
"auth"
:
WEBUI_AUTH
,
"auth"
:
WEBUI_AUTH
,
"default_models"
:
config_get
(
app
.
state
.
DEFAULT_MODELS
)
,
"default_models"
:
app
.
state
.
config
.
DEFAULT_MODELS
,
"default_prompt_suggestions"
:
config_get
(
app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
)
,
"default_prompt_suggestions"
:
app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
,
}
}
backend/apps/web/routers/auths.py
View file @
298e6848
...
@@ -33,7 +33,7 @@ from utils.utils import (
...
@@ -33,7 +33,7 @@ from utils.utils import (
from
utils.misc
import
parse_duration
,
validate_email_format
from
utils.misc
import
parse_duration
,
validate_email_format
from
utils.webhook
import
post_webhook
from
utils.webhook
import
post_webhook
from
constants
import
ERROR_MESSAGES
,
WEBHOOK_MESSAGES
from
constants
import
ERROR_MESSAGES
,
WEBHOOK_MESSAGES
from
config
import
WEBUI_AUTH
,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
,
config_get
,
config_set
from
config
import
WEBUI_AUTH
,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
router
=
APIRouter
()
router
=
APIRouter
()
...
@@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm):
...
@@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm):
if
user
:
if
user
:
token
=
create_token
(
token
=
create_token
(
data
=
{
"id"
:
user
.
id
},
data
=
{
"id"
:
user
.
id
},
expires_delta
=
parse_duration
(
config_get
(
request
.
app
.
state
.
JWT_EXPIRES_IN
)
)
,
expires_delta
=
parse_duration
(
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
),
)
)
return
{
return
{
...
@@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm):
...
@@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm):
@
router
.
post
(
"/signup"
,
response_model
=
SigninResponse
)
@
router
.
post
(
"/signup"
,
response_model
=
SigninResponse
)
async
def
signup
(
request
:
Request
,
form_data
:
SignupForm
):
async
def
signup
(
request
:
Request
,
form_data
:
SignupForm
):
if
not
config_get
(
request
.
app
.
state
.
ENABLE_SIGNUP
)
and
WEBUI_AUTH
:
if
not
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
and
WEBUI_AUTH
:
raise
HTTPException
(
raise
HTTPException
(
status
.
HTTP_403_FORBIDDEN
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
status
.
HTTP_403_FORBIDDEN
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
)
...
@@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm):
...
@@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm):
role
=
(
role
=
(
"admin"
"admin"
if
Users
.
get_num_users
()
==
0
if
Users
.
get_num_users
()
==
0
else
config_get
(
request
.
app
.
state
.
DEFAULT_USER_ROLE
)
else
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
)
)
hashed
=
get_password_hash
(
form_data
.
password
)
hashed
=
get_password_hash
(
form_data
.
password
)
user
=
Auths
.
insert_new_auth
(
user
=
Auths
.
insert_new_auth
(
...
@@ -194,15 +194,13 @@ async def signup(request: Request, form_data: SignupForm):
...
@@ -194,15 +194,13 @@ async def signup(request: Request, form_data: SignupForm):
if
user
:
if
user
:
token
=
create_token
(
token
=
create_token
(
data
=
{
"id"
:
user
.
id
},
data
=
{
"id"
:
user
.
id
},
expires_delta
=
parse_duration
(
expires_delta
=
parse_duration
(
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
),
config_get
(
request
.
app
.
state
.
JWT_EXPIRES_IN
)
),
)
)
# response.set_cookie(key='token', value=token, httponly=True)
# response.set_cookie(key='token', value=token, httponly=True)
if
config_get
(
request
.
app
.
state
.
WEBHOOK_URL
)
:
if
request
.
app
.
state
.
config
.
WEBHOOK_URL
:
post_webhook
(
post_webhook
(
config_get
(
request
.
app
.
state
.
WEBHOOK_URL
)
,
request
.
app
.
state
.
config
.
WEBHOOK_URL
,
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
{
{
"action"
:
"signup"
,
"action"
:
"signup"
,
...
@@ -278,15 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
...
@@ -278,15 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
@
router
.
get
(
"/signup/enabled"
,
response_model
=
bool
)
@
router
.
get
(
"/signup/enabled"
,
response_model
=
bool
)
async
def
get_sign_up_status
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
async
def
get_sign_up_status
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
config_get
(
request
.
app
.
state
.
ENABLE_SIGNUP
)
return
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
@
router
.
get
(
"/signup/enabled/toggle"
,
response_model
=
bool
)
@
router
.
get
(
"/signup/enabled/toggle"
,
response_model
=
bool
)
async
def
toggle_sign_up
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
async
def
toggle_sign_up
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
config_set
(
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
=
not
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
request
.
app
.
state
.
ENABLE_SIGNUP
,
not
config_get
(
request
.
app
.
state
.
ENABLE_SIGNUP
)
return
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
)
return
config_get
(
request
.
app
.
state
.
ENABLE_SIGNUP
)
############################
############################
...
@@ -296,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
...
@@ -296,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
@
router
.
get
(
"/signup/user/role"
)
@
router
.
get
(
"/signup/user/role"
)
async
def
get_default_user_role
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
async
def
get_default_user_role
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
config_get
(
request
.
app
.
state
.
DEFAULT_USER_ROLE
)
return
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
class
UpdateRoleForm
(
BaseModel
):
class
UpdateRoleForm
(
BaseModel
):
...
@@ -308,8 +304,8 @@ async def update_default_user_role(
...
@@ -308,8 +304,8 @@ async def update_default_user_role(
request
:
Request
,
form_data
:
UpdateRoleForm
,
user
=
Depends
(
get_admin_user
)
request
:
Request
,
form_data
:
UpdateRoleForm
,
user
=
Depends
(
get_admin_user
)
):
):
if
form_data
.
role
in
[
"pending"
,
"user"
,
"admin"
]:
if
form_data
.
role
in
[
"pending"
,
"user"
,
"admin"
]:
config_set
(
request
.
app
.
state
.
DEFAULT_USER_ROLE
,
form_data
.
role
)
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
=
form_data
.
role
return
config_get
(
request
.
app
.
state
.
DEFAULT_USER_ROLE
)
return
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
############################
############################
...
@@ -319,7 +315,7 @@ async def update_default_user_role(
...
@@ -319,7 +315,7 @@ async def update_default_user_role(
@
router
.
get
(
"/token/expires"
)
@
router
.
get
(
"/token/expires"
)
async
def
get_token_expires_duration
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
async
def
get_token_expires_duration
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
config_get
(
request
.
app
.
state
.
JWT_EXPIRES_IN
)
return
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
class
UpdateJWTExpiresDurationForm
(
BaseModel
):
class
UpdateJWTExpiresDurationForm
(
BaseModel
):
...
@@ -336,10 +332,10 @@ async def update_token_expires_duration(
...
@@ -336,10 +332,10 @@ async def update_token_expires_duration(
# Check if the input string matches the pattern
# Check if the input string matches the pattern
if
re
.
match
(
pattern
,
form_data
.
duration
):
if
re
.
match
(
pattern
,
form_data
.
duration
):
config_set
(
request
.
app
.
state
.
JWT_EXPIRES_IN
,
form_data
.
duration
)
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
=
form_data
.
duration
return
config_get
(
request
.
app
.
state
.
JWT_EXPIRES_IN
)
return
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
else
:
else
:
return
config_get
(
request
.
app
.
state
.
JWT_EXPIRES_IN
)
return
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
############################
############################
...
...
backend/apps/web/routers/configs.py
View file @
298e6848
...
@@ -9,7 +9,6 @@ import time
...
@@ -9,7 +9,6 @@ import time
import
uuid
import
uuid
from
apps.web.models.users
import
Users
from
apps.web.models.users
import
Users
from
config
import
config_set
,
config_get
from
utils.utils
import
(
from
utils.utils
import
(
get_password_hash
,
get_password_hash
,
...
@@ -45,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel):
...
@@ -45,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel):
async
def
set_global_default_models
(
async
def
set_global_default_models
(
request
:
Request
,
form_data
:
SetDefaultModelsForm
,
user
=
Depends
(
get_admin_user
)
request
:
Request
,
form_data
:
SetDefaultModelsForm
,
user
=
Depends
(
get_admin_user
)
):
):
config_set
(
request
.
app
.
state
.
DEFAULT_MODELS
,
form_data
.
models
)
request
.
app
.
state
.
config
.
DEFAULT_MODELS
=
form_data
.
models
return
config_get
(
request
.
app
.
state
.
DEFAULT_MODELS
)
return
request
.
app
.
state
.
config
.
DEFAULT_MODELS
@
router
.
post
(
"/default/suggestions"
,
response_model
=
List
[
PromptSuggestion
])
@
router
.
post
(
"/default/suggestions"
,
response_model
=
List
[
PromptSuggestion
])
...
@@ -56,5 +55,5 @@ async def set_global_default_suggestions(
...
@@ -56,5 +55,5 @@ async def set_global_default_suggestions(
user
=
Depends
(
get_admin_user
),
user
=
Depends
(
get_admin_user
),
):
):
data
=
form_data
.
model_dump
()
data
=
form_data
.
model_dump
()
config_set
(
request
.
app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
,
data
[
"suggestions"
]
)
request
.
app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
=
data
[
"suggestions"
]
return
config_get
(
request
.
app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
)
return
request
.
app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
backend/apps/web/routers/users.py
View file @
298e6848
...
@@ -15,7 +15,7 @@ from apps.web.models.auths import Auths
...
@@ -15,7 +15,7 @@ from apps.web.models.auths import Auths
from
utils.utils
import
get_current_user
,
get_password_hash
,
get_admin_user
from
utils.utils
import
get_current_user
,
get_password_hash
,
get_admin_user
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
from
config
import
SRC_LOG_LEVELS
,
config_set
,
config_get
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
...
@@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
...
@@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
@
router
.
get
(
"/permissions/user"
)
@
router
.
get
(
"/permissions/user"
)
async
def
get_user_permissions
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
async
def
get_user_permissions
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
config_get
(
request
.
app
.
state
.
USER_PERMISSIONS
)
return
request
.
app
.
state
.
config
.
USER_PERMISSIONS
@
router
.
post
(
"/permissions/user"
)
@
router
.
post
(
"/permissions/user"
)
async
def
update_user_permissions
(
async
def
update_user_permissions
(
request
:
Request
,
form_data
:
dict
,
user
=
Depends
(
get_admin_user
)
request
:
Request
,
form_data
:
dict
,
user
=
Depends
(
get_admin_user
)
):
):
config_set
(
request
.
app
.
state
.
USER_PERMISSIONS
,
form_data
)
request
.
app
.
state
.
config
.
USER_PERMISSIONS
=
form_data
return
config_get
(
request
.
app
.
state
.
USER_PERMISSIONS
)
return
request
.
app
.
state
.
config
.
USER_PERMISSIONS
############################
############################
...
...
backend/config.py
View file @
298e6848
...
@@ -246,19 +246,21 @@ class WrappedConfig(Generic[T]):
...
@@ -246,19 +246,21 @@ class WrappedConfig(Generic[T]):
self
.
config_value
=
self
.
value
self
.
config_value
=
self
.
value
def
config_set
(
config
:
Union
[
WrappedConfig
[
T
],
T
],
value
:
T
,
save_config
=
True
):
class
AppConfig
:
if
isinstance
(
config
,
WrappedConfig
):
_state
:
dict
[
str
,
WrappedConfig
]
config
.
value
=
value
if
save_config
:
def
__init__
(
self
):
config
.
save
()
super
().
__setattr__
(
"_state"
,
{})
else
:
config
=
value
def
__setattr__
(
self
,
key
,
value
):
if
isinstance
(
value
,
WrappedConfig
):
self
.
_state
[
key
]
=
value
def
config_get
(
config
:
Union
[
WrappedConfig
[
T
],
T
])
->
T
:
else
:
if
isinstance
(
config
,
WrappedConfig
):
self
.
_state
[
key
].
value
=
value
return
config
.
value
self
.
_state
[
key
].
save
()
return
config
def
__getattr__
(
self
,
key
):
return
self
.
_state
[
key
].
value
####################################
####################################
...
...
backend/main.py
View file @
298e6848
...
@@ -58,8 +58,7 @@ from config import (
...
@@ -58,8 +58,7 @@ from config import (
SRC_LOG_LEVELS
,
SRC_LOG_LEVELS
,
WEBHOOK_URL
,
WEBHOOK_URL
,
ENABLE_ADMIN_EXPORT
,
ENABLE_ADMIN_EXPORT
,
config_get
,
AppConfig
,
config_set
,
)
)
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
...
@@ -96,10 +95,11 @@ https://github.com/open-webui/open-webui
...
@@ -96,10 +95,11 @@ https://github.com/open-webui/open-webui
app
=
FastAPI
(
docs_url
=
"/docs"
if
ENV
==
"dev"
else
None
,
redoc_url
=
None
)
app
=
FastAPI
(
docs_url
=
"/docs"
if
ENV
==
"dev"
else
None
,
redoc_url
=
None
)
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
config
=
AppConfig
()
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
config
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
config
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
config
.
WEBHOOK_URL
=
WEBHOOK_URL
origins
=
[
"*"
]
origins
=
[
"*"
]
...
@@ -245,11 +245,9 @@ async def get_app_config():
...
@@ -245,11 +245,9 @@ async def get_app_config():
"version"
:
VERSION
,
"version"
:
VERSION
,
"auth"
:
WEBUI_AUTH
,
"auth"
:
WEBUI_AUTH
,
"default_locale"
:
default_locale
,
"default_locale"
:
default_locale
,
"images"
:
config_get
(
images_app
.
state
.
ENABLED
),
"images"
:
images_app
.
state
.
config
.
ENABLED
,
"default_models"
:
config_get
(
webui_app
.
state
.
DEFAULT_MODELS
),
"default_models"
:
webui_app
.
state
.
config
.
DEFAULT_MODELS
,
"default_prompt_suggestions"
:
config_get
(
"default_prompt_suggestions"
:
webui_app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
,
webui_app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
),
"trusted_header_auth"
:
bool
(
webui_app
.
state
.
AUTH_TRUSTED_EMAIL_HEADER
),
"trusted_header_auth"
:
bool
(
webui_app
.
state
.
AUTH_TRUSTED_EMAIL_HEADER
),
"admin_export_enabled"
:
ENABLE_ADMIN_EXPORT
,
"admin_export_enabled"
:
ENABLE_ADMIN_EXPORT
,
}
}
...
@@ -258,8 +256,8 @@ async def get_app_config():
...
@@ -258,8 +256,8 @@ async def get_app_config():
@
app
.
get
(
"/api/config/model/filter"
)
@
app
.
get
(
"/api/config/model/filter"
)
async
def
get_model_filter_config
(
user
=
Depends
(
get_admin_user
)):
async
def
get_model_filter_config
(
user
=
Depends
(
get_admin_user
)):
return
{
return
{
"enabled"
:
config_get
(
app
.
state
.
ENABLE_MODEL_FILTER
)
,
"enabled"
:
app
.
state
.
config
.
ENABLE_MODEL_FILTER
,
"models"
:
config_get
(
app
.
state
.
MODEL_FILTER_LIST
)
,
"models"
:
app
.
state
.
config
.
MODEL_FILTER_LIST
,
}
}
...
@@ -272,28 +270,28 @@ class ModelFilterConfigForm(BaseModel):
...
@@ -272,28 +270,28 @@ class ModelFilterConfigForm(BaseModel):
async
def
update_model_filter_config
(
async
def
update_model_filter_config
(
form_data
:
ModelFilterConfigForm
,
user
=
Depends
(
get_admin_user
)
form_data
:
ModelFilterConfigForm
,
user
=
Depends
(
get_admin_user
)
):
):
config_set
(
app
.
state
.
ENABLE_MODEL_FILTER
,
form_data
.
enabled
)
app
.
state
.
config
.
ENABLE_MODEL_FILTER
,
form_data
.
enabled
config_set
(
app
.
state
.
MODEL_FILTER_LIST
,
form_data
.
models
)
app
.
state
.
config
.
MODEL_FILTER_LIST
,
form_data
.
models
ollama_app
.
state
.
ENABLE_MODEL_FILTER
=
config_get
(
app
.
state
.
ENABLE_MODEL_FILTER
)
ollama_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
config
.
ENABLE_MODEL_FILTER
ollama_app
.
state
.
MODEL_FILTER_LIST
=
config_get
(
app
.
state
.
MODEL_FILTER_LIST
)
ollama_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
config
.
MODEL_FILTER_LIST
openai_app
.
state
.
ENABLE_MODEL_FILTER
=
config_get
(
app
.
state
.
ENABLE_MODEL_FILTER
)
openai_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
config
.
ENABLE_MODEL_FILTER
openai_app
.
state
.
MODEL_FILTER_LIST
=
config_get
(
app
.
state
.
MODEL_FILTER_LIST
)
openai_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
config
.
MODEL_FILTER_LIST
litellm_app
.
state
.
ENABLE_MODEL_FILTER
=
config_get
(
app
.
state
.
ENABLE_MODEL_FILTER
)
litellm_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
config
.
ENABLE_MODEL_FILTER
litellm_app
.
state
.
MODEL_FILTER_LIST
=
config_get
(
app
.
state
.
MODEL_FILTER_LIST
)
litellm_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
config
.
MODEL_FILTER_LIST
return
{
return
{
"enabled"
:
config_get
(
app
.
state
.
ENABLE_MODEL_FILTER
)
,
"enabled"
:
app
.
state
.
config
.
ENABLE_MODEL_FILTER
,
"models"
:
config_get
(
app
.
state
.
MODEL_FILTER_LIST
)
,
"models"
:
app
.
state
.
config
.
MODEL_FILTER_LIST
,
}
}
@
app
.
get
(
"/api/webhook"
)
@
app
.
get
(
"/api/webhook"
)
async
def
get_webhook_url
(
user
=
Depends
(
get_admin_user
)):
async
def
get_webhook_url
(
user
=
Depends
(
get_admin_user
)):
return
{
return
{
"url"
:
config_get
(
app
.
state
.
WEBHOOK_URL
)
,
"url"
:
app
.
state
.
config
.
WEBHOOK_URL
,
}
}
...
@@ -303,12 +301,12 @@ class UrlForm(BaseModel):
...
@@ -303,12 +301,12 @@ class UrlForm(BaseModel):
@
app
.
post
(
"/api/webhook"
)
@
app
.
post
(
"/api/webhook"
)
async
def
update_webhook_url
(
form_data
:
UrlForm
,
user
=
Depends
(
get_admin_user
)):
async
def
update_webhook_url
(
form_data
:
UrlForm
,
user
=
Depends
(
get_admin_user
)):
config_set
(
app
.
state
.
WEBHOOK_URL
,
form_data
.
url
)
app
.
state
.
config
.
WEBHOOK_URL
=
form_data
.
url
webui_app
.
state
.
WEBHOOK_URL
=
config_get
(
app
.
state
.
WEBHOOK_URL
)
webui_app
.
state
.
WEBHOOK_URL
=
app
.
state
.
config
.
WEBHOOK_URL
return
{
return
{
"url"
:
config_get
(
app
.
state
.
WEBHOOK_URL
)
,
"url"
:
app
.
state
.
config
.
WEBHOOK_URL
,
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment