Unverified Commit 13b0e7d6 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #4434 from open-webui/dev

0.3.13
parents 8d257ed5 c8badfe2
......@@ -17,7 +17,7 @@ from utils.misc import calculate_sha256, get_gravatar_url
from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR, ENABLE_ADMIN_EXPORT
from constants import ERROR_MESSAGES
from typing import List
router = APIRouter()
......@@ -57,7 +57,7 @@ async def get_html_from_markdown(
class ChatForm(BaseModel):
title: str
messages: List[dict]
messages: list[dict]
@router.post("/pdf")
......
from importlib import util
import os
import re
import sys
import subprocess
from config import TOOLS_DIR, FUNCTIONS_DIR
......@@ -52,6 +54,7 @@ def load_toolkit_module_by_id(toolkit_id):
frontmatter = extract_frontmatter(toolkit_path)
try:
install_frontmatter_requirements(frontmatter.get("requirements", ""))
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"):
......@@ -73,6 +76,7 @@ def load_function_module_by_id(function_id):
frontmatter = extract_frontmatter(function_path)
try:
install_frontmatter_requirements(frontmatter.get("requirements", ""))
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Pipe"):
......@@ -88,3 +92,13 @@ def load_function_module_by_id(function_id):
# Move the file to the error folder
os.rename(function_path, f"{function_path}.error")
raise e
def install_frontmatter_requirements(requirements):
if requirements:
req_list = [req.strip() for req in requirements.split(",")]
for req in req_list:
print(f"Installing requirement: {req}")
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
else:
print("No requirements found in frontmatter.")
......@@ -104,7 +104,7 @@ ENV = os.environ.get("ENV", "dev")
try:
PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
except:
except Exception:
try:
PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
except importlib.metadata.PackageNotFoundError:
......@@ -137,7 +137,7 @@ try:
with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
changelog_content = file.read()
except:
except Exception:
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
......@@ -202,12 +202,12 @@ if RESET_CONFIG_ON_START:
os.remove(f"{DATA_DIR}/config.json")
with open(f"{DATA_DIR}/config.json", "w") as f:
f.write("{}")
except:
except Exception:
pass
try:
CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
except:
except Exception:
CONFIG_DATA = {}
......@@ -433,6 +433,12 @@ OAUTH_PICTURE_CLAIM = PersistentConfig(
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
)
OAUTH_EMAIL_CLAIM = PersistentConfig(
"OAUTH_EMAIL_CLAIM",
"oauth.oidc.email_claim",
os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
)
def load_oauth_providers():
OAUTH_PROVIDERS.clear()
......@@ -641,7 +647,7 @@ if AIOHTTP_CLIENT_TIMEOUT == "":
else:
try:
AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
except:
except Exception:
AIOHTTP_CLIENT_TIMEOUT = 300
......@@ -721,7 +727,7 @@ try:
OPENAI_API_KEY = OPENAI_API_KEYS.value[
OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
]
except:
except Exception:
pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
......@@ -1037,7 +1043,7 @@ RAG_EMBEDDING_MODEL = PersistentConfig(
"rag.embedding_model",
os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"),
)
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}"),
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
......@@ -1059,7 +1065,7 @@ RAG_RERANKING_MODEL = PersistentConfig(
os.environ.get("RAG_RERANKING_MODEL", ""),
)
if RAG_RERANKING_MODEL.value != "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}"),
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
RAG_RERANKING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
......
......@@ -51,7 +51,7 @@ from apps.webui.internal.db import Session
from pydantic import BaseModel
from typing import List, Optional
from typing import Optional
from apps.webui.models.auths import Auths
from apps.webui.models.models import Models
......@@ -1883,7 +1883,7 @@ async def get_pipeline_valves(
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
except Exception:
pass
raise HTTPException(
......@@ -2027,7 +2027,7 @@ async def get_model_filter_config(user=Depends(get_admin_user)):
class ModelFilterConfigForm(BaseModel):
enabled: bool
models: List[str]
models: list[str]
@app.post("/api/config/model/filter")
......@@ -2158,7 +2158,8 @@ async def oauth_callback(provider: str, request: Request, response: Response):
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email = user_data.get("email", "").lower()
email_claim = webui_app.state.config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "").lower()
# We currently mandate that email addresses are provided
if not email:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
......
......@@ -11,7 +11,7 @@ python-jose==3.3.0
passlib[bcrypt]==1.7.4
requests==2.32.3
aiohttp==3.9.5
aiohttp==3.10.2
sqlalchemy==2.0.31
alembic==1.13.2
......@@ -34,12 +34,12 @@ anthropic
google-generativeai==0.7.2
tiktoken
langchain==0.2.11
langchain==0.2.12
langchain-community==0.2.10
langchain-chroma==0.1.2
fake-useragent==1.5.1
chromadb==0.5.4
chromadb==0.5.5
sentence-transformers==3.0.1
pypdf==4.3.1
docx2txt==0.8
......@@ -62,11 +62,11 @@ rank-bm25==0.2.2
faster-whisper==1.0.2
PyJWT[crypto]==2.8.0
PyJWT[crypto]==2.9.0
authlib==1.3.1
black==24.8.0
langfuse==2.39.2
langfuse==2.43.3
youtube-transcript-api==0.6.2
pytube==15.0.0
......@@ -76,5 +76,5 @@ duckduckgo-search~=6.2.1
## Tests
docker~=7.1.0
pytest~=8.2.2
pytest~=8.3.2
pytest-docker~=3.1.1
......@@ -30,7 +30,6 @@ if [[ "${USE_CUDA_DOCKER,,}" == "true" ]]; then
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib"
fi
# Check if SPACE_ID is set, if so, configure for space
if [ -n "$SPACE_ID" ]; then
echo "Configuring for HuggingFace Space deployment"
......
......@@ -2,14 +2,14 @@ from pathlib import Path
import hashlib
import re
from datetime import timedelta
from typing import Optional, List, Tuple
from typing import Optional, Callable
import uuid
import time
from utils.task import prompt_template
def get_last_user_message_item(messages: List[dict]) -> Optional[dict]:
def get_last_user_message_item(messages: list[dict]) -> Optional[dict]:
for message in reversed(messages):
if message["role"] == "user":
return message
......@@ -26,7 +26,7 @@ def get_content_from_message(message: dict) -> Optional[str]:
return None
def get_last_user_message(messages: List[dict]) -> Optional[str]:
def get_last_user_message(messages: list[dict]) -> Optional[str]:
message = get_last_user_message_item(messages)
if message is None:
return None
......@@ -34,31 +34,31 @@ def get_last_user_message(messages: List[dict]) -> Optional[str]:
return get_content_from_message(message)
def get_last_assistant_message(messages: List[dict]) -> Optional[str]:
def get_last_assistant_message(messages: list[dict]) -> Optional[str]:
for message in reversed(messages):
if message["role"] == "assistant":
return get_content_from_message(message)
return None
def get_system_message(messages: List[dict]) -> Optional[dict]:
def get_system_message(messages: list[dict]) -> Optional[dict]:
for message in messages:
if message["role"] == "system":
return message
return None
def remove_system_message(messages: List[dict]) -> List[dict]:
def remove_system_message(messages: list[dict]) -> list[dict]:
return [message for message in messages if message["role"] != "system"]
def pop_system_message(messages: List[dict]) -> Tuple[Optional[dict], List[dict]]:
def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]]:
return get_system_message(messages), remove_system_message(messages)
def prepend_to_first_user_message_content(
content: str, messages: List[dict]
) -> List[dict]:
content: str, messages: list[dict]
) -> list[dict]:
for message in messages:
if message["role"] == "user":
if isinstance(message["content"], list):
......@@ -71,7 +71,7 @@ def prepend_to_first_user_message_content(
return messages
def add_or_update_system_message(content: str, messages: List[dict]):
def add_or_update_system_message(content: str, messages: list[dict]):
"""
Adds a new system message at the beginning of the messages list
or updates the existing system message at the beginning.
......@@ -135,10 +135,21 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
# inplace function: form_data is modified
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
def apply_model_params_to_body(
params: dict, form_data: dict, mappings: dict[str, Callable]
) -> dict:
if not params:
return form_data
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
return form_data
# inplace function: form_data is modified
def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
mappings = {
"temperature": float,
"top_p": int,
......@@ -147,10 +158,40 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
return apply_model_params_to_body(params, form_data, mappings)
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
opts = [
"temperature",
"top_p",
"seed",
"mirostat",
"mirostat_eta",
"mirostat_tau",
"num_ctx",
"num_batch",
"num_keep",
"repeat_last_n",
"tfs_z",
"top_k",
"min_p",
"use_mmap",
"use_mlock",
"num_thread",
"num_gpu",
]
mappings = {i: lambda x: x for i in opts}
form_data = apply_model_params_to_body(params, form_data, mappings)
name_differences = {
"max_tokens": "num_predict",
"frequency_penalty": "repeat_penalty",
}
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
for key, value in name_differences.items():
if (param := params.get(key, None)) is not None:
form_data[value] = param
return form_data
......
import inspect
from typing import get_type_hints, List, Dict, Any
from typing import get_type_hints
def doc_to_dict(docstring):
......@@ -16,7 +16,7 @@ def doc_to_dict(docstring):
return ret_dict
def get_tools_specs(tools) -> List[dict]:
def get_tools_specs(tools) -> list[dict]:
function_list = [
{"name": func, "function": getattr(tools, func)}
for func in dir(tools)
......
......@@ -38,9 +38,10 @@ describe('Settings', () => {
// User's message should be visible
cy.get('.chat-user').should('exist');
// Wait for the response
cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received
.find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received
.should('exist');
// .chat-assistant is created after the first token is received
cy.get('.chat-assistant', { timeout: 10_000 }).should('exist');
// Generation Info is created after the stop token is received
cy.get('div[aria-label="Generation Info"]', { timeout: 120_000 }).should('exist');
});
it('user can share chat', () => {
......@@ -57,21 +58,24 @@ describe('Settings', () => {
// User's message should be visible
cy.get('.chat-user').should('exist');
// Wait for the response
cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received
.find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received
.should('exist');
// .chat-assistant is created after the first token is received
cy.get('.chat-assistant', { timeout: 10_000 }).should('exist');
// Generation Info is created after the stop token is received
cy.get('div[aria-label="Generation Info"]', { timeout: 120_000 }).should('exist');
// spy on requests
const spy = cy.spy();
cy.intercept('GET', '/api/v1/chats/*', spy);
cy.intercept('POST', '/api/v1/chats/**/share', spy);
// Open context menu
cy.get('#chat-context-menu-button').click();
// Click share button
cy.get('#chat-share-button').click();
// Check if the share dialog is visible
cy.get('#copy-and-share-chat-button').should('exist');
cy.wrap({}, { timeout: 5000 }).should(() => {
// Check if the request was made twice (once for to replace chat object and once more due to change event)
expect(spy).to.be.callCount(2);
// Click the copy button
cy.get('#copy-and-share-chat-button').click();
cy.wrap({}, { timeout: 5_000 }).should(() => {
// Check if the share request was made
expect(spy).to.be.callCount(1);
});
});
......@@ -89,9 +93,10 @@ describe('Settings', () => {
// User's message should be visible
cy.get('.chat-user').should('exist');
// Wait for the response
cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received
.find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received
.should('exist');
// .chat-assistant is created after the first token is received
cy.get('.chat-assistant', { timeout: 10_000 }).should('exist');
// Generation Info is created after the stop token is received
cy.get('div[aria-label="Generation Info"]', { timeout: 120_000 }).should('exist');
// Click on the generate image button
cy.get('[aria-label="Generate Image"]').click();
// Wait for image to be visible
......
......@@ -22,7 +22,6 @@ Noticed something off? Have an idea? Check our [Issues tab](https://github.com/o
> [!IMPORTANT]
>
> - **Template Compliance:** Please be aware that failure to follow the provided issue template, or not providing the requested information at all, will likely result in your issue being closed without further consideration. This approach is critical for maintaining the manageability and integrity of issue tracking.
>
> - **Detail is Key:** To ensure your issue is understood and can be effectively addressed, it's imperative to include comprehensive details. Descriptions should be clear, including steps to reproduce, expected outcomes, and actual results. Lack of sufficient detail may hinder our ability to resolve your issue.
### 🧭 Scope of Support
......
This diff is collapsed.
{
"name": "open-webui",
"version": "0.3.12",
"version": "0.3.13",
"private": true,
"scripts": {
"dev": "npm run pyodide:fetch && vite dev --host",
......@@ -20,30 +20,31 @@
"pyodide:fetch": "node scripts/prepare-pyodide.js"
},
"devDependencies": {
"@sveltejs/adapter-auto": "^2.0.0",
"@sveltejs/adapter-static": "^2.0.3",
"@sveltejs/kit": "^1.30.0",
"@tailwindcss/typography": "^0.5.10",
"@sveltejs/adapter-auto": "3.2.2",
"@sveltejs/adapter-static": "^3.0.2",
"@sveltejs/kit": "^2.5.20",
"@sveltejs/vite-plugin-svelte": "^3.1.1",
"@tailwindcss/typography": "^0.5.13",
"@types/bun": "latest",
"@typescript-eslint/eslint-plugin": "^6.17.0",
"@typescript-eslint/parser": "^6.17.0",
"autoprefixer": "^10.4.16",
"cypress": "^13.8.1",
"eslint": "^8.56.0",
"eslint-config-prettier": "^8.5.0",
"eslint-plugin-cypress": "^3.0.2",
"eslint-plugin-svelte": "^2.30.0",
"i18next-parser": "^8.13.0",
"eslint-config-prettier": "^9.1.0",
"eslint-plugin-cypress": "^3.4.0",
"eslint-plugin-svelte": "^2.43.0",
"i18next-parser": "^9.0.1",
"postcss": "^8.4.31",
"prettier": "^2.8.0",
"prettier-plugin-svelte": "^2.10.1",
"svelte": "^4.0.5",
"svelte-check": "^3.4.3",
"prettier": "^3.3.3",
"prettier-plugin-svelte": "^3.2.6",
"svelte": "^4.2.18",
"svelte-check": "^3.8.5",
"svelte-confetti": "^1.3.2",
"tailwindcss": "^3.3.3",
"tslib": "^2.4.1",
"typescript": "^5.0.0",
"vite": "^4.4.2",
"typescript": "^5.5.4",
"vite": "^5.3.5",
"vitest": "^1.6.0"
},
"type": "module",
......@@ -52,7 +53,7 @@
"@codemirror/lang-python": "^6.1.6",
"@codemirror/theme-one-dark": "^6.1.2",
"@pyscript/core": "^0.4.32",
"@sveltejs/adapter-node": "^1.3.1",
"@sveltejs/adapter-node": "^2.0.0",
"async": "^3.2.5",
"bits-ui": "^0.19.7",
"codemirror": "^6.0.1",
......@@ -69,6 +70,7 @@
"js-sha256": "^0.10.1",
"katex": "^0.16.9",
"marked": "^9.1.0",
"marked-katex-extension": "^5.1.1",
"mermaid": "^10.9.1",
"pyodide": "^0.26.1",
"socket.io-client": "^4.2.0",
......@@ -77,5 +79,9 @@
"tippy.js": "^6.3.7",
"turndown": "^7.2.0",
"uuid": "^9.0.1"
},
"engines": {
"node": ">=18.13.0 <=21.x.x",
"npm": ">=6.0.0"
}
}
[project]
name = "open-webui"
description = "Open WebUI (Formerly Ollama WebUI)"
description = "Open WebUI"
authors = [
{ name = "Timothy Jaeryang Baek", email = "tim@openwebui.com" }
]
......@@ -19,7 +19,7 @@ dependencies = [
"passlib[bcrypt]==1.7.4",
"requests==2.32.3",
"aiohttp==3.9.5",
"aiohttp==3.10.2",
"sqlalchemy==2.0.31",
"alembic==1.13.2",
......@@ -41,12 +41,12 @@ dependencies = [
"google-generativeai==0.7.2",
"tiktoken",
"langchain==0.2.11",
"langchain==0.2.12",
"langchain-community==0.2.10",
"langchain-chroma==0.1.2",
"fake-useragent==1.5.1",
"chromadb==0.5.4",
"chromadb==0.5.5",
"sentence-transformers==3.0.1",
"pypdf==4.3.1",
"docx2txt==0.8",
......@@ -69,11 +69,11 @@ dependencies = [
"faster-whisper==1.0.2",
"PyJWT[crypto]==2.8.0",
"PyJWT[crypto]==2.9.0",
"authlib==1.3.1",
"black==24.8.0",
"langfuse==2.39.2",
"langfuse==2.43.3",
"youtube-transcript-api==0.6.2",
"pytube==15.0.0",
......
......@@ -10,7 +10,9 @@
# universal: false
-e file:.
aiohttp==3.9.5
aiohappyeyeballs==2.3.5
# via aiohttp
aiohttp==3.10.2
# via langchain
# via langchain-community
# via open-webui
......@@ -84,9 +86,9 @@ chardet==5.2.0
charset-normalizer==3.3.2
# via requests
# via unstructured-client
chroma-hnswlib==0.7.5
chroma-hnswlib==0.7.6
# via chromadb
chromadb==0.5.4
chromadb==0.5.5
# via langchain-chroma
# via open-webui
click==8.1.7
......@@ -269,7 +271,7 @@ jsonpointer==2.4
# via jsonpatch
kubernetes==29.0.0
# via chromadb
langchain==0.2.11
langchain==0.2.12
# via langchain-community
# via open-webui
langchain-chroma==0.1.2
......@@ -285,7 +287,7 @@ langchain-text-splitters==0.2.0
# via langchain
langdetect==1.0.9
# via unstructured
langfuse==2.39.2
langfuse==2.43.3
# via open-webui
langsmith==0.1.96
# via langchain
......@@ -491,7 +493,7 @@ pydub==0.25.1
# via open-webui
pygments==2.18.0
# via rich
pyjwt==2.8.0
pyjwt==2.9.0
# via open-webui
pymongo==4.8.0
# via open-webui
......
......@@ -10,7 +10,9 @@
# universal: false
-e file:.
aiohttp==3.9.5
aiohappyeyeballs==2.3.5
# via aiohttp
aiohttp==3.10.2
# via langchain
# via langchain-community
# via open-webui
......@@ -84,9 +86,9 @@ chardet==5.2.0
charset-normalizer==3.3.2
# via requests
# via unstructured-client
chroma-hnswlib==0.7.5
chroma-hnswlib==0.7.6
# via chromadb
chromadb==0.5.4
chromadb==0.5.5
# via langchain-chroma
# via open-webui
click==8.1.7
......@@ -269,7 +271,7 @@ jsonpointer==2.4
# via jsonpatch
kubernetes==29.0.0
# via chromadb
langchain==0.2.11
langchain==0.2.12
# via langchain-community
# via open-webui
langchain-chroma==0.1.2
......@@ -285,7 +287,7 @@ langchain-text-splitters==0.2.0
# via langchain
langdetect==1.0.9
# via unstructured
langfuse==2.39.2
langfuse==2.43.3
# via open-webui
langsmith==0.1.96
# via langchain
......@@ -491,7 +493,7 @@ pydub==0.25.1
# via open-webui
pygments==2.18.0
# via rich
pyjwt==2.8.0
pyjwt==2.9.0
# via open-webui
pymongo==4.8.0
# via open-webui
......
<!DOCTYPE html>
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
......
......@@ -69,6 +69,7 @@ type ChatCompletedForm = {
model: string;
messages: string[];
chat_id: string;
session_id: string;
};
export const chatCompleted = async (token: string, body: ChatCompletedForm) => {
......
import { OLLAMA_API_BASE_URL } from '$lib/constants';
import { titleGenerationTemplate } from '$lib/utils';
export const getOllamaConfig = async (token: string = '') => {
let error = null;
......@@ -203,55 +202,6 @@ export const getOllamaModels = async (token: string = '') => {
});
};
// TODO: migrate to backend
export const generateTitle = async (
token: string = '',
template: string,
model: string,
prompt: string
) => {
let error = null;
template = titleGenerationTemplate(template, prompt);
console.log(template);
const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
model: model,
prompt: template,
stream: false,
options: {
// Restrict the number of tokens generated to 50
num_predict: 50
}
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
}
return null;
});
if (error) {
throw error;
}
return res?.response.replace(/["']/g, '') ?? 'New Chat';
};
export const generatePrompt = async (token: string = '', model: string, conversation: string) => {
let error = null;
......
import { OPENAI_API_BASE_URL } from '$lib/constants';
import { titleGenerationTemplate } from '$lib/utils';
import { type Model, models, settings } from '$lib/stores';
export const getOpenAIConfig = async (token: string = '') => {
let error = null;
......@@ -260,7 +258,7 @@ export const getOpenAIModelsDirect = async (
throw error;
}
const models = Array.isArray(res) ? res : res?.data ?? null;
const models = Array.isArray(res) ? res : (res?.data ?? null);
return models
.map((model) => ({ id: model.id, name: model.name ?? model.id, external: true }))
......@@ -330,126 +328,3 @@ export const synthesizeOpenAISpeech = async (
return res;
};
export const generateTitle = async (
token: string = '',
template: string,
model: string,
prompt: string,
chat_id?: string,
url: string = OPENAI_API_BASE_URL
) => {
let error = null;
template = titleGenerationTemplate(template, prompt);
console.log(template);
const res = await fetch(`${url}/chat/completions`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
model: model,
messages: [
{
role: 'user',
content: template
}
],
stream: false,
// Restricting the max tokens to 50 to avoid long titles
max_tokens: 50,
...(chat_id && { chat_id: chat_id }),
title: true
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
}
return null;
});
if (error) {
throw error;
}
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? 'New Chat';
};
export const generateSearchQuery = async (
token: string = '',
model: string,
previousMessages: string[],
prompt: string,
url: string = OPENAI_API_BASE_URL
): Promise<string | undefined> => {
let error = null;
// TODO: Allow users to specify the prompt
// Get the current date in the format "January 20, 2024"
const currentDate = new Intl.DateTimeFormat('en-US', {
year: 'numeric',
month: 'long',
day: '2-digit'
}).format(new Date());
const res = await fetch(`${url}/chat/completions`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
model: model,
// Few shot prompting
messages: [
{
role: 'assistant',
content: `You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is ${currentDate}.`
},
{
role: 'user',
content: prompt
}
// {
// role: 'user',
// content:
// (previousMessages.length > 0
// ? `Previous Questions:\n${previousMessages.join('\n')}\n\n`
// : '') + `Current Question: ${prompt}`
// }
],
stream: false,
// Restricting the max tokens to 30 to avoid long search queries
max_tokens: 30
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
}
return undefined;
});
if (error) {
throw error;
}
return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? undefined;
};
......@@ -75,12 +75,12 @@
class="font-semibold uppercase text-xs {section === 'added'
? 'text-white bg-blue-600'
: section === 'fixed'
? 'text-white bg-green-600'
: section === 'changed'
? 'text-white bg-yellow-600'
: section === 'removed'
? 'text-white bg-red-600'
: ''} w-fit px-3 rounded-full my-2.5"
? 'text-white bg-green-600'
: section === 'changed'
? 'text-white bg-yellow-600'
: section === 'removed'
? 'text-white bg-red-600'
: ''} w-fit px-3 rounded-full my-2.5"
>
{section}
</div>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment