Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
985fdca5
Commit
985fdca5
authored
May 27, 2024
by
Jun Siang Cheah
Browse files
refac: move things around, uplift oauth endpoints
parent
06dbf597
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
142 additions
and
100 deletions
+142
-100
backend/apps/webui/main.py
backend/apps/webui/main.py
+0
-6
backend/apps/webui/models/users.py
backend/apps/webui/models/users.py
+21
-2
backend/apps/webui/routers/auths.py
backend/apps/webui/routers/auths.py
+0
-85
backend/main.py
backend/main.py
+118
-4
src/routes/auth/+page.svelte
src/routes/auth/+page.svelte
+3
-3
No files found.
backend/apps/webui/main.py
View file @
985fdca5
...
@@ -58,12 +58,6 @@ app.add_middleware(
...
@@ -58,12 +58,6 @@ app.add_middleware(
allow_headers
=
[
"*"
],
allow_headers
=
[
"*"
],
)
)
# SessionMiddleware is used by authlib for oauth
if
len
(
OAUTH_PROVIDERS
)
>
0
:
app
.
add_middleware
(
SessionMiddleware
,
secret_key
=
WEBUI_SECRET_KEY
,
session_cookie
=
"oui-session"
)
app
.
include_router
(
auths
.
router
,
prefix
=
"/auths"
,
tags
=
[
"auths"
])
app
.
include_router
(
auths
.
router
,
prefix
=
"/auths"
,
tags
=
[
"auths"
])
app
.
include_router
(
users
.
router
,
prefix
=
"/users"
,
tags
=
[
"users"
])
app
.
include_router
(
users
.
router
,
prefix
=
"/users"
,
tags
=
[
"users"
])
app
.
include_router
(
chats
.
router
,
prefix
=
"/chats"
,
tags
=
[
"chats"
])
app
.
include_router
(
chats
.
router
,
prefix
=
"/chats"
,
tags
=
[
"chats"
])
...
...
backend/apps/webui/models/users.py
View file @
985fdca5
...
@@ -112,9 +112,16 @@ class UsersTable:
...
@@ -112,9 +112,16 @@ class UsersTable:
except
:
except
:
return
None
return
None
def
get_user_by_email
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
def
get_user_by_email
(
self
,
email
:
str
,
oauth_user
:
bool
=
False
)
->
Optional
[
UserModel
]:
try
:
try
:
user
=
User
.
get
((
User
.
email
==
email
,
User
.
oauth_sub
.
is_null
()))
conditions
=
(
(
User
.
email
==
email
,
User
.
oauth_sub
.
is_null
())
if
not
oauth_user
else
(
User
.
email
==
email
)
)
user
=
User
.
get
(
conditions
)
return
UserModel
(
**
model_to_dict
(
user
))
return
UserModel
(
**
model_to_dict
(
user
))
except
:
except
:
return
None
return
None
...
@@ -177,6 +184,18 @@ class UsersTable:
...
@@ -177,6 +184,18 @@ class UsersTable:
except
:
except
:
return
None
return
None
def
update_user_oauth_sub_by_id
(
self
,
id
:
str
,
oauth_sub
:
str
)
->
Optional
[
UserModel
]:
try
:
query
=
User
.
update
(
oauth_sub
=
oauth_sub
).
where
(
User
.
id
==
id
)
query
.
execute
()
user
=
User
.
get
(
User
.
id
==
id
)
return
UserModel
(
**
model_to_dict
(
user
))
except
:
return
None
def
update_user_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
UserModel
]:
def
update_user_by_id
(
self
,
id
:
str
,
updated
:
dict
)
->
Optional
[
UserModel
]:
try
:
try
:
query
=
User
.
update
(
**
updated
).
where
(
User
.
id
==
id
)
query
=
User
.
update
(
**
updated
).
where
(
User
.
id
==
id
)
...
...
backend/apps/webui/routers/auths.py
View file @
985fdca5
import
logging
import
logging
from
authlib.integrations.starlette_client
import
OAuth
from
authlib.oidc.core
import
UserInfo
from
fastapi
import
Request
,
UploadFile
,
File
from
fastapi
import
Request
,
UploadFile
,
File
from
fastapi
import
Depends
,
HTTPException
,
status
from
fastapi
import
Depends
,
HTTPException
,
status
...
@@ -11,8 +9,6 @@ import re
...
@@ -11,8 +9,6 @@ import re
import
uuid
import
uuid
import
csv
import
csv
from
starlette.responses
import
RedirectResponse
from
apps.webui.models.auths
import
(
from
apps.webui.models.auths
import
(
SigninForm
,
SigninForm
,
SignupForm
,
SignupForm
,
...
@@ -39,8 +35,6 @@ from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
...
@@ -39,8 +35,6 @@ from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from
config
import
(
from
config
import
(
WEBUI_AUTH
,
WEBUI_AUTH
,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
,
OAUTH_PROVIDERS
,
ENABLE_OAUTH_SIGNUP
,
)
)
router
=
APIRouter
()
router
=
APIRouter
()
...
@@ -381,82 +375,3 @@ async def get_api_key(user=Depends(get_current_user)):
...
@@ -381,82 +375,3 @@ async def get_api_key(user=Depends(get_current_user)):
}
}
else
:
else
:
raise
HTTPException
(
404
,
detail
=
ERROR_MESSAGES
.
API_KEY_NOT_FOUND
)
raise
HTTPException
(
404
,
detail
=
ERROR_MESSAGES
.
API_KEY_NOT_FOUND
)
############################
# OAuth Login & Callback
############################
oauth
=
OAuth
()
for
provider_name
,
provider_config
in
OAUTH_PROVIDERS
.
items
():
oauth
.
register
(
name
=
provider_name
,
client_id
=
provider_config
[
"client_id"
],
client_secret
=
provider_config
[
"client_secret"
],
server_metadata_url
=
provider_config
[
"server_metadata_url"
],
client_kwargs
=
{
"scope"
:
provider_config
[
"scope"
],
},
)
@
router
.
get
(
"/oauth/{provider}/login"
)
async
def
oauth_login
(
provider
:
str
,
request
:
Request
):
if
provider
not
in
OAUTH_PROVIDERS
:
raise
HTTPException
(
404
)
redirect_uri
=
request
.
url_for
(
"oauth_callback"
,
provider
=
provider
)
return
await
oauth
.
create_client
(
provider
).
authorize_redirect
(
request
,
redirect_uri
)
@
router
.
get
(
"/oauth/{provider}/callback"
)
async
def
oauth_callback
(
provider
:
str
,
request
:
Request
):
if
provider
not
in
OAUTH_PROVIDERS
:
raise
HTTPException
(
404
)
client
=
oauth
.
create_client
(
provider
)
token
=
await
client
.
authorize_access_token
(
request
)
user_data
:
UserInfo
=
token
[
"userinfo"
]
sub
=
user_data
.
get
(
"sub"
)
if
not
sub
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
provider_sub
=
f
"
{
provider
}
@
{
sub
}
"
# Check if the user exists
user
=
Users
.
get_user_by_oauth_sub
(
provider_sub
)
if
not
user
:
# If the user does not exist, create a new user if signup is enabled
if
ENABLE_OAUTH_SIGNUP
.
value
:
user
=
Auths
.
insert_new_auth
(
email
=
user_data
.
get
(
"email"
,
""
).
lower
(),
password
=
get_password_hash
(
str
(
uuid
.
uuid4
())
),
# Random password, not used
name
=
user_data
.
get
(
"name"
,
"User"
),
profile_image_url
=
user_data
.
get
(
"picture"
,
"/user.png"
),
role
=
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
,
oauth_sub
=
provider_sub
,
)
if
request
.
app
.
state
.
config
.
WEBHOOK_URL
:
post_webhook
(
request
.
app
.
state
.
config
.
WEBHOOK_URL
,
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
{
"action"
:
"signup"
,
"message"
:
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
"user"
:
user
.
model_dump_json
(
exclude_none
=
True
),
},
)
else
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
jwt_token
=
create_token
(
data
=
{
"id"
:
user
.
id
},
expires_delta
=
parse_duration
(
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
),
)
# Redirect back to the frontend with the JWT token
redirect_url
=
f
"
{
request
.
base_url
}
auth#token=
{
jwt_token
}
"
return
RedirectResponse
(
url
=
redirect_url
)
backend/main.py
View file @
985fdca5
import
uuid
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
authlib.integrations.starlette_client
import
OAuth
from
authlib.oidc.core
import
UserInfo
from
bs4
import
BeautifulSoup
from
bs4
import
BeautifulSoup
import
json
import
json
import
markdown
import
markdown
...
@@ -17,7 +21,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware
...
@@ -17,7 +21,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
starlette.exceptions
import
HTTPException
as
StarletteHTTPException
from
starlette.exceptions
import
HTTPException
as
StarletteHTTPException
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.responses
import
StreamingResponse
,
Response
from
starlette.middleware.sessions
import
SessionMiddleware
from
starlette.responses
import
StreamingResponse
,
Response
,
RedirectResponse
from
apps.ollama.main
import
app
as
ollama_app
,
get_all_models
as
get_ollama_models
from
apps.ollama.main
import
app
as
ollama_app
,
get_all_models
as
get_ollama_models
from
apps.openai.main
import
app
as
openai_app
,
get_all_models
as
get_openai_models
from
apps.openai.main
import
app
as
openai_app
,
get_all_models
as
get_openai_models
...
@@ -31,8 +36,16 @@ import asyncio
...
@@ -31,8 +36,16 @@ import asyncio
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
apps.webui.models.models
import
Models
,
ModelModel
from
apps.webui.models.auths
import
Auths
from
utils.utils
import
get_admin_user
,
get_verified_user
from
apps.webui.models.models
import
Models
from
apps.webui.models.users
import
Users
from
utils.misc
import
parse_duration
from
utils.utils
import
(
get_admin_user
,
get_verified_user
,
get_password_hash
,
create_token
,
)
from
apps.rag.utils
import
rag_messages
from
apps.rag.utils
import
rag_messages
from
config
import
(
from
config
import
(
...
@@ -56,8 +69,12 @@ from config import (
...
@@ -56,8 +69,12 @@ from config import (
ENABLE_ADMIN_EXPORT
,
ENABLE_ADMIN_EXPORT
,
AppConfig
,
AppConfig
,
OAUTH_PROVIDERS
,
OAUTH_PROVIDERS
,
ENABLE_OAUTH_SIGNUP
,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL
,
WEBUI_SECRET_KEY
,
)
)
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
,
WEBHOOK_MESSAGES
from
utils.webhook
import
post_webhook
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
GLOBAL_LOG_LEVEL
)
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
GLOBAL_LOG_LEVEL
)
log
=
logging
.
getLogger
(
__name__
)
log
=
logging
.
getLogger
(
__name__
)
...
@@ -453,6 +470,103 @@ async def get_app_latest_release_version():
...
@@ -453,6 +470,103 @@ async def get_app_latest_release_version():
)
)
############################
# OAuth Login & Callback
############################
oauth
=
OAuth
()
for
provider_name
,
provider_config
in
OAUTH_PROVIDERS
.
items
():
oauth
.
register
(
name
=
provider_name
,
client_id
=
provider_config
[
"client_id"
],
client_secret
=
provider_config
[
"client_secret"
],
server_metadata_url
=
provider_config
[
"server_metadata_url"
],
client_kwargs
=
{
"scope"
:
provider_config
[
"scope"
],
},
)
# SessionMiddleware is used by authlib for oauth
if
len
(
OAUTH_PROVIDERS
)
>
0
:
app
.
add_middleware
(
SessionMiddleware
,
secret_key
=
WEBUI_SECRET_KEY
,
session_cookie
=
"oui-session"
)
@
app
.
get
(
"/oauth/{provider}/login"
)
async
def
oauth_login
(
provider
:
str
,
request
:
Request
):
if
provider
not
in
OAUTH_PROVIDERS
:
raise
HTTPException
(
404
)
redirect_uri
=
request
.
url_for
(
"oauth_callback"
,
provider
=
provider
)
return
await
oauth
.
create_client
(
provider
).
authorize_redirect
(
request
,
redirect_uri
)
@
app
.
get
(
"/oauth/{provider}/callback"
)
async
def
oauth_callback
(
provider
:
str
,
request
:
Request
):
if
provider
not
in
OAUTH_PROVIDERS
:
raise
HTTPException
(
404
)
client
=
oauth
.
create_client
(
provider
)
token
=
await
client
.
authorize_access_token
(
request
)
user_data
:
UserInfo
=
token
[
"userinfo"
]
sub
=
user_data
.
get
(
"sub"
)
if
not
sub
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
provider_sub
=
f
"
{
provider
}
@
{
sub
}
"
# Check if the user exists
user
=
Users
.
get_user_by_oauth_sub
(
provider_sub
)
if
not
user
:
# If the user does not exist, check if merging is enabled
if
OAUTH_MERGE_ACCOUNTS_BY_EMAIL
:
# Check if the user exists by email
email
=
user_data
.
get
(
"email"
,
""
).
lower
()
if
not
email
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
user
=
Users
.
get_user_by_email
(
user_data
.
get
(
"email"
,
""
).
lower
(),
True
)
if
user
:
# Update the user with the new oauth sub
Users
.
update_user_oauth_sub_by_id
(
user
.
id
,
provider_sub
)
if
not
user
:
# If the user does not exist, check if signups are enabled
if
ENABLE_OAUTH_SIGNUP
.
value
:
user
=
Auths
.
insert_new_auth
(
email
=
user_data
.
get
(
"email"
,
""
).
lower
(),
password
=
get_password_hash
(
str
(
uuid
.
uuid4
())
),
# Random password, not used
name
=
user_data
.
get
(
"name"
,
"User"
),
profile_image_url
=
user_data
.
get
(
"picture"
,
"/user.png"
),
role
=
webui_app
.
state
.
config
.
DEFAULT_USER_ROLE
,
oauth_sub
=
provider_sub
,
)
if
webui_app
.
state
.
config
.
WEBHOOK_URL
:
post_webhook
(
webui_app
.
state
.
config
.
WEBHOOK_URL
,
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
{
"action"
:
"signup"
,
"message"
:
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
"user"
:
user
.
model_dump_json
(
exclude_none
=
True
),
},
)
else
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
jwt_token
=
create_token
(
data
=
{
"id"
:
user
.
id
},
expires_delta
=
parse_duration
(
webui_app
.
state
.
config
.
JWT_EXPIRES_IN
),
)
# Redirect back to the frontend with the JWT token
redirect_url
=
f
"
{
request
.
base_url
}
auth#token=
{
jwt_token
}
"
return
RedirectResponse
(
url
=
redirect_url
)
@
app
.
get
(
"/manifest.json"
)
@
app
.
get
(
"/manifest.json"
)
async
def
get_manifest_json
():
async
def
get_manifest_json
():
return
{
return
{
...
...
src/routes/auth/+page.svelte
View file @
985fdca5
...
@@ -259,7 +259,7 @@
...
@@ -259,7 +259,7 @@
<button
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
on:click={() => {
window.location.href = `${WEBUI_
API_
BASE_URL}/
auths/
oauth/google/login`;
window.location.href = `${WEBUI_BASE_URL}/oauth/google/login`;
}}
}}
>
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
...
@@ -284,7 +284,7 @@
...
@@ -284,7 +284,7 @@
<button
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
on:click={() => {
window.location.href = `${WEBUI_
API_
BASE_URL}/
auths/
oauth/microsoft/login`;
window.location.href = `${WEBUI_BASE_URL}/oauth/microsoft/login`;
}}
}}
>
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
...
@@ -309,7 +309,7 @@
...
@@ -309,7 +309,7 @@
<button
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
on:click={() => {
window.location.href = `${WEBUI_
API_
BASE_URL}/
auths/
oauth/oidc/login`;
window.location.href = `${WEBUI_BASE_URL}/oauth/oidc/login`;
}}
}}
>
>
<svg
<svg
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment