Unverified Commit 92c98eda authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #1781 from open-webui/dev

0.1.122
parents 1092ee9c 85df019c
...@@ -4,6 +4,7 @@ module.exports = { ...@@ -4,6 +4,7 @@ module.exports = {
'eslint:recommended', 'eslint:recommended',
'plugin:@typescript-eslint/recommended', 'plugin:@typescript-eslint/recommended',
'plugin:svelte/recommended', 'plugin:svelte/recommended',
'plugin:cypress/recommended',
'prettier' 'prettier'
], ],
parser: '@typescript-eslint/parser', parser: '@typescript-eslint/parser',
......
...@@ -2,14 +2,16 @@ ...@@ -2,14 +2,16 @@
- [ ] **Description:** Briefly describe the changes in this pull request. - [ ] **Description:** Briefly describe the changes in this pull request.
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description. - [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
- [ ] **Documentation:** Have you updated relevant documentation? - [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation? - [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
- [ ] **Testing:** Have you written and run sufficient tests for the changes?
- [ ] **Code Review:** Have you self-reviewed your code and addressed any coding standard issues?
--- ---
## Description ## Description
[Insert a brief description of the changes made in this pull request] [Insert a brief description of the changes made in this pull request, including any relevant motivation and impact.]
--- ---
...@@ -17,16 +19,32 @@ ...@@ -17,16 +19,32 @@
### Added ### Added
- [List any new features or additions] - [List any new features, functionalities, or additions]
### Fixed ### Fixed
- [List any fixes or corrections] - [List any fixes, corrections, or bug fixes]
### Changed ### Changed
- [List any changes or updates] - [List any changes, updates, refactorings, or optimizations]
### Removed ### Removed
- [List any removed features or files] - [List any removed features, files, or deprecated functionalities]
### Security
- [List any new or updated security-related changes, including vulnerability fixes]
### Breaking Changes
- [List any breaking changes affecting compatibility or functionality]
---
### Additional Information
- [Insert any additional context, notes, or explanations for the changes]
- [Reference any related issues, commits, or other relevant information]
name: Integration Test
on:
push:
branches:
- main
- dev
pull_request:
branches:
- main
- dev
jobs:
cypress-run:
name: Run Cypress Integration Tests
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v4
- name: Build and run Compose Stack
run: |
docker compose up --detach --build
- name: Preload Ollama model
run: |
docker exec ollama ollama pull qwen:0.5b-chat-v1.5-q2_K
- name: Cypress run
uses: cypress-io/github-action@v6
with:
browser: chrome
wait-on: 'http://localhost:3000'
config: baseUrl=http://localhost:3000
- uses: actions/upload-artifact@v4
if: always()
name: Upload Cypress videos
with:
name: cypress-videos
path: cypress/videos
if-no-files-found: ignore
- name: Extract Compose logs
if: always()
run: |
docker compose logs > compose-logs.txt
- uses: actions/upload-artifact@v4
if: always()
name: Upload Compose logs
with:
name: compose-logs
path: compose-logs.txt
if-no-files-found: ignore
...@@ -297,4 +297,8 @@ dist ...@@ -297,4 +297,8 @@ dist
.yarn/unplugged .yarn/unplugged
.yarn/build-state.yml .yarn/build-state.yml
.yarn/install-state.gz .yarn/install-state.gz
.pnp.* .pnp.*
\ No newline at end of file
# cypress artifacts
cypress/videos
cypress/screenshots
...@@ -5,6 +5,30 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,30 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.122] - 2024-04-27
### Added
- **🌟 Enhanced RAG Pipeline**: Now with hybrid searching via 'BM25', reranking powered by 'CrossEncoder', and configurable relevance score thresholds.
- **🛢️ External Database Support**: Seamlessly connect to custom SQLite or Postgres databases using the 'DATABASE_URL' environment variable.
- **🌐 Remote ChromaDB Support**: Introducing the capability to connect to remote ChromaDB servers.
- **👨‍💼 Improved Admin Panel**: Admins can now conveniently check users' chat lists and last active status directly from the admin panel.
- **🎨 Splash Screen**: Introducing a loading splash screen for a smoother user experience.
- **🌍 Language Support Expansion**: Added support for Bangla (bn-BD), along with enhancements to Chinese, Spanish, and Ukrainian translations.
- **💻 Improved LaTeX Rendering Performance**: Enjoy faster rendering times for LaTeX equations.
- **🔧 More Environment Variables**: Explore additional environment variables in our documentation (https://docs.openwebui.com), including the 'ENABLE_LITELLM' option to manage memory usage.
### Fixed
- **🔧 Ollama Compatibility**: Resolved errors occurring when Ollama server version isn't an integer, such as SHA builds or RCs.
- **🐛 Various OpenAI API Issues**: Addressed several issues related to the OpenAI API.
- **🛑 Stop Sequence Issue**: Fixed the problem where the stop sequence with a backslash '\' was not functioning.
- **🔤 Font Fallback**: Corrected font fallback issue.
### Changed
- **⌨️ Prompt Input Behavior on Mobile**: Enter key prompt submission disabled on mobile devices for improved user experience.
## [0.1.121] - 2024-04-24 ## [0.1.121] - 2024-04-24
### Fixed ### Fixed
......
...@@ -8,8 +8,9 @@ ARG USE_CUDA_VER=cu121 ...@@ -8,8 +8,9 @@ ARG USE_CUDA_VER=cu121
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
# IMPORTANT: If you change the default model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
ARG USE_RERANKING_MODEL=""
######## WebUI frontend ######## ######## WebUI frontend ########
FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
...@@ -30,6 +31,7 @@ ARG USE_CUDA ...@@ -30,6 +31,7 @@ ARG USE_CUDA
ARG USE_OLLAMA ARG USE_OLLAMA
ARG USE_CUDA_VER ARG USE_CUDA_VER
ARG USE_EMBEDDING_MODEL ARG USE_EMBEDDING_MODEL
ARG USE_RERANKING_MODEL
## Basis ## ## Basis ##
ENV ENV=prod \ ENV ENV=prod \
...@@ -38,7 +40,8 @@ ENV ENV=prod \ ...@@ -38,7 +40,8 @@ ENV ENV=prod \
USE_OLLAMA_DOCKER=${USE_OLLAMA} \ USE_OLLAMA_DOCKER=${USE_OLLAMA} \
USE_CUDA_DOCKER=${USE_CUDA} \ USE_CUDA_DOCKER=${USE_CUDA} \
USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \ USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
## Basis URL Config ## ## Basis URL Config ##
ENV OLLAMA_BASE_URL="/ollama" \ ENV OLLAMA_BASE_URL="/ollama" \
...@@ -62,8 +65,11 @@ ENV WHISPER_MODEL="base" \ ...@@ -62,8 +65,11 @@ ENV WHISPER_MODEL="base" \
## RAG Embedding model settings ## ## RAG Embedding model settings ##
ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models" \ RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
## Hugging Face download cache ##
ENV HF_HOME="/app/backend/data/cache/embedding/models"
#### Other models ########################################################## #### Other models ##########################################################
WORKDIR /app/backend WORKDIR /app/backend
......
ifneq ($(shell which docker-compose 2>/dev/null),)
DOCKER_COMPOSE := docker-compose
else
DOCKER_COMPOSE := docker compose
endif
install: install:
@docker-compose up -d $(DOCKER_COMPOSE) up -d
remove: remove:
@chmod +x confirm_remove.sh @chmod +x confirm_remove.sh
@./confirm_remove.sh @./confirm_remove.sh
start: start:
@docker-compose start $(DOCKER_COMPOSE) start
startAndBuild: startAndBuild:
docker-compose up -d --build $(DOCKER_COMPOSE) up -d --build
stop: stop:
@docker-compose stop $(DOCKER_COMPOSE) stop
update: update:
# Calls the LLM update script # Calls the LLM update script
chmod +x update_ollama_models.sh chmod +x update_ollama_models.sh
@./update_ollama_models.sh @./update_ollama_models.sh
@git pull @git pull
@docker-compose down $(DOCKER_COMPOSE) down
# Make sure the ollama-webui container is stopped before rebuilding # Make sure the ollama-webui container is stopped before rebuilding
@docker stop open-webui || true @docker stop open-webui || true
@docker-compose up --build -d $(DOCKER_COMPOSE) up --build -d
@docker-compose start $(DOCKER_COMPOSE) start
...@@ -32,11 +32,15 @@ import logging ...@@ -32,11 +32,15 @@ import logging
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
CACHE_DIR, CACHE_DIR,
IMAGE_GENERATION_ENGINE,
ENABLE_IMAGE_GENERATION, ENABLE_IMAGE_GENERATION,
AUTOMATIC1111_BASE_URL, AUTOMATIC1111_BASE_URL,
COMFYUI_BASE_URL, COMFYUI_BASE_URL,
IMAGES_OPENAI_API_BASE_URL, IMAGES_OPENAI_API_BASE_URL,
IMAGES_OPENAI_API_KEY, IMAGES_OPENAI_API_KEY,
IMAGE_GENERATION_MODEL,
IMAGE_SIZE,
IMAGE_STEPS,
) )
...@@ -55,21 +59,21 @@ app.add_middleware( ...@@ -55,21 +59,21 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.ENGINE = "" app.state.ENGINE = IMAGE_GENERATION_ENGINE
app.state.ENABLED = ENABLE_IMAGE_GENERATION app.state.ENABLED = ENABLE_IMAGE_GENERATION
app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.MODEL = "" app.state.MODEL = IMAGE_GENERATION_MODEL
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.IMAGE_SIZE = "512x512" app.state.IMAGE_SIZE = IMAGE_SIZE
app.state.IMAGE_STEPS = 50 app.state.IMAGE_STEPS = IMAGE_STEPS
@app.get("/config") @app.get("/config")
......
...@@ -21,12 +21,15 @@ from utils.utils import get_verified_user, get_current_user, get_admin_user ...@@ -21,12 +21,15 @@ from utils.utils import get_verified_user, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS, ENV from config import SRC_LOG_LEVELS, ENV
from constants import MESSAGES from constants import MESSAGES
import os
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"]) log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import ( from config import (
MODEL_FILTER_ENABLED, ENABLE_LITELLM,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
DATA_DIR, DATA_DIR,
LITELLM_PROXY_PORT, LITELLM_PROXY_PORT,
...@@ -57,11 +60,20 @@ LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml" ...@@ -57,11 +60,20 @@ LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml"
with open(LITELLM_CONFIG_DIR, "r") as file: with open(LITELLM_CONFIG_DIR, "r") as file:
litellm_config = yaml.safe_load(file) litellm_config = yaml.safe_load(file)
app.state.ENABLE = ENABLE_LITELLM
app.state.CONFIG = litellm_config app.state.CONFIG = litellm_config
# Global variable to store the subprocess reference # Global variable to store the subprocess reference
background_process = None background_process = None
CONFLICT_ENV_VARS = [
# Uvicorn uses PORT, so LiteLLM might use it as well
"PORT",
# LiteLLM uses DATABASE_URL for Prisma connections
"DATABASE_URL",
]
async def run_background_process(command): async def run_background_process(command):
global background_process global background_process
...@@ -70,9 +82,11 @@ async def run_background_process(command): ...@@ -70,9 +82,11 @@ async def run_background_process(command):
try: try:
# Log the command to be executed # Log the command to be executed
log.info(f"Executing command: {command}") log.info(f"Executing command: {command}")
# Filter environment variables known to conflict with litellm
env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS}
# Execute the command and create a subprocess # Execute the command and create a subprocess
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*command, stdout=subprocess.PIPE, stderr=subprocess.PIPE *command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
) )
background_process = process background_process = process
log.info("Subprocess started successfully.") log.info("Subprocess started successfully.")
...@@ -130,7 +144,7 @@ async def startup_event(): ...@@ -130,7 +144,7 @@ async def startup_event():
asyncio.create_task(start_litellm_background()) asyncio.create_task(start_litellm_background())
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
...@@ -198,49 +212,56 @@ async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_use ...@@ -198,49 +212,56 @@ async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_use
@app.get("/models") @app.get("/models")
@app.get("/v1/models") @app.get("/v1/models")
async def get_models(user=Depends(get_current_user)): async def get_models(user=Depends(get_current_user)):
while not background_process:
await asyncio.sleep(0.1)
url = f"http://localhost:{LITELLM_PROXY_PORT}/v1"
r = None
try:
r = requests.request(method="GET", url=f"{url}/models")
r.raise_for_status()
data = r.json()
if app.state.MODEL_FILTER_ENABLED: if app.state.ENABLE:
if user and user.role == "user": while not background_process:
data["data"] = list( await asyncio.sleep(0.1)
filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST, url = f"http://localhost:{LITELLM_PROXY_PORT}/v1"
data["data"], r = None
try:
r = requests.request(method="GET", url=f"{url}/models")
r.raise_for_status()
data = r.json()
if app.state.ENABLE_MODEL_FILTER:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
data["data"],
)
) )
)
return data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
return data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
return {
"data": [
{
"id": model["model_name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in app.state.CONFIG["model_list"]
],
"object": "list",
}
else:
return { return {
"data": [ "data": [],
{
"id": model["model_name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in app.state.CONFIG["model_list"]
],
"object": "list", "object": "list",
} }
......
...@@ -16,6 +16,7 @@ from fastapi.concurrency import run_in_threadpool ...@@ -16,6 +16,7 @@ from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
import os import os
import re
import copy import copy
import random import random
import requests import requests
...@@ -36,7 +37,7 @@ from utils.utils import decode_token, get_current_user, get_admin_user ...@@ -36,7 +37,7 @@ from utils.utils import decode_token, get_current_user, get_admin_user
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
OLLAMA_BASE_URLS, OLLAMA_BASE_URLS,
MODEL_FILTER_ENABLED, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
UPLOAD_DIR, UPLOAD_DIR,
) )
...@@ -55,7 +56,7 @@ app.add_middleware( ...@@ -55,7 +56,7 @@ app.add_middleware(
) )
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
...@@ -168,7 +169,7 @@ async def get_ollama_tags( ...@@ -168,7 +169,7 @@ async def get_ollama_tags(
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
if app.state.MODEL_FILTER_ENABLED: if app.state.ENABLE_MODEL_FILTER:
if user.role == "user": if user.role == "user":
models["models"] = list( models["models"] = list(
filter( filter(
...@@ -216,7 +217,9 @@ async def get_ollama_versions(url_idx: Optional[int] = None): ...@@ -216,7 +217,9 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
if len(responses) > 0: if len(responses) > 0:
lowest_version = min( lowest_version = min(
responses, responses,
key=lambda x: tuple(map(int, x["version"].split("-")[0].split("."))), key=lambda x: tuple(
map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
),
) )
return {"version": lowest_version["version"]} return {"version": lowest_version["version"]}
......
...@@ -24,7 +24,7 @@ from config import ( ...@@ -24,7 +24,7 @@ from config import (
OPENAI_API_BASE_URLS, OPENAI_API_BASE_URLS,
OPENAI_API_KEYS, OPENAI_API_KEYS,
CACHE_DIR, CACHE_DIR,
MODEL_FILTER_ENABLED, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
) )
from typing import List, Optional from typing import List, Optional
...@@ -45,7 +45,7 @@ app.add_middleware( ...@@ -45,7 +45,7 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
...@@ -225,7 +225,7 @@ async def get_all_models(): ...@@ -225,7 +225,7 @@ async def get_all_models():
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
if app.state.MODEL_FILTER_ENABLED: if app.state.ENABLE_MODEL_FILTER:
if user.role == "user": if user.role == "user":
models["data"] = list( models["data"] = list(
filter( filter(
......
...@@ -39,8 +39,6 @@ import json ...@@ -39,8 +39,6 @@ import json
import sentence_transformers import sentence_transformers
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
from apps.web.models.documents import ( from apps.web.models.documents import (
Documents, Documents,
DocumentForm, DocumentForm,
...@@ -48,9 +46,12 @@ from apps.web.models.documents import ( ...@@ -48,9 +46,12 @@ from apps.web.models.documents import (
) )
from apps.rag.utils import ( from apps.rag.utils import (
query_embeddings_doc, get_model_path,
query_embeddings_collection, get_embedding_function,
generate_openai_embeddings, query_doc,
query_doc_with_hybrid_search,
query_collection,
query_collection_with_hybrid_search,
) )
from utils.misc import ( from utils.misc import (
...@@ -60,13 +61,22 @@ from utils.misc import ( ...@@ -60,13 +61,22 @@ from utils.misc import (
extract_folders_after_data_docs, extract_folders_after_data_docs,
) )
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
UPLOAD_DIR, UPLOAD_DIR,
DOCS_DIR, DOCS_DIR,
RAG_TOP_K,
RAG_RELEVANCE_THRESHOLD,
RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
ENABLE_RAG_HYBRID_SEARCH,
RAG_RERANKING_MODEL,
PDF_EXTRACT_IMAGES,
RAG_RERANKING_MODEL_AUTO_UPDATE,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY, RAG_OPENAI_API_KEY,
DEVICE_TYPE, DEVICE_TYPE,
...@@ -83,31 +93,75 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -83,31 +93,75 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app = FastAPI() app = FastAPI()
app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.TOP_K = 4
app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.PDF_EXTRACT_IMAGES = False app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
if app.state.RAG_EMBEDDING_ENGINE == "":
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( def update_embedding_model(
app.state.RAG_EMBEDDING_MODEL, embedding_model: str,
device=DEVICE_TYPE, update_model: bool = False,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, ):
) if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model),
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
else:
app.state.sentence_transformer_ef = None
def update_reranking_model(
reranking_model: str,
update_model: bool = False,
):
if reranking_model:
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, update_model),
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
)
else:
app.state.sentence_transformer_rf = None
update_embedding_model(
app.state.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
update_reranking_model(
app.state.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
origins = ["*"] origins = ["*"]
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=origins,
...@@ -134,6 +188,7 @@ async def get_status(): ...@@ -134,6 +188,7 @@ async def get_status():
"template": app.state.RAG_TEMPLATE, "template": app.state.RAG_TEMPLATE,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.RAG_RERANKING_MODEL,
} }
...@@ -150,6 +205,11 @@ async def get_embedding_config(user=Depends(get_admin_user)): ...@@ -150,6 +205,11 @@ async def get_embedding_config(user=Depends(get_admin_user)):
} }
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
class OpenAIConfigForm(BaseModel): class OpenAIConfigForm(BaseModel):
url: str url: str
key: str key: str
...@@ -170,22 +230,22 @@ async def update_embedding_config( ...@@ -170,22 +230,22 @@ async def update_embedding_config(
) )
try: try:
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = None
if form_data.openai_config != None: if form_data.openai_config != None:
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.OPENAI_API_KEY = form_data.openai_config.key app.state.OPENAI_API_KEY = form_data.openai_config.key
else:
sentence_transformer_ef = sentence_transformers.SentenceTransformer( update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
app.state.RAG_EMBEDDING_MODEL,
device=DEVICE_TYPE, app.state.EMBEDDING_FUNCTION = get_embedding_function(
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, app.state.RAG_EMBEDDING_ENGINE,
) app.state.RAG_EMBEDDING_MODEL,
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model app.state.sentence_transformer_ef,
app.state.sentence_transformer_ef = sentence_transformer_ef app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
return { return {
"status": True, "status": True,
...@@ -196,7 +256,6 @@ async def update_embedding_config( ...@@ -196,7 +256,6 @@ async def update_embedding_config(
"key": app.state.OPENAI_API_KEY, "key": app.state.OPENAI_API_KEY,
}, },
} }
except Exception as e: except Exception as e:
log.exception(f"Problem updating embedding model: {e}") log.exception(f"Problem updating embedding model: {e}")
raise HTTPException( raise HTTPException(
...@@ -205,6 +264,34 @@ async def update_embedding_config( ...@@ -205,6 +264,34 @@ async def update_embedding_config(
) )
class RerankingModelUpdateForm(BaseModel):
reranking_model: str
@app.post("/reranking/update")
async def update_reranking_config(
form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
)
try:
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
return {
"status": True,
"reranking_model": app.state.RAG_RERANKING_MODEL,
}
except Exception as e:
log.exception(f"Problem updating reranking model: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@app.get("/config") @app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)): async def get_rag_config(user=Depends(get_admin_user)):
return { return {
...@@ -257,12 +344,16 @@ async def get_query_settings(user=Depends(get_admin_user)): ...@@ -257,12 +344,16 @@ async def get_query_settings(user=Depends(get_admin_user)):
"status": True, "status": True,
"template": app.state.RAG_TEMPLATE, "template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K, "k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
} }
class QuerySettingsForm(BaseModel): class QuerySettingsForm(BaseModel):
k: Optional[int] = None k: Optional[int] = None
r: Optional[float] = None
template: Optional[str] = None template: Optional[str] = None
hybrid: Optional[bool] = None
@app.post("/query/settings/update") @app.post("/query/settings/update")
...@@ -271,13 +362,23 @@ async def update_query_settings( ...@@ -271,13 +362,23 @@ async def update_query_settings(
): ):
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
app.state.TOP_K = form_data.k if form_data.k else 4 app.state.TOP_K = form_data.k if form_data.k else 4
return {"status": True, "template": app.state.RAG_TEMPLATE} app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
}
class QueryDocForm(BaseModel): class QueryDocForm(BaseModel):
collection_name: str collection_name: str
query: str query: str
k: Optional[int] = None k: Optional[int] = None
r: Optional[float] = None
hybrid: Optional[bool] = None
@app.post("/query/doc") @app.post("/query/doc")
...@@ -286,34 +387,22 @@ def query_doc_handler( ...@@ -286,34 +387,22 @@ def query_doc_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "": if app.state.ENABLE_RAG_HYBRID_SEARCH:
query_embeddings = app.state.sentence_transformer_ef.encode( return query_doc_with_hybrid_search(
form_data.query collection_name=form_data.collection_name,
).tolist() query=form_data.query,
elif app.state.RAG_EMBEDDING_ENGINE == "ollama": embeddings_function=app.state.EMBEDDING_FUNCTION,
query_embeddings = generate_ollama_embeddings( reranking_function=app.state.sentence_transformer_rf,
GenerateEmbeddingsForm( k=form_data.k if form_data.k else app.state.TOP_K,
**{ r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
"model": app.state.RAG_EMBEDDING_MODEL,
"prompt": form_data.query,
}
)
) )
elif app.state.RAG_EMBEDDING_ENGINE == "openai": else:
query_embeddings = generate_openai_embeddings( return query_doc(
model=app.state.RAG_EMBEDDING_MODEL, collection_name=form_data.collection_name,
text=form_data.query, query=form_data.query,
key=app.state.OPENAI_API_KEY, embeddings_function=app.state.EMBEDDING_FUNCTION,
url=app.state.OPENAI_API_BASE_URL, k=form_data.k if form_data.k else app.state.TOP_K,
) )
return query_embeddings_doc(
collection_name=form_data.collection_name,
query=form_data.query,
query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
...@@ -326,6 +415,8 @@ class QueryCollectionsForm(BaseModel): ...@@ -326,6 +415,8 @@ class QueryCollectionsForm(BaseModel):
collection_names: List[str] collection_names: List[str]
query: str query: str
k: Optional[int] = None k: Optional[int] = None
r: Optional[float] = None
hybrid: Optional[bool] = None
@app.post("/query/collection") @app.post("/query/collection")
...@@ -334,33 +425,23 @@ def query_collection_handler( ...@@ -334,33 +425,23 @@ def query_collection_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "": if app.state.ENABLE_RAG_HYBRID_SEARCH:
query_embeddings = app.state.sentence_transformer_ef.encode( return query_collection_with_hybrid_search(
form_data.query collection_names=form_data.collection_names,
).tolist() query=form_data.query,
elif app.state.RAG_EMBEDDING_ENGINE == "ollama": embeddings_function=app.state.EMBEDDING_FUNCTION,
query_embeddings = generate_ollama_embeddings( reranking_function=app.state.sentence_transformer_rf,
GenerateEmbeddingsForm( k=form_data.k if form_data.k else app.state.TOP_K,
**{ r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
"model": app.state.RAG_EMBEDDING_MODEL,
"prompt": form_data.query,
}
)
) )
elif app.state.RAG_EMBEDDING_ENGINE == "openai": else:
query_embeddings = generate_openai_embeddings( return query_collection(
model=app.state.RAG_EMBEDDING_MODEL, collection_names=form_data.collection_names,
text=form_data.query, query=form_data.query,
key=app.state.OPENAI_API_KEY, embeddings_function=app.state.EMBEDDING_FUNCTION,
url=app.state.OPENAI_API_BASE_URL, k=form_data.k if form_data.k else app.state.TOP_K,
) )
return query_embeddings_collection(
collection_names=form_data.collection_names,
query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
...@@ -427,8 +508,6 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ...@@ -427,8 +508,6 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"store_docs_in_vector_db {docs} {collection_name}") log.info(f"store_docs_in_vector_db {docs} {collection_name}")
texts = [doc.page_content for doc in docs] texts = [doc.page_content for doc in docs]
texts = list(map(lambda x: x.replace("\n", " "), texts))
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
try: try:
...@@ -440,27 +519,16 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ...@@ -440,27 +519,16 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name) collection = CHROMA_CLIENT.create_collection(name=collection_name)
if app.state.RAG_EMBEDDING_ENGINE == "": embedding_func = get_embedding_function(
embeddings = app.state.sentence_transformer_ef.encode(texts).tolist() app.state.RAG_EMBEDDING_ENGINE,
elif app.state.RAG_EMBEDDING_ENGINE == "ollama": app.state.RAG_EMBEDDING_MODEL,
embeddings = [ app.state.sentence_transformer_ef,
generate_ollama_embeddings( app.state.OPENAI_API_KEY,
GenerateEmbeddingsForm( app.state.OPENAI_API_BASE_URL,
**{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text} )
)
) embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
for text in texts embeddings = embedding_func(embedding_texts)
]
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
embeddings = [
generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=text,
key=app.state.OPENAI_API_KEY,
url=app.state.OPENAI_API_BASE_URL,
)
for text in texts
]
for batch in create_batches( for batch in create_batches(
api=CHROMA_CLIENT, api=CHROMA_CLIENT,
......
import os
import logging import logging
import requests import requests
...@@ -8,6 +9,16 @@ from apps.ollama.main import ( ...@@ -8,6 +9,16 @@ from apps.ollama.main import (
GenerateEmbeddingsForm, GenerateEmbeddingsForm,
) )
from huggingface_hub import snapshot_download
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import (
ContextualCompressionRetriever,
EnsembleRetriever,
)
from typing import Optional
from config import SRC_LOG_LEVELS, CHROMA_CLIENT from config import SRC_LOG_LEVELS, CHROMA_CLIENT
...@@ -15,88 +26,164 @@ log = logging.getLogger(__name__) ...@@ -15,88 +26,164 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int): def query_doc(
collection_name: str,
query: str,
embedding_function,
k: int,
):
try: try:
# if you use docker use the model from the environment variable
log.info(f"query_embeddings_doc {query_embeddings}")
collection = CHROMA_CLIENT.get_collection(name=collection_name) collection = CHROMA_CLIENT.get_collection(name=collection_name)
query_embeddings = embedding_function(query)
result = collection.query( result = collection.query(
query_embeddings=[query_embeddings], query_embeddings=[query_embeddings],
n_results=k, n_results=k,
) )
log.info(f"query_embeddings_doc:result {result}") log.info(f"query_doc:result {result}")
return result
except Exception as e:
raise e
def query_doc_with_hybrid_search(
collection_name: str,
query: str,
embedding_function,
k: int,
reranking_function,
r: int,
):
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
documents = collection.get() # get all documents
bm25_retriever = BM25Retriever.from_texts(
texts=documents.get("documents"),
metadatas=documents.get("metadatas"),
)
bm25_retriever.k = k
chroma_retriever = ChromaRetriever(
collection=collection,
embedding_function=embedding_function,
top_n=k,
)
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
)
compressor = RerankCompressor(
embedding_function=embedding_function,
reranking_function=reranking_function,
r_score=r,
top_n=k,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
result = compression_retriever.invoke(query)
result = {
"distances": [[d.metadata.get("score") for d in result]],
"documents": [[d.page_content for d in result]],
"metadatas": [[d.metadata for d in result]],
}
log.info(f"query_doc_with_hybrid_search:result {result}")
return result return result
except Exception as e: except Exception as e:
raise e raise e
def merge_and_sort_query_results(query_results, k): def merge_and_sort_query_results(query_results, k, reverse=False):
# Initialize lists to store combined data # Initialize lists to store combined data
combined_ids = []
combined_distances = [] combined_distances = []
combined_metadatas = []
combined_documents = [] combined_documents = []
combined_metadatas = []
# Combine data from each dictionary
for data in query_results: for data in query_results:
combined_ids.extend(data["ids"][0])
combined_distances.extend(data["distances"][0]) combined_distances.extend(data["distances"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_documents.extend(data["documents"][0]) combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
# Create a list of tuples (distance, id, metadata, document) # Create a list of tuples (distance, document, metadata)
combined = list( combined = list(zip(combined_distances, combined_documents, combined_metadatas))
zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
)
# Sort the list based on distances # Sort the list based on distances
combined.sort(key=lambda x: x[0]) combined.sort(key=lambda x: x[0], reverse=reverse)
# Unzip the sorted list # We don't have anything :-(
sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined) if not combined:
sorted_distances = []
sorted_documents = []
sorted_metadatas = []
else:
# Unzip the sorted list
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
# Slicing the lists to include only k elements # Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k] sorted_distances = list(sorted_distances)[:k]
sorted_ids = list(sorted_ids)[:k] sorted_documents = list(sorted_documents)[:k]
sorted_metadatas = list(sorted_metadatas)[:k] sorted_metadatas = list(sorted_metadatas)[:k]
sorted_documents = list(sorted_documents)[:k]
# Create the output dictionary # Create the output dictionary
merged_query_results = { result = {
"ids": [sorted_ids],
"distances": [sorted_distances], "distances": [sorted_distances],
"metadatas": [sorted_metadatas],
"documents": [sorted_documents], "documents": [sorted_documents],
"embeddings": None, "metadatas": [sorted_metadatas],
"uris": None,
"data": None,
} }
return merged_query_results return result
def query_embeddings_collection( def query_collection(
collection_names: List[str], query: str, query_embeddings, k: int collection_names: List[str],
query: str,
embedding_function,
k: int,
): ):
results = [] results = []
log.info(f"query_embeddings_collection {query_embeddings}") for collection_name in collection_names:
try:
result = query_doc(
collection_name=collection_name,
query=query,
k=k,
embedding_function=embedding_function,
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
collection_names: List[str],
query: str,
embedding_function,
k: int,
reranking_function,
r: float,
):
results = []
for collection_name in collection_names: for collection_name in collection_names:
try: try:
result = query_embeddings_doc( result = query_doc_with_hybrid_search(
collection_name=collection_name, collection_name=collection_name,
query=query, query=query,
query_embeddings=query_embeddings, embedding_function=embedding_function,
k=k, k=k,
reranking_function=reranking_function,
r=r,
) )
results.append(result) results.append(result)
except: except:
pass pass
return merge_and_sort_query_results(results, k) return merge_and_sort_query_results(results, k=k, reverse=True)
def rag_template(template: str, context: str, query: str): def rag_template(template: str, context: str, query: str):
...@@ -105,20 +192,53 @@ def rag_template(template: str, context: str, query: str): ...@@ -105,20 +192,53 @@ def rag_template(template: str, context: str, query: str):
return template return template
def rag_messages( def get_embedding_function(
docs,
messages,
template,
k,
embedding_engine, embedding_engine,
embedding_model, embedding_model,
embedding_function, embedding_function,
openai_key, openai_key,
openai_url, openai_url,
): ):
log.debug( if embedding_engine == "":
f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {openai_key} {openai_url}" return lambda query: embedding_function.encode(query).tolist()
) elif embedding_engine in ["ollama", "openai"]:
if embedding_engine == "ollama":
func = lambda query: generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
)
elif embedding_engine == "openai":
func = lambda query: generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
def generate_multiple(query, f):
if isinstance(query, list):
return [f(q) for q in query]
else:
return f(query)
return lambda query: generate_multiple(query, func)
def rag_messages(
docs,
messages,
template,
embedding_function,
k,
reranking_function,
r,
hybrid_search,
):
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
last_user_message_idx = None last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1): for i in range(len(messages) - 1, -1, -1):
...@@ -145,62 +265,65 @@ def rag_messages( ...@@ -145,62 +265,65 @@ def rag_messages(
content_type = None content_type = None
query = "" query = ""
extracted_collections = []
relevant_contexts = [] relevant_contexts = []
for doc in docs: for doc in docs:
context = None context = None
try: collection = doc.get("collection_name")
if collection:
collection = [collection]
else:
collection = doc.get("collection_names", [])
collection = set(collection).difference(extracted_collections)
if not collection:
log.debug(f"skipping {doc} as it has already been extracted")
continue
try:
if doc["type"] == "text": if doc["type"] == "text":
context = doc["content"] context = doc["content"]
else: else:
if embedding_engine == "": if hybrid_search:
query_embeddings = embedding_function.encode(query).tolist() context = query_collection_with_hybrid_search(
elif embedding_engine == "ollama": collection_names=(
query_embeddings = generate_ollama_embeddings( doc["collection_names"]
GenerateEmbeddingsForm( if doc["type"] == "collection"
**{ else [doc["collection_name"]]
"model": embedding_model, ),
"prompt": query,
}
)
)
elif embedding_engine == "openai":
query_embeddings = generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
if doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
query=query, query=query,
query_embeddings=query_embeddings, embedding_function=embedding_function,
k=k, k=k,
reranking_function=reranking_function,
r=r,
) )
else: else:
context = query_embeddings_doc( context = query_collection(
collection_name=doc["collection_name"], collection_names=(
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
query=query, query=query,
query_embeddings=query_embeddings, embedding_function=embedding_function,
k=k, k=k,
) )
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
context = None context = None
relevant_contexts.append(context) if context:
relevant_contexts.append(context)
log.debug(f"relevant_contexts: {relevant_contexts}") extracted_collections.extend(collection)
context_string = "" context_string = ""
for context in relevant_contexts: for context in relevant_contexts:
if context: items = context["documents"][0]
context_string += " ".join(context["documents"][0]) + "\n" context_string += "\n\n".join(items)
context_string = context_string.strip()
ra_content = rag_template( ra_content = rag_template(
template=template, template=template,
...@@ -208,6 +331,8 @@ def rag_messages( ...@@ -208,6 +331,8 @@ def rag_messages(
query=query, query=query,
) )
log.debug(f"ra_content: {ra_content}")
if content_type == "list": if content_type == "list":
new_content = [] new_content = []
for content_item in user_message["content"]: for content_item in user_message["content"]:
...@@ -229,6 +354,44 @@ def rag_messages( ...@@ -229,6 +354,44 @@ def rag_messages(
return messages return messages
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_model
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
log.debug(f"model: {model}")
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
os.path.exists(model)
or ("\\" in model or model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
return model
elif "/" not in model:
# Set valid repo_id for model short-name
model = "sentence-transformers" + "/" + model
snapshot_kwargs["repo_id"] = model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"model_repo_path: {model_repo_path}")
return model_repo_path
except Exception as e:
log.exception(f"Cannot determine model snapshot path: {e}")
return model
def generate_openai_embeddings( def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com/v1" model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
): ):
...@@ -250,3 +413,99 @@ def generate_openai_embeddings( ...@@ -250,3 +413,99 @@ def generate_openai_embeddings(
except Exception as e: except Exception as e:
print(e) print(e)
return None return None
from typing import Any
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
class ChromaRetriever(BaseRetriever):
collection: Any
embedding_function: Any
top_n: int
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
query_embeddings = self.embedding_function(query)
results = self.collection.query(
query_embeddings=[query_embeddings],
n_results=self.top_n,
)
ids = results["ids"][0]
metadatas = results["metadatas"][0]
documents = results["documents"][0]
return [
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
for idx in range(len(ids))
]
import operator
from typing import Optional, Sequence
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra
from sentence_transformers import util
class RerankCompressor(BaseDocumentCompressor):
embedding_function: Any
reranking_function: Any
r_score: float
top_n: int
class Config:
extra = Extra.forbid
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
if self.reranking_function:
scores = self.reranking_function.predict(
[(query, doc.page_content) for doc in documents]
)
else:
query_embedding = self.embedding_function(query)
document_embedding = self.embedding_function(
[doc.page_content for doc in documents]
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
docs_with_scores = list(zip(documents, scores.tolist()))
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score
]
reverse = self.reranking_function is not None
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
metadata["score"] = doc_score
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results
from peewee import * from peewee import *
from peewee_migrate import Router from peewee_migrate import Router
from config import SRC_LOG_LEVELS, DATA_DIR from playhouse.db_url import connect
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL
import os import os
import logging import logging
...@@ -11,12 +12,12 @@ log.setLevel(SRC_LOG_LEVELS["DB"]) ...@@ -11,12 +12,12 @@ log.setLevel(SRC_LOG_LEVELS["DB"])
if os.path.exists(f"{DATA_DIR}/ollama.db"): if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file # Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
log.info("File renamed successfully.") log.info("Database migrated from Ollama-WebUI successfully.")
else: else:
pass pass
DB = connect(DATABASE_URL)
DB = SqliteDatabase(f"{DATA_DIR}/webui.db") log.info(f"Connected to a {DB.__class__.__name__} database.")
router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log) router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log)
router.run() router.run()
DB.connect(reuse_if_open=True) DB.connect(reuse_if_open=True)
...@@ -37,6 +37,18 @@ with suppress(ImportError): ...@@ -37,6 +37,18 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
# We perform different migrations for SQLite and other databases
# This is because SQLite is very loose with enforcing its schema, and trying to migrate other databases like SQLite
# will require per-database SQL queries.
# Instead, we assume that because external DB support was added at a later date, it is safe to assume a newer base
# schema instead of trying to migrate from an older schema.
if isinstance(database, pw.SqliteDatabase):
migrate_sqlite(migrator, database, fake=fake)
else:
migrate_external(migrator, database, fake=fake)
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model @migrator.create_model
class Auth(pw.Model): class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True) id = pw.CharField(max_length=255, unique=True)
...@@ -129,6 +141,99 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): ...@@ -129,6 +141,99 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
table_name = "user" table_name = "user"
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True)
email = pw.CharField(max_length=255)
password = pw.TextField()
active = pw.BooleanField()
class Meta:
table_name = "auth"
@migrator.create_model
class Chat(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.TextField()
chat = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chat"
@migrator.create_model
class ChatIdTag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
tag_name = pw.CharField(max_length=255)
chat_id = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chatidtag"
@migrator.create_model
class Document(pw.Model):
id = pw.AutoField()
collection_name = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255, unique=True)
title = pw.TextField()
filename = pw.TextField()
content = pw.TextField(null=True)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "document"
@migrator.create_model
class Modelfile(pw.Model):
id = pw.AutoField()
tag_name = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
modelfile = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "modelfile"
@migrator.create_model
class Prompt(pw.Model):
id = pw.AutoField()
command = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.TextField()
content = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "prompt"
@migrator.create_model
class Tag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
data = pw.TextField(null=True)
class Meta:
table_name = "tag"
@migrator.create_model
class User(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
email = pw.CharField(max_length=255)
role = pw.CharField(max_length=255)
profile_image_url = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "user"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
......
...@@ -37,6 +37,13 @@ with suppress(ImportError): ...@@ -37,6 +37,13 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
if isinstance(database, pw.SqliteDatabase):
migrate_sqlite(migrator, database, fake=fake)
else:
migrate_external(migrator, database, fake=fake)
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table # Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields( migrator.add_fields(
"chat", "chat",
...@@ -60,9 +67,40 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): ...@@ -60,9 +67,40 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
) )
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields(
"chat",
created_at=pw.BigIntegerField(null=True), # Allow null for transition
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
)
# Populate the new fields from an existing 'timestamp' field
migrator.sql(
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
)
# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("chat", "timestamp")
# Update the fields to be not null now that they are populated
migrator.change_fields(
"chat",
created_at=pw.BigIntegerField(null=False),
updated_at=pw.BigIntegerField(null=False),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
if isinstance(database, pw.SqliteDatabase):
rollback_sqlite(migrator, database, fake=fake)
else:
rollback_external(migrator, database, fake=fake)
def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition # Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True)) migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True))
...@@ -75,3 +113,18 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False): ...@@ -75,3 +113,18 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
# Finally, alter the timestamp field to not allow nulls if that was the original setting # Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False)) migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False))
def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True))
# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql("UPDATE chat SET timestamp = created_at")
# Remove the created_at and updated_at fields
migrator.remove_fields("chat", "created_at", "updated_at")
# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False))
"""Peewee migrations -- 006_migrate_timestamps_and_charfields.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"document",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"modelfile",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"prompt",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"user",
timestamp=pw.BigIntegerField(),
)
# Alter the tables with varchar to text where necessary
migrator.change_fields(
"auth",
password=pw.TextField(),
)
migrator.change_fields(
"chat",
title=pw.TextField(),
)
migrator.change_fields(
"document",
title=pw.TextField(),
filename=pw.TextField(),
)
migrator.change_fields(
"prompt",
title=pw.TextField(),
)
migrator.change_fields(
"user",
profile_image_url=pw.TextField(),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
if isinstance(database, pw.SqliteDatabase):
# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
timestamp=pw.DateField(),
)
migrator.change_fields(
"document",
timestamp=pw.DateField(),
)
migrator.change_fields(
"modelfile",
timestamp=pw.DateField(),
)
migrator.change_fields(
"prompt",
timestamp=pw.DateField(),
)
migrator.change_fields(
"user",
timestamp=pw.DateField(),
)
migrator.change_fields(
"auth",
password=pw.CharField(max_length=255),
)
migrator.change_fields(
"chat",
title=pw.CharField(),
)
migrator.change_fields(
"document",
title=pw.CharField(),
filename=pw.CharField(),
)
migrator.change_fields(
"prompt",
title=pw.CharField(),
)
migrator.change_fields(
"user",
profile_image_url=pw.CharField(),
)
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields created_at and updated_at to the 'user' table
migrator.add_fields(
"user",
created_at=pw.BigIntegerField(null=True), # Allow null for transition
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
last_active_at=pw.BigIntegerField(null=True), # Allow null for transition
)
# Populate the new fields from an existing 'timestamp' field
migrator.sql(
"UPDATE user SET created_at = timestamp, updated_at = timestamp, last_active_at = timestamp WHERE timestamp IS NOT NULL"
)
# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("user", "timestamp")
# Update the fields to be not null now that they are populated
migrator.change_fields(
"user",
created_at=pw.BigIntegerField(null=False),
updated_at=pw.BigIntegerField(null=False),
last_active_at=pw.BigIntegerField(null=False),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("user", timestamp=pw.BigIntegerField(null=True))
# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql("UPDATE user SET timestamp = created_at")
# Remove the created_at and updated_at fields
migrator.remove_fields("user", "created_at", "updated_at", "last_active_at")
# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("user", timestamp=pw.BigIntegerField(null=False))
...@@ -23,7 +23,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -23,7 +23,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Auth(Model): class Auth(Model):
id = CharField(unique=True) id = CharField(unique=True)
email = CharField() email = CharField()
password = CharField() password = TextField()
active = BooleanField() active = BooleanField()
class Meta: class Meta:
......
...@@ -17,11 +17,11 @@ from apps.web.internal.db import DB ...@@ -17,11 +17,11 @@ from apps.web.internal.db import DB
class Chat(Model): class Chat(Model):
id = CharField(unique=True) id = CharField(unique=True)
user_id = CharField() user_id = CharField()
title = CharField() title = TextField()
chat = TextField() # Save Chat JSON as Text chat = TextField() # Save Chat JSON as Text
created_at = DateTimeField() created_at = BigIntegerField()
updated_at = DateTimeField() updated_at = BigIntegerField()
share_id = CharField(null=True, unique=True) share_id = CharField(null=True, unique=True)
archived = BooleanField(default=False) archived = BooleanField(default=False)
...@@ -191,7 +191,7 @@ class ChatTable: ...@@ -191,7 +191,7 @@ class ChatTable:
except: except:
return None return None
def get_archived_chat_lists_by_user_id( def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
return [ return [
...@@ -204,7 +204,7 @@ class ChatTable: ...@@ -204,7 +204,7 @@ class ChatTable:
# .offset(skip) # .offset(skip)
] ]
def get_chat_lists_by_user_id( def get_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
return [ return [
...@@ -217,7 +217,7 @@ class ChatTable: ...@@ -217,7 +217,7 @@ class ChatTable:
# .offset(skip) # .offset(skip)
] ]
def get_chat_lists_by_chat_ids( def get_chat_list_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50 self, chat_ids: List[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
return [ return [
...@@ -228,20 +228,6 @@ class ChatTable: ...@@ -228,20 +228,6 @@ class ChatTable:
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
] ]
def get_all_chats(self) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select().order_by(Chat.updated_at.desc())
]
def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
chat = Chat.get(Chat.id == id) chat = Chat.get(Chat.id == id)
...@@ -271,9 +257,28 @@ class ChatTable: ...@@ -271,9 +257,28 @@ class ChatTable:
def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
return [ return [
ChatModel(**model_to_dict(chat)) ChatModel(**model_to_dict(chat))
for chat in Chat.select().limit(limit).offset(skip) for chat in Chat.select().order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
] ]
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
]
def delete_chat_by_id(self, id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id))
query.execute() # Remove the rows, return number of rows removed.
return True and self.delete_shared_chat_by_chat_id(id)
except:
return False
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id)) query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment