Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
981f3841
Commit
981f3841
authored
Jun 21, 2024
by
Jun Siang Cheah
Browse files
refac: modify oauth login logic for unique email addresses
parent
e011e7b6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
15 deletions
+20
-15
backend/apps/webui/models/users.py
backend/apps/webui/models/users.py
+2
-9
backend/main.py
backend/main.py
+18
-6
No files found.
backend/apps/webui/models/users.py
View file @
981f3841
...
@@ -122,16 +122,9 @@ class UsersTable:
...
@@ -122,16 +122,9 @@ class UsersTable:
except
:
except
:
return
None
return
None
def
get_user_by_email
(
def
get_user_by_email
(
self
,
email
:
str
)
->
Optional
[
UserModel
]:
self
,
email
:
str
,
oauth_user
:
bool
=
False
)
->
Optional
[
UserModel
]:
try
:
try
:
conditions
=
(
user
=
User
.
get
(
User
.
email
==
email
)
(
User
.
email
==
email
,
User
.
oauth_sub
.
is_null
())
if
not
oauth_user
else
(
User
.
email
==
email
)
)
user
=
User
.
get
(
*
conditions
)
return
UserModel
(
**
model_to_dict
(
user
))
return
UserModel
(
**
model_to_dict
(
user
))
except
:
except
:
return
None
return
None
...
...
backend/main.py
View file @
981f3841
...
@@ -1869,6 +1869,12 @@ async def oauth_login(provider: str, request: Request):
...
@@ -1869,6 +1869,12 @@ async def oauth_login(provider: str, request: Request):
return
await
oauth
.
create_client
(
provider
).
authorize_redirect
(
request
,
redirect_uri
)
return
await
oauth
.
create_client
(
provider
).
authorize_redirect
(
request
,
redirect_uri
)
# OAuth login logic is as follows:
# 1. Attempt to find a user with matching subject ID, tied to the provider
# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
# - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
@
app
.
get
(
"/oauth/{provider}/callback"
)
@
app
.
get
(
"/oauth/{provider}/callback"
)
async
def
oauth_callback
(
provider
:
str
,
request
:
Request
,
response
:
Response
):
async
def
oauth_callback
(
provider
:
str
,
request
:
Request
,
response
:
Response
):
if
provider
not
in
OAUTH_PROVIDERS
:
if
provider
not
in
OAUTH_PROVIDERS
:
...
@@ -1885,6 +1891,10 @@ async def oauth_callback(provider: str, request: Request, response: Response):
...
@@ -1885,6 +1891,10 @@ async def oauth_callback(provider: str, request: Request, response: Response):
if
not
sub
:
if
not
sub
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
provider_sub
=
f
"
{
provider
}
@
{
sub
}
"
provider_sub
=
f
"
{
provider
}
@
{
sub
}
"
email
=
user_data
.
get
(
"email"
,
""
).
lower
()
# We currently mandate that email addresses are provided
if
not
email
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
# Check if the user exists
# Check if the user exists
user
=
Users
.
get_user_by_oauth_sub
(
provider_sub
)
user
=
Users
.
get_user_by_oauth_sub
(
provider_sub
)
...
@@ -1893,10 +1903,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
...
@@ -1893,10 +1903,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
# If the user does not exist, check if merging is enabled
# If the user does not exist, check if merging is enabled
if
OAUTH_MERGE_ACCOUNTS_BY_EMAIL
.
value
:
if
OAUTH_MERGE_ACCOUNTS_BY_EMAIL
.
value
:
# Check if the user exists by email
# Check if the user exists by email
email
=
user_data
.
get
(
"email"
,
""
).
lower
()
user
=
Users
.
get_user_by_email
(
email
)
if
not
email
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_CRED
)
user
=
Users
.
get_user_by_email
(
user_data
.
get
(
"email"
,
""
).
lower
(),
True
)
if
user
:
if
user
:
# Update the user with the new oauth sub
# Update the user with the new oauth sub
Users
.
update_user_oauth_sub_by_id
(
user
.
id
,
provider_sub
)
Users
.
update_user_oauth_sub_by_id
(
user
.
id
,
provider_sub
)
...
@@ -1904,6 +1911,11 @@ async def oauth_callback(provider: str, request: Request, response: Response):
...
@@ -1904,6 +1911,11 @@ async def oauth_callback(provider: str, request: Request, response: Response):
if
not
user
:
if
not
user
:
# If the user does not exist, check if signups are enabled
# If the user does not exist, check if signups are enabled
if
ENABLE_OAUTH_SIGNUP
.
value
:
if
ENABLE_OAUTH_SIGNUP
.
value
:
# Check if an existing user with the same email already exists
existing_user
=
Users
.
get_user_by_email
(
user_data
.
get
(
"email"
,
""
).
lower
())
if
existing_user
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
EMAIL_TAKEN
)
picture_url
=
user_data
.
get
(
"picture"
,
""
)
picture_url
=
user_data
.
get
(
"picture"
,
""
)
if
picture_url
:
if
picture_url
:
# Download the profile image into a base64 string
# Download the profile image into a base64 string
...
@@ -1920,12 +1932,12 @@ async def oauth_callback(provider: str, request: Request, response: Response):
...
@@ -1920,12 +1932,12 @@ async def oauth_callback(provider: str, request: Request, response: Response):
guessed_mime_type
=
"image/jpeg"
guessed_mime_type
=
"image/jpeg"
picture_url
=
f
"data:
{
guessed_mime_type
}
;base64,
{
base64_encoded_picture
}
"
picture_url
=
f
"data:
{
guessed_mime_type
}
;base64,
{
base64_encoded_picture
}
"
except
Exception
as
e
:
except
Exception
as
e
:
log
.
error
(
f
"
Profile image download error
:
{
e
}
"
)
log
.
error
(
f
"
Error downloading profile image '
{
picture_url
}
'
:
{
e
}
"
)
picture_url
=
""
picture_url
=
""
if
not
picture_url
:
if
not
picture_url
:
picture_url
=
"/user.png"
picture_url
=
"/user.png"
user
=
Auths
.
insert_new_auth
(
user
=
Auths
.
insert_new_auth
(
email
=
user_data
.
get
(
"email"
,
""
).
lower
()
,
email
=
email
,
password
=
get_password_hash
(
password
=
get_password_hash
(
str
(
uuid
.
uuid4
())
str
(
uuid
.
uuid4
())
),
# Random password, not used
),
# Random password, not used
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment