Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
985fdca5
Commit
985fdca5
authored
May 27, 2024
by
Jun Siang Cheah
Browse files
refac: move things around, uplift oauth endpoints
parent
06dbf597
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
142 additions
and
100 deletions
+142
-100
backend/apps/webui/main.py
backend/apps/webui/main.py
+0
-6
backend/apps/webui/models/users.py
backend/apps/webui/models/users.py
+21
-2
backend/apps/webui/routers/auths.py
backend/apps/webui/routers/auths.py
+0
-85
backend/main.py
backend/main.py
+118
-4
src/routes/auth/+page.svelte
src/routes/auth/+page.svelte
+3
-3
No files found.
backend/apps/webui/main.py
View file @
985fdca5
...
...
@@ -58,12 +58,6 @@ app.add_middleware(
allow_headers
=
[
"*"
],
)
# SessionMiddleware is used by authlib for oauth
if
len
(
OAUTH_PROVIDERS
)
>
0
:
app
.
add_middleware
(
SessionMiddleware
,
secret_key
=
WEBUI_SECRET_KEY
,
session_cookie
=
"oui-session"
)
app
.
include_router
(
auths
.
router
,
prefix
=
"/auths"
,
tags
=
[
"auths"
])
app
.
include_router
(
users
.
router
,
prefix
=
"/users"
,
tags
=
[
"users"
])
app
.
include_router
(
chats
.
router
,
prefix
=
"/chats"
,
tags
=
[
"chats"
])
...
...
backend/apps/webui/models/users.py
View file @
985fdca5
...
...
@@ -112,9 +112,16 @@ class UsersTable:
except
:
return
None
def
get_user_by_email
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
def
get_user_by_email
(
self
,
email
:
str
,
oauth_user
:
bool
=
False
)
->
Optional
[
UserModel
]:
try
:
user
=
User
.
get
((
User
.
email
==
email
,
User
.
oauth_sub
.
is_null
()))
conditions
=
(
(
User
.
email
==
email
,
User
.
oauth_sub
.
is_null
())
if
not
oauth_user
else
(
User
.
email
==
email
)
)
user
=
User
.
get
(
conditions
)
return
UserModel
(
**
model_to_dict
(
user
))
except
:
return
None
...
...
@@ -177,6 +184,18 @@ class UsersTable:
except
:
return
None
def
update_user_oauth_sub_by_id
(
self
,
id
:
str
,
oauth_sub
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
oauth_sub
=
oauth_sub
).
where
(
User
.
id
==
id
)
query
.
execute
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_to_dict
(
user
))
except
:
return
None
def
update_user_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
**
updated
).
where
(
User
.
id
==
id
)
...
...
backend/apps/webui/routers/auths.py
View file @
985fdca5
import
logging
from
authlib.integrations.starlette_client
import
OAuth
from
authlib.oidc.core
import
UserInfo
from
fastapi
import
Request
,
UploadFile
,
File
from
fastapi
import
Depends
,
HTTPException
,
status
...
...
@@ -11,8 +9,6 @@ import re
import
uuid
import
csv
from
starlette.responses
import
RedirectResponse
from
apps.webui.models.auths
import
(
SigninForm
,
SignupForm
,
...
...
@@ -39,8 +35,6 @@ from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from
config
import
(
WEBUI_AUTH
,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
,
OAUTH_PROVIDERS
,
ENABLE_OAUTH_SIGNUP
,
)
router
=
APIRouter
()
...
...
@@ -381,82 +375,3 @@ async def get_api_key(user=Depends(get_current_user)):
}
else
:
raise
HTTPException
(
404
,
detail
=
ERROR_MESSAGES
.
API_KEY_NOT_FOUND
)
############################
# OAuth Login & Callback
############################
oauth
=
OAuth
()
for
provider_name
,
provider_config
in
OAUTH_PROVIDERS
.
items
():
oauth
.
register
(
name
=
provider_name
,
client_id
=
provider_config
[
"client_id"
],
client_secret
=
provider_config
[
"client_secret"
],
server_metadata_url
=
provider_config
[
"server_metadata_url"
],
client_kwargs
=
{
"scope"
:
provider_config
[
"scope"
],
},
)
@
router
.
get
(
"/oauth/{provider}/login"
)
async
def
oauth_login
(
provider
:
str
,
request
:
Request
):
if
provider
not
in
OAUTH_PROVIDERS
:
raise
HTTPException
(
404
)
redirect_uri
=
request
.
url_for
(
"oauth_callback"
,
provider
=
provider
)
return
await
oauth
.
create_client
(
provider
).
authorize_redirect
(
request
,
redirect_uri
)
@
router
.
get
(
"/oauth/{provider}/callback"
)
async
def
oauth_callback
(
provider
:
str
,
request
:
Request
):
if
provider
not
in
OAUTH_PROVIDERS
:
raise
HTTPException
(
404
)
client
=
oauth
.
create_client
(
provider
)
token
=
await
client
.
authorize_access_token
(
request
)
user_data
:
UserInfo
=
token
[
"userinfo"
]
sub
=
user_data
.
get
(
"sub"
)
if
not
sub
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
provider_sub
=
f
"
{
provider
}
@
{
sub
}
"
# Check if the user exists
user
=
Users
.
get_user_by_oauth_sub
(
provider_sub
)
if
not
user
:
# If the user does not exist, create a new user if signup is enabled
if
ENABLE_OAUTH_SIGNUP
.
value
:
user
=
Auths
.
insert_new_auth
(
email
=
user_data
.
get
(
"email"
,
""
).
lower
(),
password
=
get_password_hash
(
str
(
uuid
.
uuid4
())
),
# Random password, not used
name
=
user_data
.
get
(
"name"
,
"User"
),
profile_image_url
=
user_data
.
get
(
"picture"
,
"/user.png"
),
role
=
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
,
oauth_sub
=
provider_sub
,
)
if
request
.
app
.
state
.
config
.
WEBHOOK_URL
:
post_webhook
(
request
.
app
.
state
.
config
.
WEBHOOK_URL
,
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
{
"action"
:
"signup"
,
"message"
:
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
"user"
:
user
.
model_dump_json
(
exclude_none
=
True
),
},
)
else
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
jwt_token
=
create_token
(
data
=
{
"id"
:
user
.
id
},
expires_delta
=
parse_duration
(
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
),
)
# Redirect back to the frontend with the JWT token
redirect_url
=
f
"
{
request
.
base_url
}
auth#token=
{
jwt_token
}
"
return
RedirectResponse
(
url
=
redirect_url
)
backend/main.py
View file @
985fdca5
import
uuid
from
contextlib
import
asynccontextmanager
from
authlib.integrations.starlette_client
import
OAuth
from
authlib.oidc.core
import
UserInfo
from
bs4
import
BeautifulSoup
import
json
import
markdown
...
...
@@ -17,7 +21,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
starlette.exceptions
import
HTTPException
as
StarletteHTTPException
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.responses
import
StreamingResponse
,
Response
from
starlette.middleware.sessions
import
SessionMiddleware
from
starlette.responses
import
StreamingResponse
,
Response
,
RedirectResponse
from
apps.ollama.main
import
app
as
ollama_app
,
get_all_models
as
get_ollama_models
from
apps.openai.main
import
app
as
openai_app
,
get_all_models
as
get_openai_models
...
...
@@ -31,8 +36,16 @@ import asyncio
from
pydantic
import
BaseModel
from
typing
import
List
,
Optional
from
apps.webui.models.models
import
Models
,
ModelModel
from
utils.utils
import
get_admin_user
,
get_verified_user
from
apps.webui.models.auths
import
Auths
from
apps.webui.models.models
import
Models
from
apps.webui.models.users
import
Users
from
utils.misc
import
parse_duration
from
utils.utils
import
(
get_admin_user
,
get_verified_user
,
get_password_hash
,
create_token
,
)
from
apps.rag.utils
import
rag_messages
from
config
import
(
...
...
@@ -56,8 +69,12 @@ from config import (
ENABLE_ADMIN_EXPORT
,
AppConfig
,
OAUTH_PROVIDERS
,
ENABLE_OAUTH_SIGNUP
,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL
,
WEBUI_SECRET_KEY
,
)
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
,
WEBHOOK_MESSAGES
from
utils.webhook
import
post_webhook
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
GLOBAL_LOG_LEVEL
)
log
=
logging
.
getLogger
(
__name__
)
...
...
@@ -453,6 +470,103 @@ async def get_app_latest_release_version():
)
############################
# OAuth Login & Callback
############################
oauth
=
OAuth
()
for
provider_name
,
provider_config
in
OAUTH_PROVIDERS
.
items
():
oauth
.
register
(
name
=
provider_name
,
client_id
=
provider_config
[
"client_id"
],
client_secret
=
provider_config
[
"client_secret"
],
server_metadata_url
=
provider_config
[
"server_metadata_url"
],
client_kwargs
=
{
"scope"
:
provider_config
[
"scope"
],
},
)
# SessionMiddleware is used by authlib for oauth
if
len
(
OAUTH_PROVIDERS
)
>
0
:
app
.
add_middleware
(
SessionMiddleware
,
secret_key
=
WEBUI_SECRET_KEY
,
session_cookie
=
"oui-session"
)
@
app
.
get
(
"/oauth/{provider}/login"
)
async
def
oauth_login
(
provider
:
str
,
request
:
Request
):
if
provider
not
in
OAUTH_PROVIDERS
:
raise
HTTPException
(
404
)
redirect_uri
=
request
.
url_for
(
"oauth_callback"
,
provider
=
provider
)
return
await
oauth
.
create_client
(
provider
).
authorize_redirect
(
request
,
redirect_uri
)
@
app
.
get
(
"/oauth/{provider}/callback"
)
async
def
oauth_callback
(
provider
:
str
,
request
:
Request
):
if
provider
not
in
OAUTH_PROVIDERS
:
raise
HTTPException
(
404
)
client
=
oauth
.
create_client
(
provider
)
token
=
await
client
.
authorize_access_token
(
request
)
user_data
:
UserInfo
=
token
[
"userinfo"
]
sub
=
user_data
.
get
(
"sub"
)
if
not
sub
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
provider_sub
=
f
"
{
provider
}
@
{
sub
}
"
# Check if the user exists
user
=
Users
.
get_user_by_oauth_sub
(
provider_sub
)
if
not
user
:
# If the user does not exist, check if merging is enabled
if
OAUTH_MERGE_ACCOUNTS_BY_EMAIL
:
# Check if the user exists by email
email
=
user_data
.
get
(
"email"
,
""
).
lower
()
if
not
email
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
user
=
Users
.
get_user_by_email
(
user_data
.
get
(
"email"
,
""
).
lower
(),
True
)
if
user
:
# Update the user with the new oauth sub
Users
.
update_user_oauth_sub_by_id
(
user
.
id
,
provider_sub
)
if
not
user
:
# If the user does not exist, check if signups are enabled
if
ENABLE_OAUTH_SIGNUP
.
value
:
user
=
Auths
.
insert_new_auth
(
email
=
user_data
.
get
(
"email"
,
""
).
lower
(),
password
=
get_password_hash
(
str
(
uuid
.
uuid4
())
),
# Random password, not used
name
=
user_data
.
get
(
"name"
,
"User"
),
profile_image_url
=
user_data
.
get
(
"picture"
,
"/user.png"
),
role
=
webui_app
.
state
.
config
.
DEFAULT_USER_ROLE
,
oauth_sub
=
provider_sub
,
)
if
webui_app
.
state
.
config
.
WEBHOOK_URL
:
post_webhook
(
webui_app
.
state
.
config
.
WEBHOOK_URL
,
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
{
"action"
:
"signup"
,
"message"
:
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
"user"
:
user
.
model_dump_json
(
exclude_none
=
True
),
},
)
else
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
jwt_token
=
create_token
(
data
=
{
"id"
:
user
.
id
},
expires_delta
=
parse_duration
(
webui_app
.
state
.
config
.
JWT_EXPIRES_IN
),
)
# Redirect back to the frontend with the JWT token
redirect_url
=
f
"
{
request
.
base_url
}
auth#token=
{
jwt_token
}
"
return
RedirectResponse
(
url
=
redirect_url
)
@
app
.
get
(
"/manifest.json"
)
async
def
get_manifest_json
():
return
{
...
...
src/routes/auth/+page.svelte
View file @
985fdca5
...
...
@@ -259,7 +259,7 @@
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_
API_
BASE_URL}/
auths/
oauth/google/login`;
window.location.href = `${WEBUI_BASE_URL}/oauth/google/login`;
}}
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
...
...
@@ -284,7 +284,7 @@
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_
API_
BASE_URL}/
auths/
oauth/microsoft/login`;
window.location.href = `${WEBUI_BASE_URL}/oauth/microsoft/login`;
}}
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
...
...
@@ -309,7 +309,7 @@
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_
API_
BASE_URL}/
auths/
oauth/oidc/login`;
window.location.href = `${WEBUI_BASE_URL}/oauth/oidc/login`;
}}
>
<svg
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment