auths.py 8.24 KB
Newer Older
1
from fastapi import Response, Request
2
3
4
5
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union

6
from fastapi import APIRouter, status
7
8
9
from pydantic import BaseModel
import time
import uuid
Timothy J. Baek's avatar
Timothy J. Baek committed
10
import re
11
12
13
14

from apps.web.models.auths import (
    SigninForm,
    SignupForm,
15
    UpdateProfileForm,
16
    UpdatePasswordForm,
17
18
19
20
21
22
    UserResponse,
    SigninResponse,
    Auths,
)
from apps.web.models.users import Users

Timothy J. Baek's avatar
Timothy J. Baek committed
23
24
25
26
27
28
from utils.utils import (
    get_password_hash,
    get_current_user,
    get_admin_user,
    create_token,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
29
from utils.misc import parse_duration, validate_email_format
Timothy J. Baek's avatar
Timothy J. Baek committed
30
31
from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
32
from config import WEBUI_AUTH_TRUSTED_EMAIL_HEADER
33

Timothy J. Baek's avatar
Timothy J. Baek committed
34
35
router = APIRouter()

36
37
38
39
40
############################
# GetSessionUser
############################


41
@router.get("/", response_model=UserResponse)
42
43
44
45
46
47
48
49
async def get_session_user(user=Depends(get_current_user)):
    return {
        "id": user.id,
        "email": user.email,
        "name": user.name,
        "role": user.role,
        "profile_image_url": user.profile_image_url,
    }
50
51


52
############################
53
# Update Profile
54
55
56
57
############################


@router.post("/update/profile", response_model=UserResponse)
58
59
async def update_profile(
    form_data: UpdateProfileForm, session_user=Depends(get_current_user)
60
61
):
    if session_user:
62
63
64
        user = Users.update_user_by_id(
            session_user.id,
            {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
65
66
67
68
69
70
71
72
73
        )
        if user:
            return user
        else:
            raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
    else:
        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)


74
75
76
77
78
############################
# Update Password
############################


79
@router.post("/update/password", response_model=bool)
80
81
82
async def update_password(
    form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
):
83
84
    if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
        raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
85
86
    if session_user:
        user = Auths.authenticate_user(session_user.email, form_data.password)
87

88
89
        if user:
            hashed = get_password_hash(form_data.new_password)
Timothy J. Baek's avatar
Timothy J. Baek committed
90
            return Auths.update_user_password_by_id(user.id, hashed)
91
92
        else:
            raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
93
94
95
96
    else:
        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)


97
98
99
100
101
102
############################
# SignIn
############################


@router.post("/signin", response_model=SigninResponse)
Timothy J. Baek's avatar
Timothy J. Baek committed
103
async def signin(request: Request, form_data: SigninForm):
104
105
106
107
108
109
110
111
112
113
    if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
        if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
            raise HTTPException(400,
                                detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
        trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower(
        )
        user = Auths.authenticate_user_by_trusted_header(trusted_email)
    else:
        user = Auths.authenticate_user(form_data.email.lower(),
                                       form_data.password)
114
    if user:
Timothy J. Baek's avatar
Timothy J. Baek committed
115
116
117
118
        token = create_token(
            data={"id": user.id},
            expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
        )
119
120
121
122
123
124
125
126

        return {
            "token": token,
            "token_type": "Bearer",
            "id": user.id,
            "email": user.email,
            "name": user.name,
            "role": user.role,
Timothy J. Baek's avatar
Timothy J. Baek committed
127
            "profile_image_url": user.profile_image_url,
128
129
        }
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
130
        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
131
132
133
134
135
136
137
138


############################
# SignUp
############################


@router.post("/signup", response_model=SigninResponse)
139
async def signup(request: Request, form_data: SignupForm):
140
    if not request.app.state.ENABLE_SIGNUP:
Timothy J. Baek's avatar
Timothy J. Baek committed
141
142
143
        raise HTTPException(
            status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
        )
144

145
    if not validate_email_format(form_data.email.lower()):
Timothy J. Baek's avatar
Timothy J. Baek committed
146
147
148
        raise HTTPException(
            status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
        )
149

150
151
    if Users.get_user_by_email(form_data.email.lower()):
        raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
152

153
154
155
156
157
158
159
160
161
162
163
    if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
        if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
            raise HTTPException(400,
                                detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
        trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower(
        )
        if trusted_email != form_data.email:
            raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_MISMATCH)
        # TODO: Yolo hack to assign a password
        form_data.password = str(uuid.uuid4())

164
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
165
166
167
168
169
        role = (
            "admin"
            if Users.get_num_users() == 0
            else request.app.state.DEFAULT_USER_ROLE
        )
170
        hashed = get_password_hash(form_data.password)
171
172
173
        user = Auths.insert_new_auth(
            form_data.email.lower(), hashed, form_data.name, role
        )
174

175
        if user:
Timothy J. Baek's avatar
Timothy J. Baek committed
176
177
178
179
            token = create_token(
                data={"id": user.id},
                expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
            )
180
181
            # response.set_cookie(key='token', value=token, httponly=True)

Timothy J. Baek's avatar
Timothy J. Baek committed
182
183
184
            if request.app.state.WEBHOOK_URL:
                post_webhook(
                    request.app.state.WEBHOOK_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
185
                    WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
Timothy J. Baek's avatar
Timothy J. Baek committed
186
187
188
189
190
191
192
                    {
                        "action": "signup",
                        "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
                        "user": user.model_dump_json(exclude_none=True),
                    },
                )

193
194
195
196
197
198
199
200
201
202
            return {
                "token": token,
                "token_type": "Bearer",
                "id": user.id,
                "email": user.email,
                "name": user.name,
                "role": user.role,
                "profile_image_url": user.profile_image_url,
            }
        else:
203
            raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
204
    except Exception as err:
205
206
        raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))

207
208
209
210
211
212
213

############################
# ToggleSignUp
############################


@router.get("/signup/enabled", response_model=bool)
214
215
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
    return request.app.state.ENABLE_SIGNUP
216
217
218


@router.get("/signup/enabled/toggle", response_model=bool)
219
220
221
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
    request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
    return request.app.state.ENABLE_SIGNUP
Timothy J. Baek's avatar
Timothy J. Baek committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244


############################
# Default User Role
############################


@router.get("/signup/user/role")
async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
    return request.app.state.DEFAULT_USER_ROLE


class UpdateRoleForm(BaseModel):
    role: str


@router.post("/signup/user/role")
async def update_default_user_role(
    request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
):
    if form_data.role in ["pending", "user", "admin"]:
        request.app.state.DEFAULT_USER_ROLE = form_data.role
    return request.app.state.DEFAULT_USER_ROLE
Timothy J. Baek's avatar
Timothy J. Baek committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274


############################
# JWT Expiration
############################


@router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
    return request.app.state.JWT_EXPIRES_IN


class UpdateJWTExpiresDurationForm(BaseModel):
    duration: str


@router.post("/token/expires/update")
async def update_token_expires_duration(
    request: Request,
    form_data: UpdateJWTExpiresDurationForm,
    user=Depends(get_admin_user),
):
    pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"

    # Check if the input string matches the pattern
    if re.match(pattern, form_data.duration):
        request.app.state.JWT_EXPIRES_IN = form_data.duration
        return request.app.state.JWT_EXPIRES_IN
    else:
        return request.app.state.JWT_EXPIRES_IN