Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
0210a105
Commit
0210a105
authored
May 26, 2024
by
Jun Siang Cheah
Browse files
feat: experimental SSO support for Google, Microsoft, and OIDC
parent
a842d8d6
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
351 additions
and
6 deletions
+351
-6
backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py
.../apps/webui/internal/migrations/011_add_user_oauth_sub.py
+49
-0
backend/apps/webui/main.py
backend/apps/webui/main.py
+10
-0
backend/apps/webui/models/auths.py
backend/apps/webui/models/auths.py
+4
-1
backend/apps/webui/models/users.py
backend/apps/webui/models/users.py
+13
-0
backend/apps/webui/routers/auths.py
backend/apps/webui/routers/auths.py
+88
-1
backend/config.py
backend/config.py
+46
-0
backend/main.py
backend/main.py
+8
-0
src/lib/stores/index.ts
src/lib/stores/index.ts
+6
-1
src/routes/+layout.svelte
src/routes/+layout.svelte
+6
-1
src/routes/auth/+page.svelte
src/routes/auth/+page.svelte
+121
-2
No files found.
backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py
0 → 100644
View file @
0210a105
"""Peewee migrations -- 011_add_user_oauth_sub.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from
contextlib
import
suppress
import
peewee
as
pw
from
peewee_migrate
import
Migrator
with
suppress
(
ImportError
):
import
playhouse.postgres_ext
as
pw_pext
def
migrate
(
migrator
:
Migrator
,
database
:
pw
.
Database
,
*
,
fake
=
False
):
"""Write your migrations here."""
migrator
.
add_fields
(
"user"
,
oauth_sub
=
pw
.
TextField
(
null
=
True
,
unique
=
True
),
)
def
rollback
(
migrator
:
Migrator
,
database
:
pw
.
Database
,
*
,
fake
=
False
):
"""Write your rollback migrations here."""
migrator
.
remove_fields
(
"user"
,
"oauth_sub"
)
backend/apps/webui/main.py
View file @
0210a105
from
fastapi
import
FastAPI
,
Depends
from
fastapi.routing
import
APIRoute
from
fastapi.middleware.cors
import
CORSMiddleware
from
starlette.middleware.sessions
import
SessionMiddleware
from
apps.webui.routers
import
(
auths
,
users
,
...
...
@@ -24,6 +26,8 @@ from config import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
,
JWT_EXPIRES_IN
,
AppConfig
,
WEBUI_SECRET_KEY
,
OAUTH_PROVIDERS
,
)
app
=
FastAPI
()
...
...
@@ -54,6 +58,12 @@ app.add_middleware(
allow_headers
=
[
"*"
],
)
# SessionMiddleware is used by authlib for oauth
if
len
(
OAUTH_PROVIDERS
)
>
0
:
app
.
add_middleware
(
SessionMiddleware
,
secret_key
=
WEBUI_SECRET_KEY
,
session_cookie
=
"oui-session"
)
app
.
include_router
(
auths
.
router
,
prefix
=
"/auths"
,
tags
=
[
"auths"
])
app
.
include_router
(
users
.
router
,
prefix
=
"/users"
,
tags
=
[
"users"
])
app
.
include_router
(
chats
.
router
,
prefix
=
"/chats"
,
tags
=
[
"chats"
])
...
...
backend/apps/webui/models/auths.py
View file @
0210a105
...
...
@@ -105,6 +105,7 @@ class AuthsTable:
name
:
str
,
profile_image_url
:
str
=
"/user.png"
,
role
:
str
=
"pending"
,
oauth_sub
:
Optional
[
str
]
=
None
,
)
->
Optional
[
UserModel
]:
log
.
info
(
"insert_new_auth"
)
...
...
@@ -115,7 +116,9 @@ class AuthsTable:
)
result
=
Auth
.
create
(
**
auth
.
model_dump
())
user
=
Users
.
insert_new_user
(
id
,
name
,
email
,
profile_image_url
,
role
)
user
=
Users
.
insert_new_user
(
id
,
name
,
email
,
profile_image_url
,
role
,
oauth_sub
)
if
result
and
user
:
return
user
...
...
backend/apps/webui/models/users.py
View file @
0210a105
...
...
@@ -26,6 +26,8 @@ class User(Model):
api_key
=
CharField
(
null
=
True
,
unique
=
True
)
oauth_sub
=
TextField
(
null
=
True
,
unique
=
True
)
class
Meta
:
database
=
DB
...
...
@@ -43,6 +45,8 @@ class UserModel(BaseModel):
api_key
:
Optional
[
str
]
=
None
oauth_sub
:
Optional
[
str
]
=
None
####################
# Forms
...
...
@@ -73,6 +77,7 @@ class UsersTable:
email
:
str
,
profile_image_url
:
str
=
"/user.png"
,
role
:
str
=
"pending"
,
oauth_sub
:
Optional
[
str
]
=
None
,
)
->
Optional
[
UserModel
]:
user
=
UserModel
(
**
{
...
...
@@ -84,6 +89,7 @@ class UsersTable:
"last_active_at"
:
int
(
time
.
time
()),
"created_at"
:
int
(
time
.
time
()),
"updated_at"
:
int
(
time
.
time
()),
"oauth_sub"
:
oauth_sub
,
}
)
result
=
User
.
create
(
**
user
.
model_dump
())
...
...
@@ -113,6 +119,13 @@ class UsersTable:
except
:
return
None
def
get_user_by_oauth_sub
(
self
,
sub
:
str
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
(
User
.
oauth_sub
==
sub
)
return
UserModel
(
**
model_to_dict
(
user
))
except
:
return
None
def
get_users
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
UserModel
]:
return
[
UserModel
(
**
model_to_dict
(
user
))
...
...
backend/apps/webui/routers/auths.py
View file @
0210a105
import
logging
from
authlib.integrations.starlette_client
import
OAuth
from
authlib.oidc.core
import
UserInfo
from
fastapi
import
Request
,
UploadFile
,
File
from
fastapi
import
Depends
,
HTTPException
,
status
...
...
@@ -9,6 +11,7 @@ import re
import
uuid
import
csv
from
starlette.responses
import
RedirectResponse
from
apps.webui.models.auths
import
(
SigninForm
,
...
...
@@ -33,7 +36,12 @@ from utils.utils import (
from
utils.misc
import
parse_duration
,
validate_email_format
from
utils.webhook
import
post_webhook
from
constants
import
ERROR_MESSAGES
,
WEBHOOK_MESSAGES
from
config
import
WEBUI_AUTH
,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
from
config
import
(
WEBUI_AUTH
,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
,
OAUTH_PROVIDERS
,
ENABLE_OAUTH_SIGNUP
,
)
router
=
APIRouter
()
...
...
@@ -373,3 +381,82 @@ async def get_api_key(user=Depends(get_current_user)):
}
else
:
raise
HTTPException
(
404
,
detail
=
ERROR_MESSAGES
.
API_KEY_NOT_FOUND
)
############################
# OAuth Login & Callback
############################
oauth
=
OAuth
()
for
provider_name
,
provider_config
in
OAUTH_PROVIDERS
.
items
():
oauth
.
register
(
name
=
provider_name
,
client_id
=
provider_config
[
"client_id"
],
client_secret
=
provider_config
[
"client_secret"
],
server_metadata_url
=
provider_config
[
"server_metadata_url"
],
client_kwargs
=
{
"scope"
:
provider_config
[
"scope"
],
},
)
@
router
.
get
(
"/oauth/{provider}/login"
)
async
def
oauth_login
(
provider
:
str
,
request
:
Request
):
if
provider
not
in
OAUTH_PROVIDERS
:
raise
HTTPException
(
404
)
redirect_uri
=
request
.
url_for
(
"oauth_callback"
,
provider
=
provider
)
return
await
oauth
.
create_client
(
provider
).
authorize_redirect
(
request
,
redirect_uri
)
@
router
.
get
(
"/oauth/{provider}/callback"
)
async
def
oauth_callback
(
provider
:
str
,
request
:
Request
):
if
provider
not
in
OAUTH_PROVIDERS
:
raise
HTTPException
(
404
)
client
=
oauth
.
create_client
(
provider
)
token
=
await
client
.
authorize_access_token
(
request
)
user_data
:
UserInfo
=
token
[
"userinfo"
]
sub
=
user_data
.
get
(
"sub"
)
if
not
sub
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
provider_sub
=
f
"
{
provider
}
@
{
sub
}
"
# Check if the user exists
user
=
Users
.
get_user_by_oauth_sub
(
provider_sub
)
if
not
user
:
# If the user does not exist, create a new user if signup is enabled
if
ENABLE_OAUTH_SIGNUP
.
value
:
user
=
Auths
.
insert_new_auth
(
email
=
user_data
.
get
(
"email"
,
""
).
lower
(),
password
=
get_password_hash
(
str
(
uuid
.
uuid4
())
),
# Random password, not used
name
=
user_data
.
get
(
"name"
,
"User"
),
profile_image_url
=
user_data
.
get
(
"picture"
,
"/user.png"
),
role
=
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
,
oauth_sub
=
provider_sub
,
)
if
request
.
app
.
state
.
config
.
WEBHOOK_URL
:
post_webhook
(
request
.
app
.
state
.
config
.
WEBHOOK_URL
,
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
{
"action"
:
"signup"
,
"message"
:
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
"user"
:
user
.
model_dump_json
(
exclude_none
=
True
),
},
)
else
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
jwt_token
=
create_token
(
data
=
{
"id"
:
user
.
id
},
expires_delta
=
parse_duration
(
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
),
)
# Redirect back to the frontend with the JWT token
redirect_url
=
f
"
{
request
.
base_url
}
auth#token=
{
jwt_token
}
"
return
RedirectResponse
(
url
=
redirect_url
)
backend/config.py
View file @
0210a105
...
...
@@ -285,6 +285,52 @@ JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN"
,
"auth.jwt_expiry"
,
os
.
environ
.
get
(
"JWT_EXPIRES_IN"
,
"-1"
)
)
####################################
# OAuth config
####################################
ENABLE_OAUTH_SIGNUP
=
PersistentConfig
(
"ENABLE_OAUTH_SIGNUP"
,
"oauth.enable_signup"
,
os
.
environ
.
get
(
"ENABLE_OAUTH_SIGNUP"
,
"False"
).
lower
()
==
"true"
,
)
OAUTH_PROVIDERS
=
{}
if
os
.
environ
.
get
(
"GOOGLE_CLIENT_ID"
)
and
os
.
environ
.
get
(
"GOOGLE_CLIENT_SECRET"
):
OAUTH_PROVIDERS
[
"google"
]
=
{
"client_id"
:
os
.
environ
.
get
(
"GOOGLE_CLIENT_ID"
),
"client_secret"
:
os
.
environ
.
get
(
"GOOGLE_CLIENT_SECRET"
),
"server_metadata_url"
:
"https://accounts.google.com/.well-known/openid-configuration"
,
"scope"
:
os
.
environ
.
get
(
"GOOGLE_OAUTH_SCOPE"
,
"openid email profile"
),
}
if
(
os
.
environ
.
get
(
"MICROSOFT_CLIENT_ID"
)
and
os
.
environ
.
get
(
"MICROSOFT_CLIENT_SECRET"
)
and
os
.
environ
.
get
(
"MICROSOFT_CLIENT_TENANT_ID"
)
):
OAUTH_PROVIDERS
[
"microsoft"
]
=
{
"client_id"
:
os
.
environ
.
get
(
"MICROSOFT_CLIENT_ID"
),
"client_secret"
:
os
.
environ
.
get
(
"MICROSOFT_CLIENT_SECRET"
),
"server_metadata_url"
:
f
"https://login.microsoftonline.com/
{
os
.
environ
.
get
(
'MICROSOFT_CLIENT_TENANT_ID'
)
}
/v2.0/.well-known/openid-configuration"
,
"scope"
:
os
.
environ
.
get
(
"MICROSOFT_OAUTH_SCOPE"
,
"openid email profile"
),
}
if
(
os
.
environ
.
get
(
"OPENID_CLIENT_ID"
)
and
os
.
environ
.
get
(
"OPENID_CLIENT_SECRET"
)
and
os
.
environ
.
get
(
"OPENID_PROVIDER_URL"
)
):
OAUTH_PROVIDERS
[
"oidc"
]
=
{
"client_id"
:
os
.
environ
.
get
(
"OPENID_CLIENT_ID"
),
"client_secret"
:
os
.
environ
.
get
(
"OPENID_CLIENT_SECRET"
),
"server_metadata_url"
:
os
.
environ
.
get
(
"OPENID_PROVIDER_URL"
),
"scope"
:
os
.
environ
.
get
(
"OPENID_SCOPE"
,
"openid email profile"
),
"name"
:
os
.
environ
.
get
(
"OPENID_PROVIDER_NAME"
,
"SSO"
),
}
####################################
# Static DIR
####################################
...
...
backend/main.py
View file @
0210a105
...
...
@@ -55,6 +55,7 @@ from config import (
WEBHOOK_URL
,
ENABLE_ADMIN_EXPORT
,
AppConfig
,
OAUTH_PROVIDERS
,
)
from
constants
import
ERROR_MESSAGES
...
...
@@ -364,6 +365,13 @@ async def get_app_config():
"default_locale"
:
default_locale
,
"default_models"
:
webui_app
.
state
.
config
.
DEFAULT_MODELS
,
"default_prompt_suggestions"
:
webui_app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
,
"trusted_header_auth"
:
bool
(
webui_app
.
state
.
AUTH_TRUSTED_EMAIL_HEADER
),
"oauth"
:
{
"providers"
:
{
name
:
config
.
get
(
"name"
,
name
)
for
name
,
config
in
OAUTH_PROVIDERS
.
items
()
}
},
}
...
...
src/lib/stores/index.ts
View file @
0210a105
...
...
@@ -134,7 +134,12 @@ type Config = {
default_models
?:
string
[];
default_prompt_suggestions
?:
PromptSuggestion
[];
auth_trusted_header
?:
boolean
;
model_config
?:
GlobalModelConfig
;
auth
:
boolean
;
oauth
:
{
providers
:
{
[
key
:
string
]:
string
;
};
};
};
type
PromptSuggestion
=
{
...
...
src/routes/+layout.svelte
View file @
0210a105
...
...
@@ -2,6 +2,7 @@
import { onMount, tick, setContext } from 'svelte';
import { config, user, theme, WEBUI_NAME, mobile } from '$lib/stores';
import { goto } from '$app/navigation';
import { page } from '$app/stores';
import { Toaster, toast } from 'svelte-sonner';
import { getBackendConfig } from '$lib/apis';
...
...
@@ -75,9 +76,13 @@
await goto('/auth');
}
} else {
// Don't redirect if we're already on the auth page
// Needed because we pass in tokens from OAuth logins via URL fragments
if ($page.url.pathname !== '/auth') {
await goto('/auth');
}
}
}
} else {
// Redirect to /error when Backend Not Detected
await goto(`/error`);
...
...
src/routes/auth/+page.svelte
View file @
0210a105
<script>
import { goto } from '$app/navigation';
import { userSignIn, userSignUp } from '$lib/apis/auths';
import {
getSessionUser,
userSignIn, userSignUp } from '$lib/apis/auths';
import Spinner from '$lib/components/common/Spinner.svelte';
import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
import { WEBUI_NAME, config, user } from '$lib/stores';
import { onMount, getContext } from 'svelte';
import { toast } from 'svelte-sonner';
import { generateInitialsImage, canvasPixelTest } from '$lib/utils';
import { page } from '$app/stores';
const i18n = getContext('i18n');
...
...
@@ -21,7 +22,9 @@
if (sessionUser) {
console.log(sessionUser);
toast.success($i18n.t(`You're now logged in.`));
if (sessionUser.token) {
localStorage.token = sessionUser.token;
}
await user.set(sessionUser);
goto('/');
}
...
...
@@ -55,10 +58,35 @@
}
};
const checkOauthCallback = async () => {
if (!$page.url.hash) {
return;
}
const hash = $page.url.hash.substring(1);
if (!hash) {
return;
}
const params = new URLSearchParams(hash);
const token = params.get('token');
if (!token) {
return;
}
const sessionUser = await getSessionUser(token).catch((error) => {
toast.error(error);
return null;
});
if (!sessionUser) {
return;
}
localStorage.token = token;
await setSessionUser(sessionUser);
};
onMount(async () => {
if ($user !== undefined) {
await goto('/');
}
await checkOauthCallback();
loaded = true;
if (($config?.auth_trusted_header ?? false) || $config?.auth === false) {
await signInHandler();
...
...
@@ -217,6 +245,97 @@
{/if}
</div>
</form>
{#if Object.keys($config?.oauth?.providers ?? {}).length > 0 }
<div class="inline-flex items-center justify-center w-full">
<hr class="w-64 h-px my-8 bg-gray-200 border-0 dark:bg-gray-700" />
<span
class="absolute px-3 font-medium text-gray-900 -translate-x-1/2 bg-white left-1/2 dark:text-white dark:bg-gray-950"
>{$i18n.t('or')}</span
>
</div>
<div class="flex flex-col space-y-2">
{#if $config?.oauth?.providers?.google }
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/google/login`;
}}
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
<path
fill="#EA4335"
d="M24 9.5c3.54 0 6.71 1.22 9.21 3.6l6.85-6.85C35.9 2.38 30.47 0 24 0 14.62 0 6.51 5.38 2.56 13.22l7.98 6.19C12.43 13.72 17.74 9.5 24 9.5z"
/><path
fill="#4285F4"
d="M46.98 24.55c0-1.57-.15-3.09-.38-4.55H24v9.02h12.94c-.58 2.96-2.26 5.48-4.78 7.18l7.73 6c4.51-4.18 7.09-10.36 7.09-17.65z"
/><path
fill="#FBBC05"
d="M10.53 28.59c-.48-1.45-.76-2.99-.76-4.59s.27-3.14.76-4.59l-7.98-6.19C.92 16.46 0 20.12 0 24c0 3.88.92 7.54 2.56 10.78l7.97-6.19z"
/><path
fill="#34A853"
d="M24 48c6.48 0 11.93-2.13 15.89-5.81l-7.73-6c-2.15 1.45-4.92 2.3-8.16 2.3-6.26 0-11.57-4.22-13.47-9.91l-7.98 6.19C6.51 42.62 14.62 48 24 48z"
/><path fill="none" d="M0 0h48v48H0z" />
</svg>
<span>{$i18n.t('Continue with {{provider}}', { provider: 'Google' })}</span>
</button>
{/if}
{#if $config?.oauth?.providers?.microsoft }
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/microsoft/login`;
}}
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
<rect x="1" y="1" width="9" height="9" fill="#f25022" /><rect
x="1"
y="11"
width="9"
height="9"
fill="#00a4ef"
/><rect x="11" y="1" width="9" height="9" fill="#7fba00" /><rect
x="11"
y="11"
width="9"
height="9"
fill="#ffb900"
/>
</svg>
<span>{$i18n.t('Continue with {{provider}}', { provider: 'Microsoft' })}</span>
</button>
{/if}
{#if $config?.oauth?.providers?.oidc }
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/oidc/login`;
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="size-6 mr-3"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M15.75 5.25a3 3 0 0 1 3 3m3 0a6 6 0 0 1-7.029 5.912c-.563-.097-1.159.026-1.563.43L10.5 17.25H8.25v2.25H6v2.25H2.25v-2.818c0-.597.237-1.17.659-1.591l6.499-6.499c.404-.404.527-1 .43-1.563A6 6 0 1 1 21.75 8.25Z"
/>
</svg>
<span
>{$i18n.t('Continue with {{provider}}', {
provider: $config?.oauth?.providers?.oidc ?? 'SSO'
})}</span
>
</button>
{/if}
</div>
{/if}
</div>
{/if}
</div>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment