Unverified Commit 9f0c9d97 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #4597 from michaelpoluektov/cleanup

refactor: search and replace-able cleanup
parents 1597e33a 0470146d
...@@ -15,7 +15,7 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse ...@@ -15,7 +15,7 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel
from typing import List
import uuid import uuid
import requests import requests
import hashlib import hashlib
...@@ -244,7 +244,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -244,7 +244,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message']}" error_detail = f"External: {res['error']['message']}"
except: except Exception:
error_detail = f"External: {e}" error_detail = f"External: {e}"
raise HTTPException( raise HTTPException(
...@@ -299,7 +299,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -299,7 +299,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message']}" error_detail = f"External: {res['error']['message']}"
except: except Exception:
error_detail = f"External: {e}" error_detail = f"External: {e}"
raise HTTPException( raise HTTPException(
...@@ -353,7 +353,7 @@ def transcribe( ...@@ -353,7 +353,7 @@ def transcribe(
try: try:
model = WhisperModel(**whisper_kwargs) model = WhisperModel(**whisper_kwargs)
except: except Exception:
log.warning( log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False" "WhisperModel initialization failed, attempting download with local_files_only=False"
) )
...@@ -421,7 +421,7 @@ def transcribe( ...@@ -421,7 +421,7 @@ def transcribe(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message']}" error_detail = f"External: {res['error']['message']}"
except: except Exception:
error_detail = f"External: {e}" error_detail = f"External: {e}"
raise HTTPException( raise HTTPException(
...@@ -438,7 +438,7 @@ def transcribe( ...@@ -438,7 +438,7 @@ def transcribe(
) )
def get_available_models() -> List[dict]: def get_available_models() -> list[dict]:
if app.state.config.TTS_ENGINE == "openai": if app.state.config.TTS_ENGINE == "openai":
return [{"id": "tts-1"}, {"id": "tts-1-hd"}] return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
elif app.state.config.TTS_ENGINE == "elevenlabs": elif app.state.config.TTS_ENGINE == "elevenlabs":
...@@ -466,7 +466,7 @@ async def get_models(user=Depends(get_verified_user)): ...@@ -466,7 +466,7 @@ async def get_models(user=Depends(get_verified_user)):
return {"models": get_available_models()} return {"models": get_available_models()}
def get_available_voices() -> List[dict]: def get_available_voices() -> list[dict]:
if app.state.config.TTS_ENGINE == "openai": if app.state.config.TTS_ENGINE == "openai":
return [ return [
{"name": "alloy", "id": "alloy"}, {"name": "alloy", "id": "alloy"},
......
...@@ -94,7 +94,7 @@ app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP ...@@ -94,7 +94,7 @@ app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP
def get_automatic1111_api_auth(): def get_automatic1111_api_auth():
if app.state.config.AUTOMATIC1111_API_AUTH == None: if app.state.config.AUTOMATIC1111_API_AUTH is None:
return "" return ""
else: else:
auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8") auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
...@@ -145,7 +145,7 @@ async def get_engine_url(user=Depends(get_admin_user)): ...@@ -145,7 +145,7 @@ async def get_engine_url(user=Depends(get_admin_user)):
async def update_engine_url( async def update_engine_url(
form_data: EngineUrlUpdateForm, user=Depends(get_admin_user) form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
): ):
if form_data.AUTOMATIC1111_BASE_URL == None: if form_data.AUTOMATIC1111_BASE_URL is None:
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else: else:
url = form_data.AUTOMATIC1111_BASE_URL.strip("/") url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
...@@ -156,7 +156,7 @@ async def update_engine_url( ...@@ -156,7 +156,7 @@ async def update_engine_url(
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
if form_data.COMFYUI_BASE_URL == None: if form_data.COMFYUI_BASE_URL is None:
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else: else:
url = form_data.COMFYUI_BASE_URL.strip("/") url = form_data.COMFYUI_BASE_URL.strip("/")
...@@ -168,7 +168,7 @@ async def update_engine_url( ...@@ -168,7 +168,7 @@ async def update_engine_url(
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
if form_data.AUTOMATIC1111_API_AUTH == None: if form_data.AUTOMATIC1111_API_AUTH is None:
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
else: else:
app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
......
...@@ -21,7 +21,7 @@ import asyncio ...@@ -21,7 +21,7 @@ import asyncio
import logging import logging
import time import time
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Optional, List, Union from typing import Optional, Union
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
...@@ -114,7 +114,7 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)): ...@@ -114,7 +114,7 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)):
class UrlUpdateForm(BaseModel): class UrlUpdateForm(BaseModel):
urls: List[str] urls: list[str]
@app.post("/urls/update") @app.post("/urls/update")
...@@ -646,7 +646,7 @@ def generate_ollama_embeddings( ...@@ -646,7 +646,7 @@ def generate_ollama_embeddings(
class GenerateCompletionForm(BaseModel): class GenerateCompletionForm(BaseModel):
model: str model: str
prompt: str prompt: str
images: Optional[List[str]] = None images: Optional[list[str]] = None
format: Optional[str] = None format: Optional[str] = None
options: Optional[dict] = None options: Optional[dict] = None
system: Optional[str] = None system: Optional[str] = None
...@@ -689,12 +689,12 @@ async def generate_completion( ...@@ -689,12 +689,12 @@ async def generate_completion(
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: str role: str
content: str content: str
images: Optional[List[str]] = None images: Optional[list[str]] = None
class GenerateChatCompletionForm(BaseModel): class GenerateChatCompletionForm(BaseModel):
model: str model: str
messages: List[ChatMessage] messages: list[ChatMessage]
format: Optional[str] = None format: Optional[str] = None
options: Optional[dict] = None options: Optional[dict] = None
template: Optional[str] = None template: Optional[str] = None
...@@ -772,7 +772,7 @@ class OpenAIChatMessage(BaseModel): ...@@ -772,7 +772,7 @@ class OpenAIChatMessage(BaseModel):
class OpenAIChatCompletionForm(BaseModel): class OpenAIChatCompletionForm(BaseModel):
model: str model: str
messages: List[OpenAIChatMessage] messages: list[OpenAIChatMessage]
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
......
...@@ -33,7 +33,7 @@ from config import ( ...@@ -33,7 +33,7 @@ from config import (
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
AppConfig, AppConfig,
) )
from typing import List, Optional, Literal, overload from typing import Optional, Literal, overload
import hashlib import hashlib
...@@ -89,11 +89,11 @@ async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user ...@@ -89,11 +89,11 @@ async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user
class UrlsUpdateForm(BaseModel): class UrlsUpdateForm(BaseModel):
urls: List[str] urls: list[str]
class KeysUpdateForm(BaseModel): class KeysUpdateForm(BaseModel):
keys: List[str] keys: list[str]
@app.get("/urls") @app.get("/urls")
......
...@@ -13,7 +13,7 @@ import os, shutil, logging, re ...@@ -13,7 +13,7 @@ import os, shutil, logging, re
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import List, Union, Sequence, Iterator, Any from typing import Union, Sequence, Iterator, Any
from chromadb.utils.batch_utils import create_batches from chromadb.utils.batch_utils import create_batches
from langchain_core.documents import Document from langchain_core.documents import Document
...@@ -439,7 +439,7 @@ class ChunkParamUpdateForm(BaseModel): ...@@ -439,7 +439,7 @@ class ChunkParamUpdateForm(BaseModel):
class YoutubeLoaderConfig(BaseModel): class YoutubeLoaderConfig(BaseModel):
language: List[str] language: list[str]
translation: Optional[str] = None translation: Optional[str] = None
...@@ -642,7 +642,7 @@ def query_doc_handler( ...@@ -642,7 +642,7 @@ def query_doc_handler(
class QueryCollectionsForm(BaseModel): class QueryCollectionsForm(BaseModel):
collection_names: List[str] collection_names: list[str]
query: str query: str
k: Optional[int] = None k: Optional[int] = None
r: Optional[float] = None r: Optional[float] = None
...@@ -1021,7 +1021,7 @@ class TikaLoader: ...@@ -1021,7 +1021,7 @@ class TikaLoader:
self.file_path = file_path self.file_path = file_path
self.mime_type = mime_type self.mime_type = mime_type
def load(self) -> List[Document]: def load(self) -> list[Document]:
with open(self.file_path, "rb") as f: with open(self.file_path, "rb") as f:
data = f.read() data = f.read()
...@@ -1185,7 +1185,7 @@ def store_doc( ...@@ -1185,7 +1185,7 @@ def store_doc(
f.close() f.close()
f = open(file_path, "rb") f = open(file_path, "rb")
if collection_name == None: if collection_name is None:
collection_name = calculate_sha256(f)[:63] collection_name = calculate_sha256(f)[:63]
f.close() f.close()
...@@ -1238,7 +1238,7 @@ def process_doc( ...@@ -1238,7 +1238,7 @@ def process_doc(
f = open(file_path, "rb") f = open(file_path, "rb")
collection_name = form_data.collection_name collection_name = form_data.collection_name
if collection_name == None: if collection_name is None:
collection_name = calculate_sha256(f)[:63] collection_name = calculate_sha256(f)[:63]
f.close() f.close()
...@@ -1296,7 +1296,7 @@ def store_text( ...@@ -1296,7 +1296,7 @@ def store_text(
): ):
collection_name = form_data.collection_name collection_name = form_data.collection_name
if collection_name == None: if collection_name is None:
collection_name = calculate_sha256_string(form_data.content) collection_name = calculate_sha256_string(form_data.content)
result = store_text_in_vector_db( result = store_text_in_vector_db(
...@@ -1339,7 +1339,7 @@ def scan_docs_dir(user=Depends(get_admin_user)): ...@@ -1339,7 +1339,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
sanitized_filename = sanitize_filename(filename) sanitized_filename = sanitize_filename(filename)
doc = Documents.get_doc_by_name(sanitized_filename) doc = Documents.get_doc_by_name(sanitized_filename)
if doc == None: if doc is None:
doc = Documents.insert_new_doc( doc = Documents.insert_new_doc(
user.id, user.id,
DocumentForm( DocumentForm(
......
import logging import logging
from typing import List, Optional from typing import Optional
import requests import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
...@@ -10,7 +10,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -10,7 +10,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave( def search_brave(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects. """Search using Brave's Search API and return the results as a list of SearchResult objects.
......
import logging import logging
from typing import List, Optional from typing import Optional
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -9,7 +9,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -9,7 +9,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo( def search_duckduckgo(
query: str, count: int, filter_list: Optional[List[str]] = None query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]: ) -> list[SearchResult]:
""" """
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects. Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
...@@ -18,7 +18,7 @@ def search_duckduckgo( ...@@ -18,7 +18,7 @@ def search_duckduckgo(
count (int): The number of results to return count (int): The number of results to return
Returns: Returns:
List[SearchResult]: A list of search results list[SearchResult]: A list of search results
""" """
# Use the DDGS context manager to create a DDGS object # Use the DDGS context manager to create a DDGS object
with DDGS() as ddgs: with DDGS() as ddgs:
......
import json import json
import logging import logging
from typing import List, Optional from typing import Optional
import requests import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
...@@ -15,7 +15,7 @@ def search_google_pse( ...@@ -15,7 +15,7 @@ def search_google_pse(
search_engine_id: str, search_engine_id: str,
query: str, query: str,
count: int, count: int,
filter_list: Optional[List[str]] = None, filter_list: Optional[list[str]] = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects. """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
......
...@@ -17,7 +17,7 @@ def search_jina(query: str, count: int) -> list[SearchResult]: ...@@ -17,7 +17,7 @@ def search_jina(query: str, count: int) -> list[SearchResult]:
count (int): The number of results to return count (int): The number of results to return
Returns: Returns:
List[SearchResult]: A list of search results list[SearchResult]: A list of search results
""" """
jina_search_endpoint = "https://s.jina.ai/" jina_search_endpoint = "https://s.jina.ai/"
headers = { headers = {
......
import logging import logging
import requests import requests
from typing import List, Optional from typing import Optional
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -14,9 +14,9 @@ def search_searxng( ...@@ -14,9 +14,9 @@ def search_searxng(
query_url: str, query_url: str,
query: str, query: str,
count: int, count: int,
filter_list: Optional[List[str]] = None, filter_list: Optional[list[str]] = None,
**kwargs, **kwargs,
) -> List[SearchResult]: ) -> list[SearchResult]:
""" """
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects. Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
...@@ -31,10 +31,10 @@ def search_searxng( ...@@ -31,10 +31,10 @@ def search_searxng(
language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string. language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate). safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''. time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
categories: (Optional[List[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided. categories: (Optional[list[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
Returns: Returns:
List[SearchResult]: A list of SearchResults sorted by relevance score in descending order. list[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
Raise: Raise:
requests.exceptions.RequestException: If a request error occurs during the search process. requests.exceptions.RequestException: If a request error occurs during the search process.
......
import json import json
import logging import logging
from typing import List, Optional from typing import Optional
import requests import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
...@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serper( def search_serper(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects. """Search using serper.dev's API and return the results as a list of SearchResult objects.
......
import json import json
import logging import logging
from typing import List, Optional from typing import Optional
import requests import requests
from urllib.parse import urlencode from urllib.parse import urlencode
...@@ -19,7 +19,7 @@ def search_serply( ...@@ -19,7 +19,7 @@ def search_serply(
limit: int = 10, limit: int = 10,
device_type: str = "desktop", device_type: str = "desktop",
proxy_location: str = "US", proxy_location: str = "US",
filter_list: Optional[List[str]] = None, filter_list: Optional[list[str]] = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects. """Search using serper.dev's API and return the results as a list of SearchResult objects.
......
import json import json
import logging import logging
from typing import List, Optional from typing import Optional
import requests import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
...@@ -14,7 +14,7 @@ def search_serpstack( ...@@ -14,7 +14,7 @@ def search_serpstack(
api_key: str, api_key: str,
query: str, query: str,
count: int, count: int,
filter_list: Optional[List[str]] = None, filter_list: Optional[list[str]] = None,
https_enabled: bool = True, https_enabled: bool = True,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects. """Search using serpstack.com's and return the results as a list of SearchResult objects.
......
...@@ -17,7 +17,7 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]: ...@@ -17,7 +17,7 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
query (str): The query to search for query (str): The query to search for
Returns: Returns:
List[SearchResult]: A list of search results list[SearchResult]: A list of search results
""" """
url = "https://api.tavily.com/search" url = "https://api.tavily.com/search"
data = {"query": query, "api_key": api_key} data = {"query": query, "api_key": api_key}
......
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import logging import logging
import requests import requests
from typing import List, Union from typing import Union
from apps.ollama.main import ( from apps.ollama.main import (
generate_ollama_embeddings, generate_ollama_embeddings,
...@@ -142,7 +142,7 @@ def merge_and_sort_query_results(query_results, k, reverse=False): ...@@ -142,7 +142,7 @@ def merge_and_sort_query_results(query_results, k, reverse=False):
def query_collection( def query_collection(
collection_names: List[str], collection_names: list[str],
query: str, query: str,
embedding_function, embedding_function,
k: int, k: int,
...@@ -157,13 +157,13 @@ def query_collection( ...@@ -157,13 +157,13 @@ def query_collection(
embedding_function=embedding_function, embedding_function=embedding_function,
) )
results.append(result) results.append(result)
except: except Exception:
pass pass
return merge_and_sort_query_results(results, k=k) return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search( def query_collection_with_hybrid_search(
collection_names: List[str], collection_names: list[str],
query: str, query: str,
embedding_function, embedding_function,
k: int, k: int,
...@@ -182,7 +182,7 @@ def query_collection_with_hybrid_search( ...@@ -182,7 +182,7 @@ def query_collection_with_hybrid_search(
r=r, r=r,
) )
results.append(result) results.append(result)
except: except Exception:
pass pass
return merge_and_sort_query_results(results, k=k, reverse=True) return merge_and_sort_query_results(results, k=k, reverse=True)
...@@ -411,7 +411,7 @@ class ChromaRetriever(BaseRetriever): ...@@ -411,7 +411,7 @@ class ChromaRetriever(BaseRetriever):
query: str, query: str,
*, *,
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]: ) -> list[Document]:
query_embeddings = self.embedding_function(query) query_embeddings = self.embedding_function(query)
results = self.collection.query( results = self.collection.query(
......
...@@ -140,7 +140,7 @@ class AuthsTable: ...@@ -140,7 +140,7 @@ class AuthsTable:
return None return None
else: else:
return None return None
except: except Exception:
return None return None
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
...@@ -152,7 +152,7 @@ class AuthsTable: ...@@ -152,7 +152,7 @@ class AuthsTable:
try: try:
user = Users.get_user_by_api_key(api_key) user = Users.get_user_by_api_key(api_key)
return user if user else None return user if user else None
except: except Exception:
return False return False
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
...@@ -163,7 +163,7 @@ class AuthsTable: ...@@ -163,7 +163,7 @@ class AuthsTable:
if auth: if auth:
user = Users.get_user_by_id(auth.id) user = Users.get_user_by_id(auth.id)
return user return user
except: except Exception:
return None return None
def update_user_password_by_id(self, id: str, new_password: str) -> bool: def update_user_password_by_id(self, id: str, new_password: str) -> bool:
...@@ -174,7 +174,7 @@ class AuthsTable: ...@@ -174,7 +174,7 @@ class AuthsTable:
) )
db.commit() db.commit()
return True if result == 1 else False return True if result == 1 else False
except: except Exception:
return False return False
def update_email_by_id(self, id: str, email: str) -> bool: def update_email_by_id(self, id: str, email: str) -> bool:
...@@ -183,7 +183,7 @@ class AuthsTable: ...@@ -183,7 +183,7 @@ class AuthsTable:
result = db.query(Auth).filter_by(id=id).update({"email": email}) result = db.query(Auth).filter_by(id=id).update({"email": email})
db.commit() db.commit()
return True if result == 1 else False return True if result == 1 else False
except: except Exception:
return False return False
def delete_auth_by_id(self, id: str) -> bool: def delete_auth_by_id(self, id: str) -> bool:
...@@ -200,7 +200,7 @@ class AuthsTable: ...@@ -200,7 +200,7 @@ class AuthsTable:
return True return True
else: else:
return False return False
except: except Exception:
return False return False
......
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional from typing import Union, Optional
import json import json
import uuid import uuid
...@@ -164,7 +164,7 @@ class ChatTable: ...@@ -164,7 +164,7 @@ class ChatTable:
db.refresh(chat) db.refresh(chat)
return self.get_chat_by_id(chat.share_id) return self.get_chat_by_id(chat.share_id)
except: except Exception:
return None return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
...@@ -175,7 +175,7 @@ class ChatTable: ...@@ -175,7 +175,7 @@ class ChatTable:
db.commit() db.commit()
return True return True
except: except Exception:
return False return False
def update_chat_share_id_by_id( def update_chat_share_id_by_id(
...@@ -189,7 +189,7 @@ class ChatTable: ...@@ -189,7 +189,7 @@ class ChatTable:
db.commit() db.commit()
db.refresh(chat) db.refresh(chat)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except: except Exception:
return None return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
...@@ -201,7 +201,7 @@ class ChatTable: ...@@ -201,7 +201,7 @@ class ChatTable:
db.commit() db.commit()
db.refresh(chat) db.refresh(chat)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except: except Exception:
return None return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool: def archive_all_chats_by_user_id(self, user_id: str) -> bool:
...@@ -210,12 +210,12 @@ class ChatTable: ...@@ -210,12 +210,12 @@ class ChatTable:
db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
db.commit() db.commit()
return True return True
except: except Exception:
return False return False
def get_archived_chat_list_by_user_id( def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
all_chats = ( all_chats = (
...@@ -233,7 +233,7 @@ class ChatTable: ...@@ -233,7 +233,7 @@ class ChatTable:
include_archived: bool = False, include_archived: bool = False,
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
) -> List[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id)
if not include_archived: if not include_archived:
...@@ -251,7 +251,7 @@ class ChatTable: ...@@ -251,7 +251,7 @@ class ChatTable:
include_archived: bool = False, include_archived: bool = False,
skip: int = 0, skip: int = 0,
limit: int = -1, limit: int = -1,
) -> List[ChatTitleIdResponse]: ) -> list[ChatTitleIdResponse]:
with get_db() as db: with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id) query = db.query(Chat).filter_by(user_id=user_id)
if not include_archived: if not include_archived:
...@@ -279,8 +279,8 @@ class ChatTable: ...@@ -279,8 +279,8 @@ class ChatTable:
] ]
def get_chat_list_by_chat_ids( def get_chat_list_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50 self, chat_ids: list[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
all_chats = ( all_chats = (
db.query(Chat) db.query(Chat)
...@@ -297,7 +297,7 @@ class ChatTable: ...@@ -297,7 +297,7 @@ class ChatTable:
chat = db.get(Chat, id) chat = db.get(Chat, id)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except: except Exception:
return None return None
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
...@@ -319,10 +319,10 @@ class ChatTable: ...@@ -319,10 +319,10 @@ class ChatTable:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except: except Exception:
return None return None
def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
all_chats = ( all_chats = (
...@@ -332,7 +332,7 @@ class ChatTable: ...@@ -332,7 +332,7 @@ class ChatTable:
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
all_chats = ( all_chats = (
...@@ -342,7 +342,7 @@ class ChatTable: ...@@ -342,7 +342,7 @@ class ChatTable:
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db: with get_db() as db:
all_chats = ( all_chats = (
...@@ -360,7 +360,7 @@ class ChatTable: ...@@ -360,7 +360,7 @@ class ChatTable:
db.commit() db.commit()
return True and self.delete_shared_chat_by_chat_id(id) return True and self.delete_shared_chat_by_chat_id(id)
except: except Exception:
return False return False
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
...@@ -371,7 +371,7 @@ class ChatTable: ...@@ -371,7 +371,7 @@ class ChatTable:
db.commit() db.commit()
return True and self.delete_shared_chat_by_chat_id(id) return True and self.delete_shared_chat_by_chat_id(id)
except: except Exception:
return False return False
def delete_chats_by_user_id(self, user_id: str) -> bool: def delete_chats_by_user_id(self, user_id: str) -> bool:
...@@ -385,7 +385,7 @@ class ChatTable: ...@@ -385,7 +385,7 @@ class ChatTable:
db.commit() db.commit()
return True return True
except: except Exception:
return False return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool: def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
...@@ -400,7 +400,7 @@ class ChatTable: ...@@ -400,7 +400,7 @@ class ChatTable:
db.commit() db.commit()
return True return True
except: except Exception:
return False return False
......
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing import List, Optional from typing import Optional
import time import time
import logging import logging
...@@ -93,7 +93,7 @@ class DocumentsTable: ...@@ -93,7 +93,7 @@ class DocumentsTable:
return DocumentModel.model_validate(result) return DocumentModel.model_validate(result)
else: else:
return None return None
except: except Exception:
return None return None
def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
...@@ -102,10 +102,10 @@ class DocumentsTable: ...@@ -102,10 +102,10 @@ class DocumentsTable:
document = db.query(Document).filter_by(name=name).first() document = db.query(Document).filter_by(name=name).first()
return DocumentModel.model_validate(document) if document else None return DocumentModel.model_validate(document) if document else None
except: except Exception:
return None return None
def get_docs(self) -> List[DocumentModel]: def get_docs(self) -> list[DocumentModel]:
with get_db() as db: with get_db() as db:
return [ return [
...@@ -160,7 +160,7 @@ class DocumentsTable: ...@@ -160,7 +160,7 @@ class DocumentsTable:
db.query(Document).filter_by(name=name).delete() db.query(Document).filter_by(name=name).delete()
db.commit() db.commit()
return True return True
except: except Exception:
return False return False
......
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional from typing import Union, Optional
import time import time
import logging import logging
...@@ -90,10 +90,10 @@ class FilesTable: ...@@ -90,10 +90,10 @@ class FilesTable:
try: try:
file = db.get(File, id) file = db.get(File, id)
return FileModel.model_validate(file) return FileModel.model_validate(file)
except: except Exception:
return None return None
def get_files(self) -> List[FileModel]: def get_files(self) -> list[FileModel]:
with get_db() as db: with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()] return [FileModel.model_validate(file) for file in db.query(File).all()]
...@@ -107,7 +107,7 @@ class FilesTable: ...@@ -107,7 +107,7 @@ class FilesTable:
db.commit() db.commit()
return True return True
except: except Exception:
return False return False
def delete_all_files(self) -> bool: def delete_all_files(self) -> bool:
...@@ -119,7 +119,7 @@ class FilesTable: ...@@ -119,7 +119,7 @@ class FilesTable:
db.commit() db.commit()
return True return True
except: except Exception:
return False return False
......
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional from typing import Union, Optional
import time import time
import logging import logging
...@@ -122,10 +122,10 @@ class FunctionsTable: ...@@ -122,10 +122,10 @@ class FunctionsTable:
function = db.get(Function, id) function = db.get(Function, id)
return FunctionModel.model_validate(function) return FunctionModel.model_validate(function)
except: except Exception:
return None return None
def get_functions(self, active_only=False) -> List[FunctionModel]: def get_functions(self, active_only=False) -> list[FunctionModel]:
with get_db() as db: with get_db() as db:
if active_only: if active_only:
...@@ -141,7 +141,7 @@ class FunctionsTable: ...@@ -141,7 +141,7 @@ class FunctionsTable:
def get_functions_by_type( def get_functions_by_type(
self, type: str, active_only=False self, type: str, active_only=False
) -> List[FunctionModel]: ) -> list[FunctionModel]:
with get_db() as db: with get_db() as db:
if active_only: if active_only:
...@@ -157,7 +157,7 @@ class FunctionsTable: ...@@ -157,7 +157,7 @@ class FunctionsTable:
for function in db.query(Function).filter_by(type=type).all() for function in db.query(Function).filter_by(type=type).all()
] ]
def get_global_filter_functions(self) -> List[FunctionModel]: def get_global_filter_functions(self) -> list[FunctionModel]:
with get_db() as db: with get_db() as db:
return [ return [
...@@ -167,7 +167,7 @@ class FunctionsTable: ...@@ -167,7 +167,7 @@ class FunctionsTable:
.all() .all()
] ]
def get_global_action_functions(self) -> List[FunctionModel]: def get_global_action_functions(self) -> list[FunctionModel]:
with get_db() as db: with get_db() as db:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
...@@ -198,7 +198,7 @@ class FunctionsTable: ...@@ -198,7 +198,7 @@ class FunctionsTable:
db.commit() db.commit()
db.refresh(function) db.refresh(function)
return self.get_function_by_id(id) return self.get_function_by_id(id)
except: except Exception:
return None return None
def get_user_valves_by_id_and_user_id( def get_user_valves_by_id_and_user_id(
...@@ -256,7 +256,7 @@ class FunctionsTable: ...@@ -256,7 +256,7 @@ class FunctionsTable:
) )
db.commit() db.commit()
return self.get_function_by_id(id) return self.get_function_by_id(id)
except: except Exception:
return None return None
def deactivate_all_functions(self) -> Optional[bool]: def deactivate_all_functions(self) -> Optional[bool]:
...@@ -271,7 +271,7 @@ class FunctionsTable: ...@@ -271,7 +271,7 @@ class FunctionsTable:
) )
db.commit() db.commit()
return True return True
except: except Exception:
return None return None
def delete_function_by_id(self, id: str) -> bool: def delete_function_by_id(self, id: str) -> bool:
...@@ -281,7 +281,7 @@ class FunctionsTable: ...@@ -281,7 +281,7 @@ class FunctionsTable:
db.commit() db.commit()
return True return True
except: except Exception:
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment