Commit 298e6848 authored by Jun Siang Cheah's avatar Jun Siang Cheah
Browse files

feat: switch to config proxy, remove config_get/set

parent f712c900
...@@ -45,8 +45,7 @@ from config import ( ...@@ -45,8 +45,7 @@ from config import (
AUDIO_OPENAI_API_KEY, AUDIO_OPENAI_API_KEY,
AUDIO_OPENAI_API_MODEL, AUDIO_OPENAI_API_MODEL,
AUDIO_OPENAI_API_VOICE, AUDIO_OPENAI_API_VOICE,
config_get, AppConfig,
config_set,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -61,11 +60,11 @@ app.add_middleware( ...@@ -61,11 +60,11 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.config = AppConfig()
app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
app.state.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
app.state.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
# setting device type for whisper model # setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
...@@ -85,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel): ...@@ -85,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/config") @app.get("/config")
async def get_openai_config(user=Depends(get_admin_user)): async def get_openai_config(user=Depends(get_admin_user)):
return { return {
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL), "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
"OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE), "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
} }
...@@ -99,22 +98,17 @@ async def update_openai_config( ...@@ -99,22 +98,17 @@ async def update_openai_config(
if form_data.key == "": if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
config_set(app.state.OPENAI_API_BASE_URL, form_data.url) app.state.config.OPENAI_API_BASE_URL = form_data.url
config_set(app.state.OPENAI_API_KEY, form_data.key) app.state.config.OPENAI_API_KEY = form_data.key
config_set(app.state.OPENAI_API_MODEL, form_data.model) app.state.config.OPENAI_API_MODEL = form_data.model
config_set(app.state.OPENAI_API_VOICE, form_data.speaker) app.state.config.OPENAI_API_VOICE = form_data.speaker
app.state.OPENAI_API_BASE_URL.save()
app.state.OPENAI_API_KEY.save()
app.state.OPENAI_API_MODEL.save()
app.state.OPENAI_API_VOICE.save()
return { return {
"status": True, "status": True,
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_MODEL": config_get(app.state.OPENAI_API_MODEL), "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
"OPENAI_API_VOICE": config_get(app.state.OPENAI_API_VOICE), "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
} }
...@@ -131,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -131,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
r = None r = None
try: try:
r = requests.post( r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URL}/audio/speech", url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech",
data=body, data=body,
headers=headers, headers=headers,
stream=True, stream=True,
......
...@@ -42,8 +42,7 @@ from config import ( ...@@ -42,8 +42,7 @@ from config import (
IMAGE_GENERATION_MODEL, IMAGE_GENERATION_MODEL,
IMAGE_SIZE, IMAGE_SIZE,
IMAGE_STEPS, IMAGE_STEPS,
config_get, AppConfig,
config_set,
) )
...@@ -62,28 +61,30 @@ app.add_middleware( ...@@ -62,28 +61,30 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.ENGINE = IMAGE_GENERATION_ENGINE app.state.config = AppConfig()
app.state.ENABLED = ENABLE_IMAGE_GENERATION
app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
app.state.MODEL = IMAGE_GENERATION_MODEL app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.MODEL = IMAGE_GENERATION_MODEL
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.IMAGE_SIZE = IMAGE_SIZE
app.state.IMAGE_STEPS = IMAGE_STEPS app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
@app.get("/config") @app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)): async def get_config(request: Request, user=Depends(get_admin_user)):
return { return {
"engine": config_get(app.state.ENGINE), "engine": app.state.config.ENGINE,
"enabled": config_get(app.state.ENABLED), "enabled": app.state.config.ENABLED,
} }
...@@ -94,11 +95,11 @@ class ConfigUpdateForm(BaseModel): ...@@ -94,11 +95,11 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update") @app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
config_set(app.state.ENGINE, form_data.engine) app.state.config.ENGINE = form_data.engine
config_set(app.state.ENABLED, form_data.enabled) app.state.config.ENABLED = form_data.enabled
return { return {
"engine": config_get(app.state.ENGINE), "engine": app.state.config.ENGINE,
"enabled": config_get(app.state.ENABLED), "enabled": app.state.config.ENABLED,
} }
...@@ -110,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel): ...@@ -110,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel):
@app.get("/url") @app.get("/url")
async def get_engine_url(user=Depends(get_admin_user)): async def get_engine_url(user=Depends(get_admin_user)):
return { return {
"AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL), "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL), "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
} }
...@@ -121,29 +122,29 @@ async def update_engine_url( ...@@ -121,29 +122,29 @@ async def update_engine_url(
): ):
if form_data.AUTOMATIC1111_BASE_URL == None: if form_data.AUTOMATIC1111_BASE_URL == None:
config_set(app.state.AUTOMATIC1111_BASE_URL, config_get(AUTOMATIC1111_BASE_URL)) app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else: else:
url = form_data.AUTOMATIC1111_BASE_URL.strip("/") url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
try: try:
r = requests.head(url) r = requests.head(url)
config_set(app.state.AUTOMATIC1111_BASE_URL, url) app.state.config.AUTOMATIC1111_BASE_URL = url
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.COMFYUI_BASE_URL == None: if form_data.COMFYUI_BASE_URL == None:
config_set(app.state.COMFYUI_BASE_URL, COMFYUI_BASE_URL) app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else: else:
url = form_data.COMFYUI_BASE_URL.strip("/") url = form_data.COMFYUI_BASE_URL.strip("/")
try: try:
r = requests.head(url) r = requests.head(url)
config_set(app.state.COMFYUI_BASE_URL, url) app.state.config.COMFYUI_BASE_URL = url
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return { return {
"AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL), "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL), "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
"status": True, "status": True,
} }
...@@ -156,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel): ...@@ -156,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/openai/config") @app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)): async def get_openai_config(user=Depends(get_admin_user)):
return { return {
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
} }
...@@ -168,13 +169,13 @@ async def update_openai_config( ...@@ -168,13 +169,13 @@ async def update_openai_config(
if form_data.key == "": if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
config_set(app.state.OPENAI_API_BASE_URL, form_data.url) app.state.config.OPENAI_API_BASE_URL = form_data.url
config_set(app.state.OPENAI_API_KEY, form_data.key) app.state.config.OPENAI_API_KEY = form_data.key
return { return {
"status": True, "status": True,
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL), "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY), "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
} }
...@@ -184,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel): ...@@ -184,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel):
@app.get("/size") @app.get("/size")
async def get_image_size(user=Depends(get_admin_user)): async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE)} return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
@app.post("/size/update") @app.post("/size/update")
...@@ -193,9 +194,9 @@ async def update_image_size( ...@@ -193,9 +194,9 @@ async def update_image_size(
): ):
pattern = r"^\d+x\d+$" # Regular expression pattern pattern = r"^\d+x\d+$" # Regular expression pattern
if re.match(pattern, form_data.size): if re.match(pattern, form_data.size):
config_set(app.state.IMAGE_SIZE, form_data.size) app.state.config.IMAGE_SIZE = form_data.size
return { return {
"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE), "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
"status": True, "status": True,
} }
else: else:
...@@ -211,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel): ...@@ -211,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel):
@app.get("/steps") @app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)): async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS)} return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
@app.post("/steps/update") @app.post("/steps/update")
...@@ -219,9 +220,9 @@ async def update_image_size( ...@@ -219,9 +220,9 @@ async def update_image_size(
form_data: ImageStepsUpdateForm, user=Depends(get_admin_user) form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
): ):
if form_data.steps >= 0: if form_data.steps >= 0:
config_set(app.state.IMAGE_STEPS, form_data.steps) app.state.config.IMAGE_STEPS = form_data.steps
return { return {
"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS), "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
"status": True, "status": True,
} }
else: else:
...@@ -234,14 +235,14 @@ async def update_image_size( ...@@ -234,14 +235,14 @@ async def update_image_size(
@app.get("/models") @app.get("/models")
def get_models(user=Depends(get_current_user)): def get_models(user=Depends(get_current_user)):
try: try:
if app.state.ENGINE == "openai": if app.state.config.ENGINE == "openai":
return [ return [
{"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"}, {"id": "dall-e-3", "name": "DALL·E 3"},
] ]
elif app.state.ENGINE == "comfyui": elif app.state.config.ENGINE == "comfyui":
r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info") r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
info = r.json() info = r.json()
return list( return list(
...@@ -253,7 +254,7 @@ def get_models(user=Depends(get_current_user)): ...@@ -253,7 +254,7 @@ def get_models(user=Depends(get_current_user)):
else: else:
r = requests.get( r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
) )
models = r.json() models = r.json()
return list( return list(
...@@ -263,33 +264,29 @@ def get_models(user=Depends(get_current_user)): ...@@ -263,33 +264,29 @@ def get_models(user=Depends(get_current_user)):
) )
) )
except Exception as e: except Exception as e:
app.state.ENABLED = False app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models/default") @app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)): async def get_default_model(user=Depends(get_admin_user)):
try: try:
if app.state.ENGINE == "openai": if app.state.config.ENGINE == "openai":
return {
"model": (
config_get(app.state.MODEL)
if config_get(app.state.MODEL)
else "dall-e-2"
)
}
elif app.state.ENGINE == "comfyui":
return { return {
"model": ( "model": (
config_get(app.state.MODEL) if config_get(app.state.MODEL) else "" app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
) )
} }
elif app.state.config.ENGINE == "comfyui":
return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
else: else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
)
options = r.json() options = r.json()
return {"model": options["sd_model_checkpoint"]} return {"model": options["sd_model_checkpoint"]}
except Exception as e: except Exception as e:
config_set(app.state.ENABLED, False) app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
...@@ -298,17 +295,20 @@ class UpdateModelForm(BaseModel): ...@@ -298,17 +295,20 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str): def set_model_handler(model: str):
if app.state.ENGINE in ["openai", "comfyui"]: if app.state.config.ENGINE in ["openai", "comfyui"]:
config_set(app.state.MODEL, model) app.state.config.MODEL = model
return config_get(app.state.MODEL) return app.state.config.MODEL
else: else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
)
options = r.json() options = r.json()
if model != options["sd_model_checkpoint"]: if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model options["sd_model_checkpoint"] = model
r = requests.post( r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
) )
return options return options
...@@ -397,30 +397,32 @@ def generate_image( ...@@ -397,30 +397,32 @@ def generate_image(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
width, height = tuple(map(int, config_get(app.state.IMAGE_SIZE).split("x"))) width, height = tuple(map(int, app.state.config.IMAGE_SIZE).split("x"))
r = None r = None
try: try:
if app.state.ENGINE == "openai": if app.state.config.ENGINE == "openai":
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
data = { data = {
"model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", "model": (
app.state.config.MODEL
if app.state.config.MODEL != ""
else "dall-e-2"
),
"prompt": form_data.prompt, "prompt": form_data.prompt,
"n": form_data.n, "n": form_data.n,
"size": ( "size": (
form_data.size form_data.size if form_data.size else app.state.config.IMAGE_SIZE
if form_data.size
else config_get(app.state.IMAGE_SIZE)
), ),
"response_format": "b64_json", "response_format": "b64_json",
} }
r = requests.post( r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URL}/images/generations", url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
json=data, json=data,
headers=headers, headers=headers,
) )
...@@ -440,7 +442,7 @@ def generate_image( ...@@ -440,7 +442,7 @@ def generate_image(
return images return images
elif app.state.ENGINE == "comfyui": elif app.state.config.ENGINE == "comfyui":
data = { data = {
"prompt": form_data.prompt, "prompt": form_data.prompt,
...@@ -449,8 +451,8 @@ def generate_image( ...@@ -449,8 +451,8 @@ def generate_image(
"n": form_data.n, "n": form_data.n,
} }
if config_get(app.state.IMAGE_STEPS) is not None: if app.state.config.IMAGE_STEPS is not None:
data["steps"] = config_get(app.state.IMAGE_STEPS) data["steps"] = app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None: if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt data["negative_prompt"] = form_data.negative_prompt
...@@ -458,10 +460,10 @@ def generate_image( ...@@ -458,10 +460,10 @@ def generate_image(
data = ImageGenerationPayload(**data) data = ImageGenerationPayload(**data)
res = comfyui_generate_image( res = comfyui_generate_image(
config_get(app.state.MODEL), app.state.config.MODEL,
data, data,
user.id, user.id,
config_get(app.state.COMFYUI_BASE_URL), app.state.config.COMFYUI_BASE_URL,
) )
log.debug(f"res: {res}") log.debug(f"res: {res}")
...@@ -488,14 +490,14 @@ def generate_image( ...@@ -488,14 +490,14 @@ def generate_image(
"height": height, "height": height,
} }
if config_get(app.state.IMAGE_STEPS) is not None: if app.state.config.IMAGE_STEPS is not None:
data["steps"] = config_get(app.state.IMAGE_STEPS) data["steps"] = app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None: if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt data["negative_prompt"] = form_data.negative_prompt
r = requests.post( r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data, json=data,
) )
......
...@@ -46,8 +46,7 @@ from config import ( ...@@ -46,8 +46,7 @@ from config import (
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
UPLOAD_DIR, UPLOAD_DIR,
config_set, AppConfig,
config_get,
) )
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
...@@ -63,11 +62,12 @@ app.add_middleware( ...@@ -63,11 +62,12 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.config = AppConfig()
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {} app.state.MODELS = {}
...@@ -98,7 +98,7 @@ async def get_status(): ...@@ -98,7 +98,7 @@ async def get_status():
@app.get("/urls") @app.get("/urls")
async def get_ollama_api_urls(user=Depends(get_admin_user)): async def get_ollama_api_urls(user=Depends(get_admin_user)):
return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)} return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
class UrlUpdateForm(BaseModel): class UrlUpdateForm(BaseModel):
...@@ -107,10 +107,10 @@ class UrlUpdateForm(BaseModel): ...@@ -107,10 +107,10 @@ class UrlUpdateForm(BaseModel):
@app.post("/urls/update") @app.post("/urls/update")
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
config_set(app.state.OLLAMA_BASE_URLS, form_data.urls) app.state.config.OLLAMA_BASE_URLS = form_data.urls
log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}") log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}")
return {"OLLAMA_BASE_URLS": config_get(app.state.OLLAMA_BASE_URLS)} return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
@app.get("/cancel/{request_id}") @app.get("/cancel/{request_id}")
...@@ -155,9 +155,7 @@ def merge_models_lists(model_lists): ...@@ -155,9 +155,7 @@ def merge_models_lists(model_lists):
async def get_all_models(): async def get_all_models():
log.info("get_all_models()") log.info("get_all_models()")
tasks = [ tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS]
fetch_url(f"{url}/api/tags") for url in config_get(app.state.OLLAMA_BASE_URLS)
]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
models = { models = {
...@@ -183,15 +181,14 @@ async def get_ollama_tags( ...@@ -183,15 +181,14 @@ async def get_ollama_tags(
if user.role == "user": if user.role == "user":
models["models"] = list( models["models"] = list(
filter( filter(
lambda model: model["name"] lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
in config_get(app.state.MODEL_FILTER_LIST),
models["models"], models["models"],
) )
) )
return models return models
return models return models
else: else:
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try: try:
r = requests.request(method="GET", url=f"{url}/api/tags") r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status() r.raise_for_status()
...@@ -222,8 +219,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): ...@@ -222,8 +219,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
# returns lowest version # returns lowest version
tasks = [ tasks = [
fetch_url(f"{url}/api/version") fetch_url(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS
for url in config_get(app.state.OLLAMA_BASE_URLS)
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses)) responses = list(filter(lambda x: x is not None, responses))
...@@ -243,7 +239,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): ...@@ -243,7 +239,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
) )
else: else:
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try: try:
r = requests.request(method="GET", url=f"{url}/api/version") r = requests.request(method="GET", url=f"{url}/api/version")
r.raise_for_status() r.raise_for_status()
...@@ -275,7 +271,7 @@ class ModelNameForm(BaseModel): ...@@ -275,7 +271,7 @@ class ModelNameForm(BaseModel):
async def pull_model( async def pull_model(
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
): ):
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
...@@ -363,7 +359,7 @@ async def push_model( ...@@ -363,7 +359,7 @@ async def push_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
) )
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}") log.debug(f"url: {url}")
r = None r = None
...@@ -425,7 +421,7 @@ async def create_model( ...@@ -425,7 +421,7 @@ async def create_model(
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
): ):
log.debug(f"form_data: {form_data}") log.debug(f"form_data: {form_data}")
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
...@@ -498,7 +494,7 @@ async def copy_model( ...@@ -498,7 +494,7 @@ async def copy_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
) )
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
...@@ -545,7 +541,7 @@ async def delete_model( ...@@ -545,7 +541,7 @@ async def delete_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
) )
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
...@@ -585,7 +581,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us ...@@ -585,7 +581,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
) )
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
...@@ -642,7 +638,7 @@ async def generate_embeddings( ...@@ -642,7 +638,7 @@ async def generate_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
...@@ -692,7 +688,7 @@ def generate_ollama_embeddings( ...@@ -692,7 +688,7 @@ def generate_ollama_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
...@@ -761,7 +757,7 @@ async def generate_completion( ...@@ -761,7 +757,7 @@ async def generate_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
...@@ -864,7 +860,7 @@ async def generate_chat_completion( ...@@ -864,7 +860,7 @@ async def generate_chat_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
...@@ -973,7 +969,7 @@ async def generate_openai_chat_completion( ...@@ -973,7 +969,7 @@ async def generate_openai_chat_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
...@@ -1072,7 +1068,7 @@ async def get_openai_models( ...@@ -1072,7 +1068,7 @@ async def get_openai_models(
} }
else: else:
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try: try:
r = requests.request(method="GET", url=f"{url}/api/tags") r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status() r.raise_for_status()
...@@ -1206,7 +1202,7 @@ async def download_model( ...@@ -1206,7 +1202,7 @@ async def download_model(
if url_idx == None: if url_idx == None:
url_idx = 0 url_idx = 0
url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
file_name = parse_huggingface_url(form_data.url) file_name = parse_huggingface_url(form_data.url)
...@@ -1225,7 +1221,7 @@ async def download_model( ...@@ -1225,7 +1221,7 @@ async def download_model(
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
if url_idx == None: if url_idx == None:
url_idx = 0 url_idx = 0
ollama_url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
file_path = f"{UPLOAD_DIR}/{file.filename}" file_path = f"{UPLOAD_DIR}/{file.filename}"
...@@ -1290,7 +1286,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): ...@@ -1290,7 +1286,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None): # async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None: # if url_idx == None:
# url_idx = 0 # url_idx = 0
# url = config_get(app.state.OLLAMA_BASE_URLS)[url_idx] # url = app.state.config.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename) # file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size # total_size = file.size
...@@ -1327,7 +1323,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): ...@@ -1327,7 +1323,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
async def deprecated_proxy( async def deprecated_proxy(
path: str, request: Request, user=Depends(get_verified_user) path: str, request: Request, user=Depends(get_verified_user)
): ):
url = config_get(app.state.OLLAMA_BASE_URLS)[0] url = app.state.config.OLLAMA_BASE_URLS[0]
target_url = f"{url}/{path}" target_url = f"{url}/{path}"
body = await request.body() body = await request.body()
......
...@@ -26,8 +26,7 @@ from config import ( ...@@ -26,8 +26,7 @@ from config import (
CACHE_DIR, CACHE_DIR,
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
config_set, AppConfig,
config_get,
) )
from typing import List, Optional from typing import List, Optional
...@@ -47,11 +46,13 @@ app.add_middleware( ...@@ -47,11 +46,13 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.config = AppConfig()
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.MODELS = {} app.state.MODELS = {}
...@@ -77,34 +78,32 @@ class KeysUpdateForm(BaseModel): ...@@ -77,34 +78,32 @@ class KeysUpdateForm(BaseModel):
@app.get("/urls") @app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)): async def get_openai_urls(user=Depends(get_admin_user)):
return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)} return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.post("/urls/update") @app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
await get_all_models() await get_all_models()
config_set(app.state.OPENAI_API_BASE_URLS, form_data.urls) app.state.config.OPENAI_API_BASE_URLS = form_data.urls
return {"OPENAI_API_BASE_URLS": config_get(app.state.OPENAI_API_BASE_URLS)} return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.get("/keys") @app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)): async def get_openai_keys(user=Depends(get_admin_user)):
return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)} return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/keys/update") @app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
config_set(app.state.OPENAI_API_KEYS, form_data.keys) app.state.config.OPENAI_API_KEYS = form_data.keys
return {"OPENAI_API_KEYS": config_get(app.state.OPENAI_API_KEYS)} return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/audio/speech") @app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)): async def speech(request: Request, user=Depends(get_verified_user)):
idx = None idx = None
try: try:
idx = config_get(app.state.OPENAI_API_BASE_URLS).index( idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
"https://api.openai.com/v1"
)
body = await request.body() body = await request.body()
name = hashlib.sha256(body).hexdigest() name = hashlib.sha256(body).hexdigest()
...@@ -118,15 +117,13 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -118,15 +117,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
headers = {} headers = {}
headers["Authorization"] = ( headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
f"Bearer {config_get(app.state.OPENAI_API_KEYS)[idx]}"
)
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
r = None r = None
try: try:
r = requests.post( r = requests.post(
url=f"{config_get(app.state.OPENAI_API_BASE_URLS)[idx]}/audio/speech", url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
data=body, data=body,
headers=headers, headers=headers,
stream=True, stream=True,
...@@ -187,7 +184,7 @@ def merge_models_lists(model_lists): ...@@ -187,7 +184,7 @@ def merge_models_lists(model_lists):
{**model, "urlIdx": idx} {**model, "urlIdx": idx}
for model in models for model in models
if "api.openai.com" if "api.openai.com"
not in config_get(app.state.OPENAI_API_BASE_URLS)[idx] not in app.state.config.OPENAI_API_BASE_URLS[idx]
or "gpt" in model["id"] or "gpt" in model["id"]
] ]
) )
...@@ -199,14 +196,14 @@ async def get_all_models(): ...@@ -199,14 +196,14 @@ async def get_all_models():
log.info("get_all_models()") log.info("get_all_models()")
if ( if (
len(config_get(app.state.OPENAI_API_KEYS)) == 1 len(app.state.config.OPENAI_API_KEYS) == 1
and config_get(app.state.OPENAI_API_KEYS)[0] == "" and app.state.config.OPENAI_API_KEYS[0] == ""
): ):
models = {"data": []} models = {"data": []}
else: else:
tasks = [ tasks = [
fetch_url(f"{url}/models", config_get(app.state.OPENAI_API_KEYS)[idx]) fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(config_get(app.state.OPENAI_API_BASE_URLS)) for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
...@@ -238,19 +235,18 @@ async def get_all_models(): ...@@ -238,19 +235,18 @@ async def get_all_models():
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
if config_get(app.state.ENABLE_MODEL_FILTER): if app.state.ENABLE_MODEL_FILTER:
if user.role == "user": if user.role == "user":
models["data"] = list( models["data"] = list(
filter( filter(
lambda model: model["id"] lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
in config_get(app.state.MODEL_FILTER_LIST),
models["data"], models["data"],
) )
) )
return models return models
return models return models
else: else:
url = config_get(app.state.OPENAI_API_BASE_URLS)[url_idx] url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
r = None r = None
...@@ -314,8 +310,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -314,8 +310,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e) log.error("Error loading request body into a dictionary:", e)
url = config_get(app.state.OPENAI_API_BASE_URLS)[idx] url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = config_get(app.state.OPENAI_API_KEYS)[idx] key = app.state.config.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}" target_url = f"{url}/{path}"
......
This diff is collapsed.
...@@ -22,21 +22,23 @@ from config import ( ...@@ -22,21 +22,23 @@ from config import (
WEBHOOK_URL, WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN, JWT_EXPIRES_IN,
config_get, AppConfig,
) )
app = FastAPI() app = FastAPI()
origins = ["*"] origins = ["*"]
app.state.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config = AppConfig()
app.state.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.DEFAULT_MODELS = DEFAULT_MODELS app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.WEBHOOK_URL = WEBHOOK_URL app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.add_middleware( app.add_middleware(
...@@ -63,6 +65,6 @@ async def get_status(): ...@@ -63,6 +65,6 @@ async def get_status():
return { return {
"status": True, "status": True,
"auth": WEBUI_AUTH, "auth": WEBUI_AUTH,
"default_models": config_get(app.state.DEFAULT_MODELS), "default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": config_get(app.state.DEFAULT_PROMPT_SUGGESTIONS), "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
} }
...@@ -33,7 +33,7 @@ from utils.utils import ( ...@@ -33,7 +33,7 @@ from utils.utils import (
from utils.misc import parse_duration, validate_email_format from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, config_get, config_set from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER
router = APIRouter() router = APIRouter()
...@@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm): ...@@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm):
if user: if user:
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration(config_get(request.app.state.JWT_EXPIRES_IN)), expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
) )
return { return {
...@@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm): ...@@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse) @router.post("/signup", response_model=SigninResponse)
async def signup(request: Request, form_data: SignupForm): async def signup(request: Request, form_data: SignupForm):
if not config_get(request.app.state.ENABLE_SIGNUP) and WEBUI_AUTH: if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
raise HTTPException( raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
) )
...@@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm): ...@@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm):
role = ( role = (
"admin" "admin"
if Users.get_num_users() == 0 if Users.get_num_users() == 0
else config_get(request.app.state.DEFAULT_USER_ROLE) else request.app.state.config.DEFAULT_USER_ROLE
) )
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
...@@ -194,15 +194,13 @@ async def signup(request: Request, form_data: SignupForm): ...@@ -194,15 +194,13 @@ async def signup(request: Request, form_data: SignupForm):
if user: if user:
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration( expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
config_get(request.app.state.JWT_EXPIRES_IN)
),
) )
# response.set_cookie(key='token', value=token, httponly=True) # response.set_cookie(key='token', value=token, httponly=True)
if config_get(request.app.state.WEBHOOK_URL): if request.app.state.config.WEBHOOK_URL:
post_webhook( post_webhook(
config_get(request.app.state.WEBHOOK_URL), request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name), WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{ {
"action": "signup", "action": "signup",
...@@ -278,15 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): ...@@ -278,15 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
@router.get("/signup/enabled", response_model=bool) @router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)): async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
return config_get(request.app.state.ENABLE_SIGNUP) return request.app.state.config.ENABLE_SIGNUP
@router.get("/signup/enabled/toggle", response_model=bool) @router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
config_set( request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP
request.app.state.ENABLE_SIGNUP, not config_get(request.app.state.ENABLE_SIGNUP) return request.app.state.config.ENABLE_SIGNUP
)
return config_get(request.app.state.ENABLE_SIGNUP)
############################ ############################
...@@ -296,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): ...@@ -296,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
@router.get("/signup/user/role") @router.get("/signup/user/role")
async def get_default_user_role(request: Request, user=Depends(get_admin_user)): async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
return config_get(request.app.state.DEFAULT_USER_ROLE) return request.app.state.config.DEFAULT_USER_ROLE
class UpdateRoleForm(BaseModel): class UpdateRoleForm(BaseModel):
...@@ -308,8 +304,8 @@ async def update_default_user_role( ...@@ -308,8 +304,8 @@ async def update_default_user_role(
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user) request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
): ):
if form_data.role in ["pending", "user", "admin"]: if form_data.role in ["pending", "user", "admin"]:
config_set(request.app.state.DEFAULT_USER_ROLE, form_data.role) request.app.state.config.DEFAULT_USER_ROLE = form_data.role
return config_get(request.app.state.DEFAULT_USER_ROLE) return request.app.state.config.DEFAULT_USER_ROLE
############################ ############################
...@@ -319,7 +315,7 @@ async def update_default_user_role( ...@@ -319,7 +315,7 @@ async def update_default_user_role(
@router.get("/token/expires") @router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)): async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
return config_get(request.app.state.JWT_EXPIRES_IN) return request.app.state.config.JWT_EXPIRES_IN
class UpdateJWTExpiresDurationForm(BaseModel): class UpdateJWTExpiresDurationForm(BaseModel):
...@@ -336,10 +332,10 @@ async def update_token_expires_duration( ...@@ -336,10 +332,10 @@ async def update_token_expires_duration(
# Check if the input string matches the pattern # Check if the input string matches the pattern
if re.match(pattern, form_data.duration): if re.match(pattern, form_data.duration):
config_set(request.app.state.JWT_EXPIRES_IN, form_data.duration) request.app.state.config.JWT_EXPIRES_IN = form_data.duration
return config_get(request.app.state.JWT_EXPIRES_IN) return request.app.state.config.JWT_EXPIRES_IN
else: else:
return config_get(request.app.state.JWT_EXPIRES_IN) return request.app.state.config.JWT_EXPIRES_IN
############################ ############################
......
...@@ -9,7 +9,6 @@ import time ...@@ -9,7 +9,6 @@ import time
import uuid import uuid
from apps.web.models.users import Users from apps.web.models.users import Users
from config import config_set, config_get
from utils.utils import ( from utils.utils import (
get_password_hash, get_password_hash,
...@@ -45,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel): ...@@ -45,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel):
async def set_global_default_models( async def set_global_default_models(
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
): ):
config_set(request.app.state.DEFAULT_MODELS, form_data.models) request.app.state.config.DEFAULT_MODELS = form_data.models
return config_get(request.app.state.DEFAULT_MODELS) return request.app.state.config.DEFAULT_MODELS
@router.post("/default/suggestions", response_model=List[PromptSuggestion]) @router.post("/default/suggestions", response_model=List[PromptSuggestion])
...@@ -56,5 +55,5 @@ async def set_global_default_suggestions( ...@@ -56,5 +55,5 @@ async def set_global_default_suggestions(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
data = form_data.model_dump() data = form_data.model_dump()
config_set(request.app.state.DEFAULT_PROMPT_SUGGESTIONS, data["suggestions"]) request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
return config_get(request.app.state.DEFAULT_PROMPT_SUGGESTIONS) return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS
...@@ -15,7 +15,7 @@ from apps.web.models.auths import Auths ...@@ -15,7 +15,7 @@ from apps.web.models.auths import Auths
from utils.utils import get_current_user, get_password_hash, get_admin_user from utils.utils import get_current_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS, config_set, config_get from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
...@@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user) ...@@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
@router.get("/permissions/user") @router.get("/permissions/user")
async def get_user_permissions(request: Request, user=Depends(get_admin_user)): async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
return config_get(request.app.state.USER_PERMISSIONS) return request.app.state.config.USER_PERMISSIONS
@router.post("/permissions/user") @router.post("/permissions/user")
async def update_user_permissions( async def update_user_permissions(
request: Request, form_data: dict, user=Depends(get_admin_user) request: Request, form_data: dict, user=Depends(get_admin_user)
): ):
config_set(request.app.state.USER_PERMISSIONS, form_data) request.app.state.config.USER_PERMISSIONS = form_data
return config_get(request.app.state.USER_PERMISSIONS) return request.app.state.config.USER_PERMISSIONS
############################ ############################
......
...@@ -246,19 +246,21 @@ class WrappedConfig(Generic[T]): ...@@ -246,19 +246,21 @@ class WrappedConfig(Generic[T]):
self.config_value = self.value self.config_value = self.value
def config_set(config: Union[WrappedConfig[T], T], value: T, save_config=True): class AppConfig:
if isinstance(config, WrappedConfig): _state: dict[str, WrappedConfig]
config.value = value
if save_config: def __init__(self):
config.save() super().__setattr__("_state", {})
else:
config = value def __setattr__(self, key, value):
if isinstance(value, WrappedConfig):
self._state[key] = value
def config_get(config: Union[WrappedConfig[T], T]) -> T: else:
if isinstance(config, WrappedConfig): self._state[key].value = value
return config.value self._state[key].save()
return config
def __getattr__(self, key):
return self._state[key].value
#################################### ####################################
......
...@@ -58,8 +58,7 @@ from config import ( ...@@ -58,8 +58,7 @@ from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
WEBHOOK_URL, WEBHOOK_URL,
ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_EXPORT,
config_get, AppConfig,
config_set,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -96,10 +95,11 @@ https://github.com/open-webui/open-webui ...@@ -96,10 +95,11 @@ https://github.com/open-webui/open-webui
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config = AppConfig()
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
origins = ["*"] origins = ["*"]
...@@ -245,11 +245,9 @@ async def get_app_config(): ...@@ -245,11 +245,9 @@ async def get_app_config():
"version": VERSION, "version": VERSION,
"auth": WEBUI_AUTH, "auth": WEBUI_AUTH,
"default_locale": default_locale, "default_locale": default_locale,
"images": config_get(images_app.state.ENABLED), "images": images_app.state.config.ENABLED,
"default_models": config_get(webui_app.state.DEFAULT_MODELS), "default_models": webui_app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": config_get( "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
webui_app.state.DEFAULT_PROMPT_SUGGESTIONS
),
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"admin_export_enabled": ENABLE_ADMIN_EXPORT, "admin_export_enabled": ENABLE_ADMIN_EXPORT,
} }
...@@ -258,8 +256,8 @@ async def get_app_config(): ...@@ -258,8 +256,8 @@ async def get_app_config():
@app.get("/api/config/model/filter") @app.get("/api/config/model/filter")
async def get_model_filter_config(user=Depends(get_admin_user)): async def get_model_filter_config(user=Depends(get_admin_user)):
return { return {
"enabled": config_get(app.state.ENABLE_MODEL_FILTER), "enabled": app.state.config.ENABLE_MODEL_FILTER,
"models": config_get(app.state.MODEL_FILTER_LIST), "models": app.state.config.MODEL_FILTER_LIST,
} }
...@@ -272,28 +270,28 @@ class ModelFilterConfigForm(BaseModel): ...@@ -272,28 +270,28 @@ class ModelFilterConfigForm(BaseModel):
async def update_model_filter_config( async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user) form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
): ):
config_set(app.state.ENABLE_MODEL_FILTER, form_data.enabled) app.state.config.ENABLE_MODEL_FILTER, form_data.enabled
config_set(app.state.MODEL_FILTER_LIST, form_data.models) app.state.config.MODEL_FILTER_LIST, form_data.models
ollama_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
ollama_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
openai_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
openai_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
litellm_app.state.ENABLE_MODEL_FILTER = config_get(app.state.ENABLE_MODEL_FILTER) litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
litellm_app.state.MODEL_FILTER_LIST = config_get(app.state.MODEL_FILTER_LIST) litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
return { return {
"enabled": config_get(app.state.ENABLE_MODEL_FILTER), "enabled": app.state.config.ENABLE_MODEL_FILTER,
"models": config_get(app.state.MODEL_FILTER_LIST), "models": app.state.config.MODEL_FILTER_LIST,
} }
@app.get("/api/webhook") @app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)): async def get_webhook_url(user=Depends(get_admin_user)):
return { return {
"url": config_get(app.state.WEBHOOK_URL), "url": app.state.config.WEBHOOK_URL,
} }
...@@ -303,12 +301,12 @@ class UrlForm(BaseModel): ...@@ -303,12 +301,12 @@ class UrlForm(BaseModel):
@app.post("/api/webhook") @app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
config_set(app.state.WEBHOOK_URL, form_data.url) app.state.config.WEBHOOK_URL = form_data.url
webui_app.state.WEBHOOK_URL = config_get(app.state.WEBHOOK_URL) webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
return { return {
"url": config_get(app.state.WEBHOOK_URL), "url": app.state.config.WEBHOOK_URL,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment