Unverified Commit 38ff3209 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #1881 from open-webui/dev

0.1.123
parents c9589e21 2789102d
...@@ -10,7 +10,8 @@ OPENAI_API_KEY='' ...@@ -10,7 +10,8 @@ OPENAI_API_KEY=''
# DO NOT TRACK # DO NOT TRACK
SCARF_NO_ANALYTICS=true SCARF_NO_ANALYTICS=true
DO_NOT_TRACK=true DO_NOT_TRACK=true
ANONYMIZED_TELEMETRY=false
# Use locally bundled version of the LiteLLM cost map json # Use locally bundled version of the LiteLLM cost map json
# to avoid repetitive startup connections # to avoid repetitive startup connections
LITELLM_LOCAL_MODEL_COST_MAP="True" LITELLM_LOCAL_MODEL_COST_MAP="True"
\ No newline at end of file
version: 2
updates:
- package-ecosystem: pip
directory: "/backend"
schedule:
interval: daily
time: "13:00"
groups:
python-packages:
patterns:
- "*"
...@@ -29,6 +29,9 @@ jobs: ...@@ -29,6 +29,9 @@ jobs:
- name: Format Frontend - name: Format Frontend
run: npm run format run: npm run format
- name: Run i18next
run: npm run i18n:parse
- name: Check for Changes After Format - name: Check for Changes After Format
run: git diff --exit-code run: git diff --exit-code
......
...@@ -53,3 +53,134 @@ jobs: ...@@ -53,3 +53,134 @@ jobs:
name: compose-logs name: compose-logs
path: compose-logs.txt path: compose-logs.txt
if-no-files-found: ignore if-no-files-found: ignore
migration_test:
name: Run Migration Tests
runs-on: ubuntu-latest
services:
postgres:
image: postgres
env:
POSTGRES_PASSWORD: postgres
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
# mysql:
# image: mysql
# env:
# MYSQL_ROOT_PASSWORD: mysql
# MYSQL_DATABASE: mysql
# options: >-
# --health-cmd "mysqladmin ping -h localhost"
# --health-interval 10s
# --health-timeout 5s
# --health-retries 5
# ports:
# - 3306:3306
steps:
- name: Checkout Repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Set up uv
uses: yezz123/setup-uv@v4
with:
uv-venv: venv
- name: Activate virtualenv
run: |
. venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
- name: Install dependencies
run: |
uv pip install -r backend/requirements.txt
- name: Test backend with SQLite
id: sqlite
env:
WEBUI_SECRET_KEY: secret-key
GLOBAL_LOG_LEVEL: debug
run: |
cd backend
uvicorn main:app --port "8080" --forwarded-allow-ips '*' &
UVICORN_PID=$!
# Wait up to 20 seconds for the server to start
for i in {1..20}; do
curl -s http://localhost:8080/api/config > /dev/null && break
sleep 1
if [ $i -eq 20 ]; then
echo "Server failed to start"
kill -9 $UVICORN_PID
exit 1
fi
done
# Check that the server is still running after 5 seconds
sleep 5
if ! kill -0 $UVICORN_PID; then
echo "Server has stopped"
exit 1
fi
- name: Test backend with Postgres
if: success() || steps.sqlite.conclusion == 'failure'
env:
WEBUI_SECRET_KEY: secret-key
GLOBAL_LOG_LEVEL: debug
DATABASE_URL: postgresql://postgres:postgres@localhost:5432/postgres
run: |
cd backend
uvicorn main:app --port "8081" --forwarded-allow-ips '*' &
UVICORN_PID=$!
# Wait up to 20 seconds for the server to start
for i in {1..20}; do
curl -s http://localhost:8081/api/config > /dev/null && break
sleep 1
if [ $i -eq 20 ]; then
echo "Server failed to start"
kill -9 $UVICORN_PID
exit 1
fi
done
# Check that the server is still running after 5 seconds
sleep 5
if ! kill -0 $UVICORN_PID; then
echo "Server has stopped"
exit 1
fi
# - name: Test backend with MySQL
# if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure'
# env:
# WEBUI_SECRET_KEY: secret-key
# GLOBAL_LOG_LEVEL: debug
# DATABASE_URL: mysql://root:mysql@localhost:3306/mysql
# run: |
# cd backend
# uvicorn main:app --port "8083" --forwarded-allow-ips '*' &
# UVICORN_PID=$!
# # Wait up to 20 seconds for the server to start
# for i in {1..20}; do
# curl -s http://localhost:8083/api/config > /dev/null && break
# sleep 1
# if [ $i -eq 20 ]; then
# echo "Server failed to start"
# kill -9 $UVICORN_PID
# exit 1
# fi
# done
# # Check that the server is still running after 5 seconds
# sleep 5
# if ! kill -0 $UVICORN_PID; then
# echo "Server has stopped"
# exit 1
# fi
...@@ -5,6 +5,32 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,32 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.123] - 2024-05-02
### Added
- **🎨 New Landing Page Design**: Refreshed design for a more modern look and optimized use of screen space.
- **📹 Youtube RAG Pipeline**: Introduces dedicated RAG pipeline for Youtube videos, enabling interaction with video transcriptions directly.
- **🔧 Enhanced Admin Panel**: Streamlined user management with options to add users directly or in bulk via CSV import.
- **👥 '@' Model Integration**: Easily switch to specific models during conversations; old collaborative chat feature phased out.
- **🌐 Language Enhancements**: Swedish translation added, plus improvements to German, Spanish, and the addition of Doge translation.
### Fixed
- **🗑️ Delete Chat Shortcut**: Addressed issue where shortcut wasn't functioning.
- **🖼️ Modal Closing Bug**: Resolved unexpected closure of modal when dragging from within.
- **✏️ Edit Button Styling**: Fixed styling inconsistency with edit buttons.
- **🌐 Image Generation Compatibility Issue**: Rectified image generation compatibility issue with third-party APIs.
- **📱 iOS PWA Icon Fix**: Corrected iOS PWA home screen icon shape.
- **🔍 Scroll Gesture Bug**: Adjusted gesture sensitivity to prevent accidental activation when scrolling through code on mobile; now requires scrolling from the leftmost side to open the sidebar.
### Changed
- **🔄 Unlimited Context Length**: Advanced settings now allow unlimited max context length (previously limited to 16000).
- **👑 Super Admin Assignment**: The first signup is automatically assigned a super admin role, unchangeable by other admins.
- **🛡️ Admin User Restrictions**: User action buttons from the admin panel are now disabled for users with admin roles.
- **🔝 Default Model Selector**: Set as default model option now exclusively available on the landing page.
## [0.1.122] - 2024-04-27 ## [0.1.122] - 2024-04-27
### Added ### Added
......
...@@ -51,7 +51,8 @@ ENV OLLAMA_BASE_URL="/ollama" \ ...@@ -51,7 +51,8 @@ ENV OLLAMA_BASE_URL="/ollama" \
ENV OPENAI_API_KEY="" \ ENV OPENAI_API_KEY="" \
WEBUI_SECRET_KEY="" \ WEBUI_SECRET_KEY="" \
SCARF_NO_ANALYTICS=true \ SCARF_NO_ANALYTICS=true \
DO_NOT_TRACK=true DO_NOT_TRACK=true \
ANONYMIZED_TELEMETRY=false
# Use locally bundled version of the LiteLLM cost map json # Use locally bundled version of the LiteLLM cost map json
# to avoid repetitive startup connections # to avoid repetitive startup connections
...@@ -74,6 +75,10 @@ ENV HF_HOME="/app/backend/data/cache/embedding/models" ...@@ -74,6 +75,10 @@ ENV HF_HOME="/app/backend/data/cache/embedding/models"
WORKDIR /app/backend WORKDIR /app/backend
ENV HOME /root
RUN mkdir -p $HOME/.cache/chroma
RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
RUN if [ "$USE_OLLAMA" = "true" ]; then \ RUN if [ "$USE_OLLAMA" = "true" ]; then \
apt-get update && \ apt-get update && \
# Install pandoc and netcat # Install pandoc and netcat
...@@ -129,4 +134,4 @@ COPY ./backend . ...@@ -129,4 +134,4 @@ COPY ./backend .
EXPOSE 8080 EXPOSE 8080
CMD [ "bash", "start.sh"] CMD [ "bash", "start.sh"]
\ No newline at end of file
...@@ -24,6 +24,7 @@ from utils.misc import calculate_sha256 ...@@ -24,6 +24,7 @@ from utils.misc import calculate_sha256
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from pathlib import Path from pathlib import Path
import mimetypes
import uuid import uuid
import base64 import base64
import json import json
...@@ -315,38 +316,50 @@ class GenerateImageForm(BaseModel): ...@@ -315,38 +316,50 @@ class GenerateImageForm(BaseModel):
def save_b64_image(b64_str): def save_b64_image(b64_str):
image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try: try:
# Split the base64 string to get the actual image data header, encoded = b64_str.split(",", 1)
img_data = base64.b64decode(b64_str) mime_type = header.split(";")[0]
img_data = base64.b64decode(encoded)
# Write the image data to a file image_id = str(uuid.uuid4())
image_format = mimetypes.guess_extension(mime_type)
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(img_data) f.write(img_data)
return image_filename
return image_id
except Exception as e: except Exception as e:
log.error(f"Error saving image: {e}") log.exception(f"Error saving image: {e}")
return None return None
def save_url_image(url): def save_url_image(url):
image_id = str(uuid.uuid4()) image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try: try:
r = requests.get(url) r = requests.get(url)
r.raise_for_status() r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"]
image_format = mimetypes.guess_extension(mime_type)
if not image_format:
raise ValueError("Could not determine image type from MIME type")
with open(file_path, "wb") as image_file: file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}{image_format}")
image_file.write(r.content) with open(file_path, "wb") as image_file:
for chunk in r.iter_content(chunk_size=8192):
image_file.write(chunk)
return image_id, image_format
else:
log.error(f"Url does not point to an image.")
return None, None
return image_id
except Exception as e: except Exception as e:
log.exception(f"Error saving image: {e}") log.exception(f"Error saving image: {e}")
return None return None, None
@app.post("/generations") @app.post("/generations")
...@@ -385,8 +398,8 @@ def generate_image( ...@@ -385,8 +398,8 @@ def generate_image(
images = [] images = []
for image in res["data"]: for image in res["data"]:
image_id = save_b64_image(image["b64_json"]) image_filename = save_b64_image(image["b64_json"])
images.append({"url": f"/cache/image/generations/{image_id}.png"}) images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f: with open(file_body_path, "w") as f:
...@@ -422,8 +435,10 @@ def generate_image( ...@@ -422,8 +435,10 @@ def generate_image(
images = [] images = []
for image in res["data"]: for image in res["data"]:
image_id = save_url_image(image["url"]) image_id, image_format = save_url_image(image["url"])
images.append({"url": f"/cache/image/generations/{image_id}.png"}) images.append(
{"url": f"/cache/image/generations/{image_id}{image_format}"}
)
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f: with open(file_body_path, "w") as f:
...@@ -460,8 +475,8 @@ def generate_image( ...@@ -460,8 +475,8 @@ def generate_image(
images = [] images = []
for image in res["images"]: for image in res["images"]:
image_id = save_b64_image(image) image_filename = save_b64_image(image)
images.append({"url": f"/cache/image/generations/{image_id}.png"}) images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f: with open(file_body_path, "w") as f:
......
...@@ -171,6 +171,7 @@ async def fetch_url(url, key): ...@@ -171,6 +171,7 @@ async def fetch_url(url, key):
def merge_models_lists(model_lists): def merge_models_lists(model_lists):
log.info(f"merge_models_lists {model_lists}")
merged_list = [] merged_list = []
for idx, models in enumerate(model_lists): for idx, models in enumerate(model_lists):
...@@ -199,14 +200,16 @@ async def get_all_models(): ...@@ -199,14 +200,16 @@ async def get_all_models():
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
log.info(f"get_all_models:responses() {responses}")
models = { models = {
"data": merge_models_lists( "data": merge_models_lists(
list( list(
map( map(
lambda response: ( lambda response: (
response["data"] response["data"]
if response and "data" in response if (response and "data" in response)
else None else (response if isinstance(response, list) else None)
), ),
responses, responses,
) )
......
...@@ -28,9 +28,15 @@ from langchain_community.document_loaders import ( ...@@ -28,9 +28,15 @@ from langchain_community.document_loaders import (
UnstructuredXMLLoader, UnstructuredXMLLoader,
UnstructuredRSTLoader, UnstructuredRSTLoader,
UnstructuredExcelLoader, UnstructuredExcelLoader,
YoutubeLoader,
) )
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
import validators
import urllib.parse
import socket
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
import mimetypes import mimetypes
...@@ -84,6 +90,7 @@ from config import ( ...@@ -84,6 +90,7 @@ from config import (
CHUNK_SIZE, CHUNK_SIZE,
CHUNK_OVERLAP, CHUNK_OVERLAP,
RAG_TEMPLATE, RAG_TEMPLATE,
ENABLE_LOCAL_WEB_FETCH,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -175,7 +182,7 @@ class CollectionNameForm(BaseModel): ...@@ -175,7 +182,7 @@ class CollectionNameForm(BaseModel):
collection_name: Optional[str] = "test" collection_name: Optional[str] = "test"
class StoreWebForm(CollectionNameForm): class UrlForm(CollectionNameForm):
url: str url: str
...@@ -391,16 +398,16 @@ def query_doc_handler( ...@@ -391,16 +398,16 @@ def query_doc_handler(
return query_doc_with_hybrid_search( return query_doc_with_hybrid_search(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query=form_data.query, query=form_data.query,
embeddings_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
reranking_function=app.state.sentence_transformer_rf,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.TOP_K,
reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
) )
else: else:
return query_doc( return query_doc(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query=form_data.query, query=form_data.query,
embeddings_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.TOP_K,
) )
except Exception as e: except Exception as e:
...@@ -429,16 +436,16 @@ def query_collection_handler( ...@@ -429,16 +436,16 @@ def query_collection_handler(
return query_collection_with_hybrid_search( return query_collection_with_hybrid_search(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
query=form_data.query, query=form_data.query,
embeddings_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
reranking_function=app.state.sentence_transformer_rf,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.TOP_K,
reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
) )
else: else:
return query_collection( return query_collection(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
query=form_data.query, query=form_data.query,
embeddings_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.TOP_K,
) )
...@@ -450,11 +457,35 @@ def query_collection_handler( ...@@ -450,11 +457,35 @@ def query_collection_handler(
) )
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
try:
loader = YoutubeLoader.from_youtube_url(form_data.url, add_video_info=False)
data = loader.load()
collection_name = form_data.collection_name
if collection_name == "":
collection_name = calculate_sha256_string(form_data.url)[:63]
store_data_in_vector_db(data, collection_name, overwrite=True)
return {
"status": True,
"collection_name": collection_name,
"filename": form_data.url,
}
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@app.post("/web") @app.post("/web")
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try: try:
loader = WebBaseLoader(form_data.url) loader = get_web_loader(form_data.url)
data = loader.load() data = loader.load()
collection_name = form_data.collection_name collection_name = form_data.collection_name
...@@ -475,6 +506,37 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): ...@@ -475,6 +506,37 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
) )
def get_web_loader(url: str):
# Check if the URL is valid
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url)
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
# Extract IP addresses from address information
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
return ipv4_addresses, ipv6_addresses
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
......
...@@ -35,6 +35,7 @@ def query_doc( ...@@ -35,6 +35,7 @@ def query_doc(
try: try:
collection = CHROMA_CLIENT.get_collection(name=collection_name) collection = CHROMA_CLIENT.get_collection(name=collection_name)
query_embeddings = embedding_function(query) query_embeddings = embedding_function(query)
result = collection.query( result = collection.query(
query_embeddings=[query_embeddings], query_embeddings=[query_embeddings],
n_results=k, n_results=k,
...@@ -52,7 +53,7 @@ def query_doc_with_hybrid_search( ...@@ -52,7 +53,7 @@ def query_doc_with_hybrid_search(
embedding_function, embedding_function,
k: int, k: int,
reranking_function, reranking_function,
r: int, r: float,
): ):
try: try:
collection = CHROMA_CLIENT.get_collection(name=collection_name) collection = CHROMA_CLIENT.get_collection(name=collection_name)
...@@ -76,9 +77,9 @@ def query_doc_with_hybrid_search( ...@@ -76,9 +77,9 @@ def query_doc_with_hybrid_search(
compressor = RerankCompressor( compressor = RerankCompressor(
embedding_function=embedding_function, embedding_function=embedding_function,
top_n=k,
reranking_function=reranking_function, reranking_function=reranking_function,
r_score=r, r_score=r,
top_n=k,
) )
compression_retriever = ContextualCompressionRetriever( compression_retriever = ContextualCompressionRetriever(
...@@ -91,6 +92,7 @@ def query_doc_with_hybrid_search( ...@@ -91,6 +92,7 @@ def query_doc_with_hybrid_search(
"documents": [[d.page_content for d in result]], "documents": [[d.page_content for d in result]],
"metadatas": [[d.metadata for d in result]], "metadatas": [[d.metadata for d in result]],
} }
log.info(f"query_doc_with_hybrid_search:result {result}") log.info(f"query_doc_with_hybrid_search:result {result}")
return result return result
except Exception as e: except Exception as e:
...@@ -167,7 +169,6 @@ def query_collection_with_hybrid_search( ...@@ -167,7 +169,6 @@ def query_collection_with_hybrid_search(
reranking_function, reranking_function,
r: float, r: float,
): ):
results = [] results = []
for collection_name in collection_names: for collection_name in collection_names:
try: try:
...@@ -182,7 +183,6 @@ def query_collection_with_hybrid_search( ...@@ -182,7 +183,6 @@ def query_collection_with_hybrid_search(
results.append(result) results.append(result)
except: except:
pass pass
return merge_and_sort_query_results(results, k=k, reverse=True) return merge_and_sort_query_results(results, k=k, reverse=True)
...@@ -321,8 +321,12 @@ def rag_messages( ...@@ -321,8 +321,12 @@ def rag_messages(
context_string = "" context_string = ""
for context in relevant_contexts: for context in relevant_contexts:
items = context["documents"][0] try:
context_string += "\n\n".join(items) if "documents" in context:
items = [item for item in context["documents"][0] if item is not None]
context_string += "\n\n".join(items)
except Exception as e:
log.exception(e)
context_string = context_string.strip() context_string = context_string.strip()
ra_content = rag_template( ra_content = rag_template(
...@@ -443,13 +447,15 @@ class ChromaRetriever(BaseRetriever): ...@@ -443,13 +447,15 @@ class ChromaRetriever(BaseRetriever):
metadatas = results["metadatas"][0] metadatas = results["metadatas"][0]
documents = results["documents"][0] documents = results["documents"][0]
return [ results = []
Document( for idx in range(len(ids)):
metadata=metadatas[idx], results.append(
page_content=documents[idx], Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
) )
for idx in range(len(ids)) return results
]
import operator import operator
...@@ -465,9 +471,9 @@ from sentence_transformers import util ...@@ -465,9 +471,9 @@ from sentence_transformers import util
class RerankCompressor(BaseDocumentCompressor): class RerankCompressor(BaseDocumentCompressor):
embedding_function: Any embedding_function: Any
top_n: int
reranking_function: Any reranking_function: Any
r_score: float r_score: float
top_n: int
class Config: class Config:
extra = Extra.forbid extra = Extra.forbid
...@@ -479,7 +485,9 @@ class RerankCompressor(BaseDocumentCompressor): ...@@ -479,7 +485,9 @@ class RerankCompressor(BaseDocumentCompressor):
query: str, query: str,
callbacks: Optional[Callbacks] = None, callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]: ) -> Sequence[Document]:
if self.reranking_function: reranking = self.reranking_function is not None
if reranking:
scores = self.reranking_function.predict( scores = self.reranking_function.predict(
[(query, doc.page_content) for doc in documents] [(query, doc.page_content) for doc in documents]
) )
...@@ -496,9 +504,7 @@ class RerankCompressor(BaseDocumentCompressor): ...@@ -496,9 +504,7 @@ class RerankCompressor(BaseDocumentCompressor):
(d, s) for d, s in docs_with_scores if s >= self.r_score (d, s) for d, s in docs_with_scores if s >= self.r_score
] ]
reverse = self.reranking_function is not None result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse)
final_results = [] final_results = []
for doc, doc_score in result[: self.top_n]: for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata metadata = doc.metadata
......
...@@ -89,6 +89,10 @@ class SignupForm(BaseModel): ...@@ -89,6 +89,10 @@ class SignupForm(BaseModel):
profile_image_url: Optional[str] = "/user.png" profile_image_url: Optional[str] = "/user.png"
class AddUserForm(SignupForm):
role: Optional[str] = "pending"
class AuthsTable: class AuthsTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
......
...@@ -123,6 +123,13 @@ class UsersTable: ...@@ -123,6 +123,13 @@ class UsersTable:
def get_num_users(self) -> Optional[int]: def get_num_users(self) -> Optional[int]:
return User.select().count() return User.select().count()
def get_first_user(self) -> UserModel:
try:
user = User.select().order_by(User.created_at).first()
return UserModel(**model_to_dict(user))
except:
return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try: try:
query = User.update(role=role).where(User.id == id) query = User.update(role=role).where(User.id == id)
......
import logging import logging
from fastapi import Request from fastapi import Request, UploadFile, File
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import re import re
import uuid import uuid
import csv
from apps.web.models.auths import ( from apps.web.models.auths import (
SigninForm, SigninForm,
SignupForm, SignupForm,
AddUserForm,
UpdateProfileForm, UpdateProfileForm,
UpdatePasswordForm, UpdatePasswordForm,
UserResponse, UserResponse,
...@@ -205,6 +208,51 @@ async def signup(request: Request, form_data: SignupForm): ...@@ -205,6 +208,51 @@ async def signup(request: Request, form_data: SignupForm):
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
############################
# AddUser
############################
@router.post("/add", response_model=SigninResponse)
async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
if not validate_email_format(form_data.email.lower()):
raise HTTPException(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
)
if Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try:
print(form_data)
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(
form_data.email.lower(),
hashed,
form_data.name,
form_data.profile_image_url,
form_data.role,
)
if user:
token = create_token(data={"id": user.id})
return {
"token": token,
"token_type": "Bearer",
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
"profile_image_url": user.profile_image_url,
}
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
############################ ############################
# ToggleSignUp # ToggleSignUp
############################ ############################
......
...@@ -58,7 +58,7 @@ async def update_user_permissions( ...@@ -58,7 +58,7 @@ async def update_user_permissions(
@router.post("/update/role", response_model=Optional[UserModel]) @router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)): async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
if user.id != form_data.id: if user.id != form_data.id and form_data.id != Users.get_first_user().id:
return Users.update_user_role_by_id(form_data.id, form_data.role) return Users.update_user_role_by_id(form_data.id, form_data.role)
raise HTTPException( raise HTTPException(
......
...@@ -168,7 +168,11 @@ except: ...@@ -168,7 +168,11 @@ except:
STATIC_DIR = str(Path(os.getenv("STATIC_DIR", "./static")).resolve()) STATIC_DIR = str(Path(os.getenv("STATIC_DIR", "./static")).resolve())
shutil.copyfile(f"{FRONTEND_BUILD_DIR}/favicon.png", f"{STATIC_DIR}/favicon.png") frontend_favicon = f"{FRONTEND_BUILD_DIR}/favicon.png"
if os.path.exists(frontend_favicon):
shutil.copyfile(frontend_favicon, f"{STATIC_DIR}/favicon.png")
else:
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
#################################### ####################################
# CUSTOM_NAME # CUSTOM_NAME
...@@ -363,6 +367,17 @@ DEFAULT_PROMPT_SUGGESTIONS = ( ...@@ -363,6 +367,17 @@ DEFAULT_PROMPT_SUGGESTIONS = (
"title": ["Show me a code snippet", "of a website's sticky header"], "title": ["Show me a code snippet", "of a website's sticky header"],
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.", "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.",
}, },
{
"title": [
"Explain options trading",
"if I'm familiar with buying and selling stocks",
],
"content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.",
},
{
"title": ["Overcome procrastination", "give me tips"],
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
},
] ]
) )
...@@ -516,6 +531,8 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE) ...@@ -516,6 +531,8 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true"
#################################### ####################################
# Transcribe # Transcribe
#################################### ####################################
......
...@@ -71,3 +71,7 @@ class ERROR_MESSAGES(str, Enum): ...@@ -71,3 +71,7 @@ class ERROR_MESSAGES(str, Enum):
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases." DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
INVALID_URL = (
"Oops! The URL you provided is invalid. Please double-check and try again."
)
...@@ -18,6 +18,18 @@ ...@@ -18,6 +18,18 @@
{ {
"title": ["Show me a code snippet", "of a website's sticky header"], "title": ["Show me a code snippet", "of a website's sticky header"],
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript." "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript."
},
{
"title": ["Explain options trading", "if I'm familiar with buying and selling stocks"],
"content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks."
},
{
"title": ["Overcome procrastination", "give me tips"],
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?"
},
{
"title": ["Grammar check", "rewrite it for better readability "],
"content": "Check the following sentence for grammar and clarity: \"[sentence]\". Rewrite it for better readability while maintaining its original meaning."
} }
] ]
} }
......
File mode changed from 100644 to 100755
...@@ -311,18 +311,23 @@ async def get_manifest_json(): ...@@ -311,18 +311,23 @@ async def get_manifest_json():
"background_color": "#343541", "background_color": "#343541",
"theme_color": "#343541", "theme_color": "#343541",
"orientation": "portrait-primary", "orientation": "portrait-primary",
"icons": [{"src": "/favicon.png", "type": "image/png", "sizes": "844x884"}], "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
} }
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
app.mount( if os.path.exists(FRONTEND_BUILD_DIR):
"/", app.mount(
SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), "/",
name="spa-static-files", SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
) name="spa-static-files",
)
else:
log.warning(
f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only."
)
@app.on_event("shutdown") @app.on_event("shutdown")
......
fastapi fastapi==0.109.2
uvicorn[standard] uvicorn[standard]==0.22.0
pydantic pydantic==2.7.1
python-multipart python-multipart==0.0.9
flask Flask==3.0.3
flask_cors Flask-Cors==4.0.0
python-socketio python-socketio==5.11.2
python-jose python-jose==3.3.0
passlib[bcrypt] passlib[bcrypt]==1.7.4
uuid uuid==1.30
requests requests==2.31.0
aiohttp aiohttp==3.9.5
peewee peewee==3.17.3
peewee-migrate peewee-migrate==1.12.2
psycopg2-binary psycopg2-binary==2.9.9
pymysql PyMySQL==1.1.0
bcrypt bcrypt==4.1.2
litellm==1.35.17 litellm==1.35.28
litellm[proxy]==1.35.17 litellm[proxy]==1.35.28
boto3 boto3==1.34.95
argon2-cffi argon2-cffi==23.1.0
apscheduler APScheduler==3.10.4
google-generativeai google-generativeai==0.5.2
langchain langchain==0.1.16
langchain-chroma langchain-community==0.0.34
langchain-community langchain-chroma==0.1.0
fake_useragent
chromadb fake-useragent==1.5.1
sentence_transformers chromadb==0.4.24
pypdf sentence-transformers==2.7.0
docx2txt pypdf==4.2.0
unstructured docx2txt==0.8
markdown unstructured==0.11.8
pypandoc Markdown==3.6
pandas pypandoc==1.13
openpyxl pandas==2.2.2
pyxlsb openpyxl==3.1.2
xlrd pyxlsb==1.0.10
xlrd==2.0.1
opencv-python-headless validators==0.28.1
rapidocr-onnxruntime
opencv-python-headless==4.9.0.80
fpdf2 rapidocr-onnxruntime==1.2.3
rank_bm25
fpdf2==2.7.8
faster-whisper rank-bm25==0.2.2
PyJWT faster-whisper==1.0.1
pyjwt[crypto]
PyJWT==2.8.0
black PyJWT[crypto]==2.8.0
langfuse
black==24.4.2
langfuse==2.27.3
youtube-transcript-api
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment