Commit 0210a105 authored by Jun Siang Cheah's avatar Jun Siang Cheah
Browse files

feat: experimental SSO support for Google, Microsoft, and OIDC

parent a842d8d6
"""Peewee migrations -- 011_add_user_oauth_sub.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"user",
oauth_sub=pw.TextField(null=True, unique=True),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("user", "oauth_sub")
from fastapi import FastAPI, Depends from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from apps.webui.routers import ( from apps.webui.routers import (
auths, auths,
users, users,
...@@ -24,6 +26,8 @@ from config import ( ...@@ -24,6 +26,8 @@ from config import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN, JWT_EXPIRES_IN,
AppConfig, AppConfig,
WEBUI_SECRET_KEY,
OAUTH_PROVIDERS,
) )
app = FastAPI() app = FastAPI()
...@@ -54,6 +58,12 @@ app.add_middleware( ...@@ -54,6 +58,12 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
app.add_middleware(
SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
)
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
......
...@@ -105,6 +105,7 @@ class AuthsTable: ...@@ -105,6 +105,7 @@ class AuthsTable:
name: str, name: str,
profile_image_url: str = "/user.png", profile_image_url: str = "/user.png",
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
log.info("insert_new_auth") log.info("insert_new_auth")
...@@ -115,7 +116,9 @@ class AuthsTable: ...@@ -115,7 +116,9 @@ class AuthsTable:
) )
result = Auth.create(**auth.model_dump()) result = Auth.create(**auth.model_dump())
user = Users.insert_new_user(id, name, email, profile_image_url, role) user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub
)
if result and user: if result and user:
return user return user
......
...@@ -26,6 +26,8 @@ class User(Model): ...@@ -26,6 +26,8 @@ class User(Model):
api_key = CharField(null=True, unique=True) api_key = CharField(null=True, unique=True)
oauth_sub = TextField(null=True, unique=True)
class Meta: class Meta:
database = DB database = DB
...@@ -43,6 +45,8 @@ class UserModel(BaseModel): ...@@ -43,6 +45,8 @@ class UserModel(BaseModel):
api_key: Optional[str] = None api_key: Optional[str] = None
oauth_sub: Optional[str] = None
#################### ####################
# Forms # Forms
...@@ -73,6 +77,7 @@ class UsersTable: ...@@ -73,6 +77,7 @@ class UsersTable:
email: str, email: str,
profile_image_url: str = "/user.png", profile_image_url: str = "/user.png",
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
user = UserModel( user = UserModel(
**{ **{
...@@ -84,6 +89,7 @@ class UsersTable: ...@@ -84,6 +89,7 @@ class UsersTable:
"last_active_at": int(time.time()), "last_active_at": int(time.time()),
"created_at": int(time.time()), "created_at": int(time.time()),
"updated_at": int(time.time()), "updated_at": int(time.time()),
"oauth_sub": oauth_sub,
} }
) )
result = User.create(**user.model_dump()) result = User.create(**user.model_dump())
...@@ -113,6 +119,13 @@ class UsersTable: ...@@ -113,6 +119,13 @@ class UsersTable:
except: except:
return None return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try:
user = User.get(User.oauth_sub == sub)
return UserModel(**model_to_dict(user))
except:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [ return [
UserModel(**model_to_dict(user)) UserModel(**model_to_dict(user))
......
import logging import logging
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from fastapi import Request, UploadFile, File from fastapi import Request, UploadFile, File
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
...@@ -9,6 +11,7 @@ import re ...@@ -9,6 +11,7 @@ import re
import uuid import uuid
import csv import csv
from starlette.responses import RedirectResponse
from apps.webui.models.auths import ( from apps.webui.models.auths import (
SigninForm, SigninForm,
...@@ -33,7 +36,12 @@ from utils.utils import ( ...@@ -33,7 +36,12 @@ from utils.utils import (
from utils.misc import parse_duration, validate_email_format from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER from config import (
WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
OAUTH_PROVIDERS,
ENABLE_OAUTH_SIGNUP,
)
router = APIRouter() router = APIRouter()
...@@ -373,3 +381,82 @@ async def get_api_key(user=Depends(get_current_user)): ...@@ -373,3 +381,82 @@ async def get_api_key(user=Depends(get_current_user)):
} }
else: else:
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
############################
# OAuth Login & Callback
############################
oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
)
@router.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
redirect_uri = request.url_for("oauth_callback", provider=provider)
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
@router.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = oauth.create_client(provider)
token = await client.authorize_access_token(request)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, create a new user if signup is enabled
if ENABLE_OAUTH_SIGNUP.value:
user = Auths.insert_new_auth(
email=user_data.get("email", "").lower(),
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get("name", "User"),
profile_image_url=user_data.get("picture", "/user.png"),
role=request.app.state.config.DEFAULT_USER_ROLE,
oauth_sub=provider_sub,
)
if request.app.state.config.WEBHOOK_URL:
post_webhook(
request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
...@@ -285,6 +285,52 @@ JWT_EXPIRES_IN = PersistentConfig( ...@@ -285,6 +285,52 @@ JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
) )
####################################
# OAuth config
####################################
ENABLE_OAUTH_SIGNUP = PersistentConfig(
"ENABLE_OAUTH_SIGNUP",
"oauth.enable_signup",
os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true",
)
OAUTH_PROVIDERS = {}
if os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET"):
OAUTH_PROVIDERS["google"] = {
"client_id": os.environ.get("GOOGLE_CLIENT_ID"),
"client_secret": os.environ.get("GOOGLE_CLIENT_SECRET"),
"server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
"scope": os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"),
}
if (
os.environ.get("MICROSOFT_CLIENT_ID")
and os.environ.get("MICROSOFT_CLIENT_SECRET")
and os.environ.get("MICROSOFT_CLIENT_TENANT_ID")
):
OAUTH_PROVIDERS["microsoft"] = {
"client_id": os.environ.get("MICROSOFT_CLIENT_ID"),
"client_secret": os.environ.get("MICROSOFT_CLIENT_SECRET"),
"server_metadata_url": f"https://login.microsoftonline.com/{os.environ.get('MICROSOFT_CLIENT_TENANT_ID')}/v2.0/.well-known/openid-configuration",
"scope": os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"),
}
if (
os.environ.get("OPENID_CLIENT_ID")
and os.environ.get("OPENID_CLIENT_SECRET")
and os.environ.get("OPENID_PROVIDER_URL")
):
OAUTH_PROVIDERS["oidc"] = {
"client_id": os.environ.get("OPENID_CLIENT_ID"),
"client_secret": os.environ.get("OPENID_CLIENT_SECRET"),
"server_metadata_url": os.environ.get("OPENID_PROVIDER_URL"),
"scope": os.environ.get("OPENID_SCOPE", "openid email profile"),
"name": os.environ.get("OPENID_PROVIDER_NAME", "SSO"),
}
#################################### ####################################
# Static DIR # Static DIR
#################################### ####################################
......
...@@ -55,6 +55,7 @@ from config import ( ...@@ -55,6 +55,7 @@ from config import (
WEBHOOK_URL, WEBHOOK_URL,
ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_EXPORT,
AppConfig, AppConfig,
OAUTH_PROVIDERS,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -364,6 +365,13 @@ async def get_app_config(): ...@@ -364,6 +365,13 @@ async def get_app_config():
"default_locale": default_locale, "default_locale": default_locale,
"default_models": webui_app.state.config.DEFAULT_MODELS, "default_models": webui_app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"oauth": {
"providers": {
name: config.get("name", name)
for name, config in OAUTH_PROVIDERS.items()
}
},
} }
......
...@@ -134,7 +134,12 @@ type Config = { ...@@ -134,7 +134,12 @@ type Config = {
default_models?: string[]; default_models?: string[];
default_prompt_suggestions?: PromptSuggestion[]; default_prompt_suggestions?: PromptSuggestion[];
auth_trusted_header?: boolean; auth_trusted_header?: boolean;
model_config?: GlobalModelConfig; auth: boolean;
oauth: {
providers: {
[key: string]: string;
};
};
}; };
type PromptSuggestion = { type PromptSuggestion = {
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import { onMount, tick, setContext } from 'svelte'; import { onMount, tick, setContext } from 'svelte';
import { config, user, theme, WEBUI_NAME, mobile } from '$lib/stores'; import { config, user, theme, WEBUI_NAME, mobile } from '$lib/stores';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { page } from '$app/stores';
import { Toaster, toast } from 'svelte-sonner'; import { Toaster, toast } from 'svelte-sonner';
import { getBackendConfig } from '$lib/apis'; import { getBackendConfig } from '$lib/apis';
...@@ -75,9 +76,13 @@ ...@@ -75,9 +76,13 @@
await goto('/auth'); await goto('/auth');
} }
} else { } else {
// Don't redirect if we're already on the auth page
// Needed because we pass in tokens from OAuth logins via URL fragments
if ($page.url.pathname !== '/auth') {
await goto('/auth'); await goto('/auth');
} }
} }
}
} else { } else {
// Redirect to /error when Backend Not Detected // Redirect to /error when Backend Not Detected
await goto(`/error`); await goto(`/error`);
......
<script> <script>
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { userSignIn, userSignUp } from '$lib/apis/auths'; import { getSessionUser, userSignIn, userSignUp } from '$lib/apis/auths';
import Spinner from '$lib/components/common/Spinner.svelte'; import Spinner from '$lib/components/common/Spinner.svelte';
import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
import { WEBUI_NAME, config, user } from '$lib/stores'; import { WEBUI_NAME, config, user } from '$lib/stores';
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import { generateInitialsImage, canvasPixelTest } from '$lib/utils'; import { generateInitialsImage, canvasPixelTest } from '$lib/utils';
import { page } from '$app/stores';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
...@@ -21,7 +22,9 @@ ...@@ -21,7 +22,9 @@
if (sessionUser) { if (sessionUser) {
console.log(sessionUser); console.log(sessionUser);
toast.success($i18n.t(`You're now logged in.`)); toast.success($i18n.t(`You're now logged in.`));
if (sessionUser.token) {
localStorage.token = sessionUser.token; localStorage.token = sessionUser.token;
}
await user.set(sessionUser); await user.set(sessionUser);
goto('/'); goto('/');
} }
...@@ -55,10 +58,35 @@ ...@@ -55,10 +58,35 @@
} }
}; };
const checkOauthCallback = async () => {
if (!$page.url.hash) {
return;
}
const hash = $page.url.hash.substring(1);
if (!hash) {
return;
}
const params = new URLSearchParams(hash);
const token = params.get('token');
if (!token) {
return;
}
const sessionUser = await getSessionUser(token).catch((error) => {
toast.error(error);
return null;
});
if (!sessionUser) {
return;
}
localStorage.token = token;
await setSessionUser(sessionUser);
};
onMount(async () => { onMount(async () => {
if ($user !== undefined) { if ($user !== undefined) {
await goto('/'); await goto('/');
} }
await checkOauthCallback();
loaded = true; loaded = true;
if (($config?.auth_trusted_header ?? false) || $config?.auth === false) { if (($config?.auth_trusted_header ?? false) || $config?.auth === false) {
await signInHandler(); await signInHandler();
...@@ -217,6 +245,97 @@ ...@@ -217,6 +245,97 @@
{/if} {/if}
</div> </div>
</form> </form>
{#if Object.keys($config?.oauth?.providers ?? {}).length > 0 }
<div class="inline-flex items-center justify-center w-full">
<hr class="w-64 h-px my-8 bg-gray-200 border-0 dark:bg-gray-700" />
<span
class="absolute px-3 font-medium text-gray-900 -translate-x-1/2 bg-white left-1/2 dark:text-white dark:bg-gray-950"
>{$i18n.t('or')}</span
>
</div>
<div class="flex flex-col space-y-2">
{#if $config?.oauth?.providers?.google }
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/google/login`;
}}
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
<path
fill="#EA4335"
d="M24 9.5c3.54 0 6.71 1.22 9.21 3.6l6.85-6.85C35.9 2.38 30.47 0 24 0 14.62 0 6.51 5.38 2.56 13.22l7.98 6.19C12.43 13.72 17.74 9.5 24 9.5z"
/><path
fill="#4285F4"
d="M46.98 24.55c0-1.57-.15-3.09-.38-4.55H24v9.02h12.94c-.58 2.96-2.26 5.48-4.78 7.18l7.73 6c4.51-4.18 7.09-10.36 7.09-17.65z"
/><path
fill="#FBBC05"
d="M10.53 28.59c-.48-1.45-.76-2.99-.76-4.59s.27-3.14.76-4.59l-7.98-6.19C.92 16.46 0 20.12 0 24c0 3.88.92 7.54 2.56 10.78l7.97-6.19z"
/><path
fill="#34A853"
d="M24 48c6.48 0 11.93-2.13 15.89-5.81l-7.73-6c-2.15 1.45-4.92 2.3-8.16 2.3-6.26 0-11.57-4.22-13.47-9.91l-7.98 6.19C6.51 42.62 14.62 48 24 48z"
/><path fill="none" d="M0 0h48v48H0z" />
</svg>
<span>{$i18n.t('Continue with {{provider}}', { provider: 'Google' })}</span>
</button>
{/if}
{#if $config?.oauth?.providers?.microsoft }
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/microsoft/login`;
}}
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
<rect x="1" y="1" width="9" height="9" fill="#f25022" /><rect
x="1"
y="11"
width="9"
height="9"
fill="#00a4ef"
/><rect x="11" y="1" width="9" height="9" fill="#7fba00" /><rect
x="11"
y="11"
width="9"
height="9"
fill="#ffb900"
/>
</svg>
<span>{$i18n.t('Continue with {{provider}}', { provider: 'Microsoft' })}</span>
</button>
{/if}
{#if $config?.oauth?.providers?.oidc }
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/oidc/login`;
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="size-6 mr-3"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M15.75 5.25a3 3 0 0 1 3 3m3 0a6 6 0 0 1-7.029 5.912c-.563-.097-1.159.026-1.563.43L10.5 17.25H8.25v2.25H6v2.25H2.25v-2.818c0-.597.237-1.17.659-1.591l6.499-6.499c.404-.404.527-1 .43-1.563A6 6 0 1 1 21.75 8.25Z"
/>
</svg>
<span
>{$i18n.t('Continue with {{provider}}', {
provider: $config?.oauth?.providers?.oidc ?? 'SSO'
})}</span
>
</button>
{/if}
</div>
{/if}
</div> </div>
{/if} {/if}
</div> </div>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment