Unverified Commit e6c65088 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #381 from ollama-webui/openai-backend

feat: openai backend support
parents 31fcb9d6 2fedd42e
......@@ -16,6 +16,10 @@ ARG OLLAMA_API_BASE_URL='/ollama/api'
ENV ENV=prod
ENV OLLAMA_API_BASE_URL $OLLAMA_API_BASE_URL
ENV OPENAI_API_BASE_URL ""
ENV OPENAI_API_KEY ""
ENV WEBUI_JWT_SECRET_KEY "SECRET_KEY"
WORKDIR /app
......
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
import requests
import json
from pydantic import BaseModel
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = OPENAI_API_KEY
class UrlUpdateForm(BaseModel):
url: str
class KeyUpdateForm(BaseModel):
key: str
@app.get("/url")
async def get_openai_url(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)):
if user and user.role == "admin":
app.state.OPENAI_API_BASE_URL = form_data.url
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.get("/key")
async def get_openai_key(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/key/update")
async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)):
if user and user.role == "admin":
app.state.OPENAI_API_KEY = form_data.key
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
print(target_url, app.state.OPENAI_API_KEY)
if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
body = await request.body()
# headers = dict(request.headers)
# print(headers)
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
try:
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
else:
# For non-SSE, read the response and return it
# response_data = (
# r.json()
# if r.headers.get("Content-Type", "")
# == "application/json"
# else r.text
# )
response_data = r.json()
print(type(response_data))
if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
response_data["data"] = list(
filter(lambda model: "gpt" in model["id"], response_data["data"])
)
return response_data
except Exception as e:
print(e)
error_detail = "Ollama WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status_code, detail=error_detail)
......@@ -27,11 +27,22 @@ if ENV == "prod":
if OLLAMA_API_BASE_URL == "/ollama/api":
OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api"
####################################
# OPENAI_API
####################################
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
if OPENAI_API_BASE_URL == "":
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
####################################
# WEBUI_VERSION
####################################
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.42")
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.50")
####################################
# WEBUI_AUTH (Required for security)
......
......@@ -33,4 +33,5 @@ class ERROR_MESSAGES(str, Enum):
)
NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
MALICIOUS = "Unusual activities detected, please try again in a few minutes."
......@@ -6,6 +6,8 @@ from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
from apps.web.main import app as webui_app
import time
......@@ -46,7 +48,7 @@ async def check_url(request: Request, call_next):
app.mount("/api/v1", webui_app)
# app.mount("/ollama/api", WSGIMiddleware(ollama_app))
app.mount("/ollama/api", ollama_app)
app.mount("/openai/api", openai_app)
app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files")
# If you're serving both the frontend and backend (Recommended)
# Set the public API base URL for seamless communication
PUBLIC_API_BASE_URL='/ollama/api'
# If you're serving only the frontend (Not recommended and not fully supported)
# Comment above and Uncomment below
# You can use the default value or specify a custom path, e.g., '/api'
# PUBLIC_API_BASE_URL='http://{location.hostname}:11434/api'
# Ollama URL for the backend to connect
# The path '/ollama/api' will be redirected to the specified backend URL
OLLAMA_API_BASE_URL='http://localhost:11434/api'
OPENAI_API_BASE_URL=''
OPENAI_API_KEY=''
\ No newline at end of file
export const getOpenAIModels = async (
import { OPENAI_API_BASE_URL } from '$lib/constants';
export const getOpenAIUrl = async (token: string = '') => {
let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/url`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res.OPENAI_API_BASE_URL;
};
export const updateOpenAIUrl = async (token: string = '', url: string) => {
let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/url/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
url: url
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res.OPENAI_API_BASE_URL;
};
export const getOpenAIKey = async (token: string = '') => {
let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/key`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res.OPENAI_API_KEY;
};
export const updateOpenAIKey = async (token: string = '', key: string) => {
let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/key/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
key: key
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res.OPENAI_API_KEY;
};
export const getOpenAIModels = async (token: string = '') => {
let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/models`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`;
return [];
});
if (error) {
throw error;
}
const models = Array.isArray(res) ? res : res?.data ?? null;
return models
? models
.map((model) => ({ name: model.id, external: true }))
.sort((a, b) => {
return a.name.localeCompare(b.name);
})
: models;
};
export const getOpenAIModelsDirect = async (
base_url: string = 'https://api.openai.com/v1',
api_key: string = ''
) => {
......@@ -34,3 +206,26 @@ export const getOpenAIModels = async (
return a.name.localeCompare(b.name);
});
};
export const generateOpenAIChatCompletion = async (token: string = '', body: object) => {
let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/chat/completions`, {
method: 'POST',
headers: {
Authorization: `Bearer ${token}`,
'Content-Type': 'application/json'
},
body: JSON.stringify(body)
}).catch((err) => {
console.log(err);
error = err;
return null;
});
if (error) {
throw error;
}
return res;
};
......@@ -27,7 +27,7 @@
>
{#if model in modelfiles}
<img
src={modelfiles[model]?.imageUrl}
src={modelfiles[model]?.imageUrl ?? '/ollama-dark.png'}
alt="modelfile"
class=" w-20 mb-2 rounded-full {models.length > 1
? ' border-[5px] border-white dark:border-gray-800'
......
......@@ -24,6 +24,13 @@
import { updateUserPassword } from '$lib/apis/auths';
import { goto } from '$app/navigation';
import Page from '../../../routes/(app)/+page.svelte';
import {
getOpenAIKey,
getOpenAIModels,
getOpenAIUrl,
updateOpenAIKey,
updateOpenAIUrl
} from '$lib/apis/openai';
export let show = false;
......@@ -153,6 +160,13 @@
}
};
const updateOpenAIHandler = async () => {
OPENAI_API_BASE_URL = await updateOpenAIUrl(localStorage.token, OPENAI_API_BASE_URL);
OPENAI_API_KEY = await updateOpenAIKey(localStorage.token, OPENAI_API_KEY);
await models.set(await getModels());
};
const toggleTheme = async () => {
if (theme === 'dark') {
theme = 'light';
......@@ -484,7 +498,7 @@
};
const getModels = async (type = 'all') => {
let models = [];
const models = [];
models.push(
...(await getOllamaModels(localStorage.token).catch((error) => {
toast.error(error);
......@@ -493,43 +507,13 @@
);
// If OpenAI API Key exists
if (type === 'all' && $settings.OPENAI_API_KEY) {
const OPENAI_API_BASE_URL = $settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1';
// Validate OPENAI_API_KEY
const openaiModelRes = await fetch(`${OPENAI_API_BASE_URL}/models`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${$settings.OPENAI_API_KEY}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((error) => {
if (type === 'all' && OPENAI_API_KEY) {
const openAIModels = await getOpenAIModels(localStorage.token).catch((error) => {
console.log(error);
toast.error(`OpenAI: ${error?.error?.message ?? 'Network Problem'}`);
return null;
});
const openAIModels = Array.isArray(openaiModelRes)
? openaiModelRes
: openaiModelRes?.data ?? null;
models.push(
...(openAIModels
? [
{ name: 'hr' },
...openAIModels
.map((model) => ({ name: model.id, external: true }))
.filter((model) =>
OPENAI_API_BASE_URL.includes('openai') ? model.name.includes('gpt') : true
)
]
: [])
);
models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : []));
}
return models;
......@@ -564,6 +548,8 @@
console.log('settings', $user.role === 'admin');
if ($user.role === 'admin') {
API_BASE_URL = await getOllamaAPIUrl(localStorage.token);
OPENAI_API_BASE_URL = await getOpenAIUrl(localStorage.token);
OPENAI_API_KEY = await getOpenAIKey(localStorage.token);
}
let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
......@@ -584,9 +570,6 @@
options = { ...options, ...settings.options };
options.stop = (settings?.options?.stop ?? []).join(',');
OPENAI_API_KEY = settings.OPENAI_API_KEY ?? '';
OPENAI_API_BASE_URL = settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1';
titleAutoGenerate = settings.titleAutoGenerate ?? true;
speechAutoSend = settings.speechAutoSend ?? false;
responseAutoCopy = settings.responseAutoCopy ?? false;
......@@ -709,7 +692,6 @@
</div>
<div class=" self-center">Models</div>
</button>
{/if}
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
......@@ -734,6 +716,7 @@
</div>
<div class=" self-center">External</div>
</button>
{/if}
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
......@@ -1415,10 +1398,12 @@
<form
class="flex flex-col h-full justify-between space-y-3 text-sm"
on:submit|preventDefault={() => {
saveSettings({
OPENAI_API_KEY: OPENAI_API_KEY !== '' ? OPENAI_API_KEY : undefined,
OPENAI_API_BASE_URL: OPENAI_API_BASE_URL !== '' ? OPENAI_API_BASE_URL : undefined
});
updateOpenAIHandler();
// saveSettings({
// OPENAI_API_KEY: OPENAI_API_KEY !== '' ? OPENAI_API_KEY : undefined,
// OPENAI_API_BASE_URL: OPENAI_API_BASE_URL !== '' ? OPENAI_API_BASE_URL : undefined
// });
show = false;
}}
>
......
import { dev } from '$app/environment';
export const OLLAMA_API_BASE_URL = dev
? `http://${location.hostname}:8080/ollama/api`
: '/ollama/api';
export const WEBUI_BASE_URL = dev ? `http://${location.hostname}:8080` : ``;
export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`;
export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`;
export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`;
export const WEB_UI_VERSION = 'v1.0.0-alpha-static';
......
......@@ -37,19 +37,17 @@
return [];
}))
);
// If OpenAI API Key exists
if ($settings.OPENAI_API_KEY) {
const openAIModels = await getOpenAIModels(
$settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1',
$settings.OPENAI_API_KEY
).catch((error) => {
// $settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1',
// $settings.OPENAI_API_KEY
const openAIModels = await getOpenAIModels(localStorage.token).catch((error) => {
console.log(error);
toast.error(error);
return null;
});
models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : []));
}
return models;
};
......
......@@ -16,6 +16,7 @@
import ModelSelector from '$lib/components/chat/ModelSelector.svelte';
import Navbar from '$lib/components/layout/Navbar.svelte';
import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats';
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
let stopResponseFlag = false;
let autoScroll = true;
......@@ -321,8 +322,6 @@
};
const sendPromptOpenAI = async (model, userPrompt, parentId, _chatId) => {
if ($settings.OPENAI_API_KEY) {
if (models) {
let responseMessageId = uuidv4();
let responseMessage = {
......@@ -345,15 +344,7 @@
window.scrollTo({ top: document.body.scrollHeight });
const res = await fetch(
`${$settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1'}/chat/completions`,
{
method: 'POST',
headers: {
Authorization: `Bearer ${$settings.OPENAI_API_KEY}`,
'Content-Type': 'application/json'
},
body: JSON.stringify({
const res = await generateOpenAIChatCompletion(localStorage.token, {
model: model,
stream: true,
messages: [
......@@ -394,11 +385,6 @@
num_ctx: $settings?.options?.num_ctx ?? undefined,
frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
max_tokens: $settings?.options?.num_predict ?? undefined
})
}
).catch((err) => {
console.log(err);
return null;
});
if (res && res.ok) {
......@@ -502,8 +488,6 @@
window.history.replaceState(history.state, '', `/c/${_chatId}`);
await setChatTitle(_chatId, userPrompt);
}
}
}
};
const submitPrompt = async (userPrompt) => {
......
......@@ -9,6 +9,8 @@
import { models, modelfiles, user, settings, chats, chatId } from '$lib/stores';
import { generateChatCompletion, generateTitle } from '$lib/apis/ollama';
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
import { copyToClipboard, splitStream } from '$lib/utils';
import MessageInput from '$lib/components/chat/MessageInput.svelte';
......@@ -338,8 +340,6 @@
};
const sendPromptOpenAI = async (model, userPrompt, parentId, _chatId) => {
if ($settings.OPENAI_API_KEY) {
if (models) {
let responseMessageId = uuidv4();
let responseMessage = {
......@@ -362,15 +362,7 @@
window.scrollTo({ top: document.body.scrollHeight });
const res = await fetch(
`${$settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1'}/chat/completions`,
{
method: 'POST',
headers: {
Authorization: `Bearer ${$settings.OPENAI_API_KEY}`,
'Content-Type': 'application/json'
},
body: JSON.stringify({
const res = await generateOpenAIChatCompletion(localStorage.token, {
model: model,
stream: true,
messages: [
......@@ -411,11 +403,6 @@
num_ctx: $settings?.options?.num_ctx ?? undefined,
frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
max_tokens: $settings?.options?.num_predict ?? undefined
})
}
).catch((err) => {
console.log(err);
return null;
});
if (res && res.ok) {
......@@ -519,8 +506,6 @@
window.history.replaceState(history.state, '', `/c/${_chatId}`);
await setChatTitle(_chatId, userPrompt);
}
}
}
};
const submitPrompt = async (userPrompt) => {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment